1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3#[cfg(feature = "framework")]
4use std::sync::OnceLock;
5
6use futures::channel::mpsc::UnboundedReceiver as Receiver;
7use futures::StreamExt;
8use tokio::sync::{Mutex, RwLock};
9use tokio::time::{sleep, timeout, Duration, Instant};
10use tracing::{debug, info, instrument, warn};
11use typemap_rev::TypeMap;
12
13#[cfg(feature = "voice")]
14use super::VoiceGatewayManager;
15use super::{
16 ShardId,
17 ShardManager,
18 ShardMessenger,
19 ShardQueuerMessage,
20 ShardRunner,
21 ShardRunnerInfo,
22 ShardRunnerOptions,
23};
24#[cfg(feature = "cache")]
25use crate::cache::Cache;
26use crate::client::{EventHandler, RawEventHandler};
27#[cfg(feature = "framework")]
28use crate::framework::Framework;
29use crate::gateway::{ConnectionStage, PresenceData, Shard, ShardRunnerMessage};
30use crate::http::Http;
31use crate::internal::prelude::*;
32use crate::internal::tokio::spawn_named;
33use crate::model::gateway::{GatewayIntents, ShardInfo};
34
35const WAIT_BETWEEN_BOOTS_IN_SECONDS: u64 = 5;
36
37pub struct ShardQueuer {
42 pub data: Arc<RwLock<TypeMap>>,
46 pub event_handlers: Vec<Arc<dyn EventHandler>>,
50 pub raw_event_handlers: Vec<Arc<dyn RawEventHandler>>,
54 #[cfg(feature = "framework")]
56 pub framework: Arc<OnceLock<Arc<dyn Framework>>>,
57 pub last_start: Option<Instant>,
61 pub manager: Arc<ShardManager>,
63 pub queue: VecDeque<ShardInfo>,
67 pub runners: Arc<Mutex<HashMap<ShardId, ShardRunnerInfo>>>,
69 pub rx: Receiver<ShardQueuerMessage>,
71 #[cfg(feature = "voice")]
73 pub voice_manager: Option<Arc<dyn VoiceGatewayManager + 'static>>,
74 pub ws_url: Arc<Mutex<String>>,
76 #[cfg(feature = "cache")]
77 pub cache: Arc<Cache>,
78 pub http: Arc<Http>,
79 pub intents: GatewayIntents,
80 pub presence: Option<PresenceData>,
81}
82
83impl ShardQueuer {
84 #[instrument(skip(self))]
101 pub async fn run(&mut self) {
102 const TIMEOUT: Duration = Duration::from_secs(WAIT_BETWEEN_BOOTS_IN_SECONDS);
106
107 loop {
108 match timeout(TIMEOUT, self.rx.next()).await {
109 Ok(Some(ShardQueuerMessage::Shutdown)) => {
110 debug!("[Shard Queuer] Received to shutdown.");
111 self.shutdown_runners().await;
112
113 break;
114 },
115 Ok(Some(ShardQueuerMessage::ShutdownShard(shard, code))) => {
116 debug!("[Shard Queuer] Received to shutdown shard {} with {}.", shard.0, code);
117 self.shutdown(shard, code).await;
118 },
119 Ok(Some(ShardQueuerMessage::Start(id, total))) => {
120 debug!("[Shard Queuer] Received to start shard {} of {}.", id.0, total.0);
121 self.checked_start(id, total.0).await;
122 },
123 Ok(None) => break,
124 Err(_) => {
125 if let Some(shard) = self.queue.pop_front() {
126 self.checked_start(shard.id, shard.total).await;
127 }
128 },
129 }
130 }
131 }
132
133 #[instrument(skip(self))]
134 async fn check_last_start(&mut self) {
135 let Some(instant) = self.last_start else { return };
136
137 let duration = Duration::from_secs(WAIT_BETWEEN_BOOTS_IN_SECONDS);
139 let elapsed = instant.elapsed();
140
141 if let Some(to_sleep) = duration.checked_sub(elapsed) {
142 sleep(to_sleep).await;
143 }
144 }
145
146 #[instrument(skip(self))]
147 async fn checked_start(&mut self, id: ShardId, total: u32) {
148 debug!("[Shard Queuer] Checked start for shard {} out of {}", id, total);
149 self.check_last_start().await;
150
151 if let Err(why) = self.start(id, total).await {
152 warn!("[Shard Queuer] Err starting shard {}: {:?}", id, why);
153 info!("[Shard Queuer] Re-queueing start of shard {}", id);
154
155 self.queue.push_back(ShardInfo::new(id, total));
156 }
157
158 self.last_start = Some(Instant::now());
159 }
160
161 #[instrument(skip(self))]
162 async fn start(&mut self, id: ShardId, total: u32) -> Result<()> {
163 let shard_info = ShardInfo::new(id, total);
164
165 let mut shard = Shard::new(
166 Arc::clone(&self.ws_url),
167 self.http.token(),
168 shard_info,
169 self.intents,
170 self.presence.clone(),
171 )
172 .await?;
173
174 let cloned_http = Arc::clone(&self.http);
175 shard.set_application_id_callback(move |id| cloned_http.set_application_id(id));
176
177 let mut runner = ShardRunner::new(ShardRunnerOptions {
178 data: Arc::clone(&self.data),
179 event_handlers: self.event_handlers.clone(),
180 raw_event_handlers: self.raw_event_handlers.clone(),
181 #[cfg(feature = "framework")]
182 framework: self.framework.get().cloned(),
183 manager: Arc::clone(&self.manager),
184 #[cfg(feature = "voice")]
185 voice_manager: self.voice_manager.clone(),
186 shard,
187 #[cfg(feature = "cache")]
188 cache: Arc::clone(&self.cache),
189 http: Arc::clone(&self.http),
190 });
191
192 let runner_info = ShardRunnerInfo {
193 latency: None,
194 runner_tx: ShardMessenger::new(&runner),
195 stage: ConnectionStage::Disconnected,
196 };
197
198 spawn_named("shard_queuer::stop", async move {
199 drop(Box::pin(runner.run()).await);
200 debug!("[ShardRunner {:?}] Stopping", runner.shard.shard_info());
201 });
202
203 self.runners.lock().await.insert(id, runner_info);
204
205 Ok(())
206 }
207
208 #[instrument(skip(self))]
209 async fn shutdown_runners(&mut self) {
210 let keys = {
211 let runners = self.runners.lock().await;
212
213 if runners.is_empty() {
214 return;
215 }
216
217 runners.keys().copied().collect::<Vec<_>>()
218 };
219
220 info!("Shutting down all shards");
221
222 for shard_id in keys {
223 self.shutdown(shard_id, 1000).await;
224 }
225 }
226
227 #[instrument(skip(self))]
233 pub async fn shutdown(&mut self, shard_id: ShardId, code: u16) {
234 info!("Shutting down shard {}", shard_id);
235
236 if let Some(runner) = self.runners.lock().await.get(&shard_id) {
237 let msg = ShardRunnerMessage::Shutdown(shard_id, code);
238
239 if let Err(why) = runner.runner_tx.tx.unbounded_send(msg) {
240 warn!(
241 "Failed to cleanly shutdown shard {} when sending message to shard runner: {:?}",
242 shard_id,
243 why,
244 );
245 }
246 }
247 }
248}