1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3#[cfg(feature = "framework")]
4use std::sync::OnceLock;
5
6use futures::channel::mpsc::UnboundedReceiver as Receiver;
7use futures::StreamExt;
8use tokio::sync::{Mutex, RwLock};
9use tokio::time::{sleep, timeout, Duration, Instant};
10use tracing::{debug, info, instrument, warn};
11use typemap_rev::TypeMap;
12
13#[cfg(feature = "voice")]
14use super::VoiceGatewayManager;
15use super::{
16 ShardId,
17 ShardManager,
18 ShardMessenger,
19 ShardQueuerMessage,
20 ShardRunner,
21 ShardRunnerInfo,
22 ShardRunnerOptions,
23};
24#[cfg(feature = "cache")]
25use crate::cache::Cache;
26use crate::client::{EventHandler, RawEventHandler};
27#[cfg(feature = "framework")]
28use crate::framework::Framework;
29use crate::gateway::{ConnectionStage, PresenceData, Shard, ShardRunnerMessage};
30use crate::http::Http;
31use crate::internal::prelude::*;
32use crate::internal::tokio::spawn_named;
33use crate::model::gateway::{GatewayIntents, ShardInfo};
34
35const WAIT_BETWEEN_BOOTS_IN_SECONDS: u64 = 5;
36
37pub struct ShardQueuer {
42 pub data: Arc<RwLock<TypeMap>>,
46 pub event_handlers: Vec<Arc<dyn EventHandler>>,
50 pub raw_event_handlers: Vec<Arc<dyn RawEventHandler>>,
54 #[cfg(feature = "framework")]
56 pub framework: Arc<OnceLock<Arc<dyn Framework>>>,
57 pub last_start: Option<Instant>,
61 pub manager: Arc<ShardManager>,
63 pub queue: VecDeque<ShardInfo>,
67 pub runners: Arc<Mutex<HashMap<ShardId, ShardRunnerInfo>>>,
69 pub rx: Receiver<ShardQueuerMessage>,
71 #[cfg(feature = "voice")]
73 pub voice_manager: Option<Arc<dyn VoiceGatewayManager + 'static>>,
74 pub ws_url: Arc<Mutex<String>>,
76 #[cfg(feature = "cache")]
77 pub cache: Arc<Cache>,
78 pub http: Arc<Http>,
79 pub intents: GatewayIntents,
80 pub presence: Option<PresenceData>,
81}
82
83impl ShardQueuer {
84 #[instrument(skip(self))]
101 pub async fn run(&mut self) {
102 const TIMEOUT: Duration = Duration::from_secs(WAIT_BETWEEN_BOOTS_IN_SECONDS);
106
107 loop {
108 match timeout(TIMEOUT, self.rx.next()).await {
109 Ok(Some(ShardQueuerMessage::Shutdown)) => {
110 debug!("[Shard Queuer] Received to shutdown.");
111 self.shutdown_runners().await;
112
113 break;
114 },
115 Ok(Some(ShardQueuerMessage::ShutdownShard(shard, code))) => {
116 debug!("[Shard Queuer] Received to shutdown shard {} with {}.", shard.0, code);
117 self.shutdown(shard, code).await;
118 },
119 Ok(Some(ShardQueuerMessage::Start(id, total))) => {
120 debug!("[Shard Queuer] Received to start shard {} of {}.", id.0, total.0);
121 self.checked_start(id, total.0).await;
122 },
123 Ok(None) => break,
124 Err(_) => {
125 if let Some(shard) = self.queue.pop_front() {
126 self.checked_start(shard.id, shard.total).await;
127 }
128 },
129 }
130 }
131 }
132
133 #[instrument(skip(self))]
134 async fn check_last_start(&mut self) {
135 let Some(instant) = self.last_start else { return };
136
137 let duration = Duration::from_secs(WAIT_BETWEEN_BOOTS_IN_SECONDS);
139 let elapsed = instant.elapsed();
140
141 if elapsed >= duration {
142 return;
143 }
144
145 let to_sleep = duration - elapsed;
146
147 sleep(to_sleep).await;
148 }
149
150 #[instrument(skip(self))]
151 async fn checked_start(&mut self, id: ShardId, total: u32) {
152 debug!("[Shard Queuer] Checked start for shard {} out of {}", id, total);
153 self.check_last_start().await;
154
155 if let Err(why) = self.start(id, total).await {
156 warn!("[Shard Queuer] Err starting shard {}: {:?}", id, why);
157 info!("[Shard Queuer] Re-queueing start of shard {}", id);
158
159 self.queue.push_back(ShardInfo::new(id, total));
160 }
161
162 self.last_start = Some(Instant::now());
163 }
164
165 #[instrument(skip(self))]
166 async fn start(&mut self, id: ShardId, total: u32) -> Result<()> {
167 let shard_info = ShardInfo::new(id, total);
168
169 let mut shard = Shard::new(
170 Arc::clone(&self.ws_url),
171 self.http.token(),
172 shard_info,
173 self.intents,
174 self.presence.clone(),
175 )
176 .await?;
177
178 let cloned_http = Arc::clone(&self.http);
179 shard.set_application_id_callback(move |id| cloned_http.set_application_id(id));
180
181 let mut runner = ShardRunner::new(ShardRunnerOptions {
182 data: Arc::clone(&self.data),
183 event_handlers: self.event_handlers.clone(),
184 raw_event_handlers: self.raw_event_handlers.clone(),
185 #[cfg(feature = "framework")]
186 framework: self.framework.get().cloned(),
187 manager: Arc::clone(&self.manager),
188 #[cfg(feature = "voice")]
189 voice_manager: self.voice_manager.clone(),
190 shard,
191 #[cfg(feature = "cache")]
192 cache: Arc::clone(&self.cache),
193 http: Arc::clone(&self.http),
194 });
195
196 let runner_info = ShardRunnerInfo {
197 latency: None,
198 runner_tx: ShardMessenger::new(&runner),
199 stage: ConnectionStage::Disconnected,
200 };
201
202 spawn_named("shard_queuer::stop", async move {
203 drop(runner.run().await);
204 debug!("[ShardRunner {:?}] Stopping", runner.shard.shard_info());
205 });
206
207 self.runners.lock().await.insert(id, runner_info);
208
209 Ok(())
210 }
211
212 #[instrument(skip(self))]
213 async fn shutdown_runners(&mut self) {
214 let keys = {
215 let runners = self.runners.lock().await;
216
217 if runners.is_empty() {
218 return;
219 }
220
221 runners.keys().copied().collect::<Vec<_>>()
222 };
223
224 info!("Shutting down all shards");
225
226 for shard_id in keys {
227 self.shutdown(shard_id, 1000).await;
228 }
229 }
230
231 #[instrument(skip(self))]
237 pub async fn shutdown(&mut self, shard_id: ShardId, code: u16) {
238 info!("Shutting down shard {}", shard_id);
239
240 if let Some(runner) = self.runners.lock().await.get(&shard_id) {
241 let msg = ShardRunnerMessage::Shutdown(shard_id, code);
242
243 if let Err(why) = runner.runner_tx.tx.unbounded_send(msg) {
244 warn!(
245 "Failed to cleanly shutdown shard {} when sending message to shard runner: {:?}",
246 shard_id,
247 why,
248 );
249 }
250 }
251 }
252}