1use std::collections::{HashMap, VecDeque};
2use std::sync::atomic::{AtomicU32, Ordering};
3use std::sync::Arc;
4#[cfg(feature = "framework")]
5use std::sync::OnceLock;
6use std::time::Duration;
7
8use futures::channel::mpsc::{self, UnboundedReceiver as Receiver, UnboundedSender as Sender};
9use futures::{SinkExt, StreamExt};
10use tokio::sync::{Mutex, RwLock};
11use tokio::time::timeout;
12use tracing::{info, instrument, warn};
13use typemap_rev::TypeMap;
14
15#[cfg(feature = "voice")]
16use super::VoiceGatewayManager;
17use super::{ShardId, ShardQueuer, ShardQueuerMessage, ShardRunnerInfo};
18#[cfg(feature = "cache")]
19use crate::cache::Cache;
20use crate::client::{EventHandler, RawEventHandler};
21#[cfg(feature = "framework")]
22use crate::framework::Framework;
23use crate::gateway::{ConnectionStage, GatewayError, PresenceData};
24use crate::http::Http;
25use crate::internal::prelude::*;
26use crate::internal::tokio::spawn_named;
27use crate::model::gateway::GatewayIntents;
28
29#[derive(Debug)]
99pub struct ShardManager {
100 return_value_tx: Mutex<Sender<Result<(), GatewayError>>>,
101 pub runners: Arc<Mutex<HashMap<ShardId, ShardRunnerInfo>>>,
106 shard_index: AtomicU32,
109 shard_init: AtomicU32,
111 shard_total: AtomicU32,
113 shard_queuer: Sender<ShardQueuerMessage>,
114 shard_shutdown: Mutex<Receiver<ShardId>>,
117 shard_shutdown_send: Sender<ShardId>,
118 gateway_intents: GatewayIntents,
119}
120
121impl ShardManager {
122 #[must_use]
125 pub fn new(opt: ShardManagerOptions) -> (Arc<Self>, Receiver<Result<(), GatewayError>>) {
126 let (return_value_tx, return_value_rx) = mpsc::unbounded();
127 let (shard_queue_tx, shard_queue_rx) = mpsc::unbounded();
128
129 let runners = Arc::new(Mutex::new(HashMap::new()));
130 let (shutdown_send, shutdown_recv) = mpsc::unbounded();
131
132 let manager = Arc::new(Self {
133 return_value_tx: Mutex::new(return_value_tx),
134 shard_index: AtomicU32::new(opt.shard_index),
135 shard_init: AtomicU32::new(opt.shard_init),
136 shard_queuer: shard_queue_tx,
137 shard_total: AtomicU32::new(opt.shard_total),
138 shard_shutdown: Mutex::new(shutdown_recv),
139 shard_shutdown_send: shutdown_send,
140 runners: Arc::clone(&runners),
141 gateway_intents: opt.intents,
142 });
143
144 let mut shard_queuer = ShardQueuer {
145 data: opt.data,
146 event_handlers: opt.event_handlers,
147 raw_event_handlers: opt.raw_event_handlers,
148 #[cfg(feature = "framework")]
149 framework: opt.framework,
150 last_start: None,
151 manager: Arc::clone(&manager),
152 queue: VecDeque::new(),
153 runners,
154 rx: shard_queue_rx,
155 #[cfg(feature = "voice")]
156 voice_manager: opt.voice_manager,
157 ws_url: opt.ws_url,
158 #[cfg(feature = "cache")]
159 cache: opt.cache,
160 http: opt.http,
161 intents: opt.intents,
162 presence: opt.presence,
163 };
164
165 spawn_named("shard_queuer::run", async move {
166 shard_queuer.run().await;
167 });
168
169 (Arc::clone(&manager), return_value_rx)
170 }
171
172 pub async fn has(&self, shard_id: ShardId) -> bool {
177 self.runners.lock().await.contains_key(&shard_id)
178 }
179
180 #[instrument(skip(self))]
185 pub fn initialize(&self) -> Result<()> {
186 let shard_index = self.shard_index.load(Ordering::Relaxed);
187 let shard_init = self.shard_init.load(Ordering::Relaxed);
188 let shard_total = self.shard_total.load(Ordering::Relaxed);
189
190 let shard_to = shard_index + shard_init;
191
192 for shard_id in shard_index..shard_to {
193 self.boot([ShardId(shard_id), ShardId(shard_total)]);
194 }
195
196 Ok(())
197 }
198
199 #[instrument(skip(self))]
205 pub async fn set_shards(&self, index: u32, init: u32, total: u32) {
206 self.shutdown_all().await;
207
208 self.shard_index.store(index, Ordering::Relaxed);
209 self.shard_init.store(init, Ordering::Relaxed);
210 self.shard_total.store(total, Ordering::Relaxed);
211 }
212
213 #[instrument(skip(self))]
234 pub async fn restart(&self, shard_id: ShardId) {
235 info!("Restarting shard {}", shard_id);
236 self.shutdown(shard_id, 4000).await;
237
238 let shard_total = self.shard_total.load(Ordering::Relaxed);
239
240 self.boot([shard_id, ShardId(shard_total)]);
241 }
242
243 #[instrument(skip(self))]
248 pub async fn shards_instantiated(&self) -> Vec<ShardId> {
249 self.runners.lock().await.keys().copied().collect()
250 }
251
252 #[instrument(skip(self))]
261 pub async fn shutdown(&self, shard_id: ShardId, code: u16) {
262 const TIMEOUT: tokio::time::Duration = tokio::time::Duration::from_secs(5);
263
264 info!("Shutting down shard {}", shard_id);
265
266 {
267 let mut shard_shutdown = self.shard_shutdown.lock().await;
268
269 drop(
270 self.shard_queuer.unbounded_send(ShardQueuerMessage::ShutdownShard(shard_id, code)),
271 );
272 match timeout(TIMEOUT, shard_shutdown.next()).await {
273 Ok(Some(shutdown_shard_id)) => {
274 if shutdown_shard_id != shard_id {
275 warn!(
276 "Failed to cleanly shutdown shard {}: Shutdown channel sent incorrect ID",
277 shard_id,
278 );
279 }
280 },
281 Ok(None) => (),
282 Err(why) => {
283 warn!(
284 "Failed to cleanly shutdown shard {}, reached timeout: {:?}",
285 shard_id, why
286 );
287 },
288 }
289 }
293
294 self.runners.lock().await.remove(&shard_id);
295 }
296
297 #[instrument(skip(self))]
303 pub async fn shutdown_all(&self) {
304 let keys = {
305 let runners = self.runners.lock().await;
306
307 if runners.is_empty() {
308 return;
309 }
310
311 runners.keys().copied().collect::<Vec<_>>()
312 };
313
314 info!("Shutting down all shards");
315
316 for shard_id in keys {
317 self.shutdown(shard_id, 1000).await;
318 }
319
320 drop(self.shard_queuer.unbounded_send(ShardQueuerMessage::Shutdown));
321
322 drop(self.return_value_tx.lock().await.unbounded_send(Ok(())));
325 }
326
327 #[instrument(skip(self))]
328 fn boot(&self, shard_info: [ShardId; 2]) {
329 info!("Telling shard queuer to start shard {}", shard_info[0]);
330
331 let msg = ShardQueuerMessage::Start(shard_info[0], shard_info[1]);
332
333 drop(self.shard_queuer.unbounded_send(msg));
334 }
335
336 #[must_use]
338 pub fn intents(&self) -> GatewayIntents {
339 self.gateway_intents
340 }
341
342 pub async fn return_with_value(&self, ret: Result<(), GatewayError>) {
343 if let Err(e) = self.return_value_tx.lock().await.send(ret).await {
344 tracing::warn!("failed to send return value: {}", e);
345 }
346 }
347
348 pub fn shutdown_finished(&self, id: ShardId) {
349 if let Err(e) = self.shard_shutdown_send.unbounded_send(id) {
350 tracing::warn!("failed to notify about finished shutdown: {}", e);
351 }
352 }
353
354 pub async fn restart_shard(&self, id: ShardId) {
355 self.restart(id).await;
356 if let Err(e) = self.shard_shutdown_send.unbounded_send(id) {
357 tracing::warn!("failed to notify about finished shutdown: {}", e);
358 }
359 }
360
361 pub async fn update_shard_latency_and_stage(
362 &self,
363 id: ShardId,
364 latency: Option<Duration>,
365 stage: ConnectionStage,
366 ) {
367 if let Some(runner) = self.runners.lock().await.get_mut(&id) {
368 runner.latency = latency;
369 runner.stage = stage;
370 }
371 }
372}
373
374impl Drop for ShardManager {
375 fn drop(&mut self) {
382 drop(self.shard_queuer.unbounded_send(ShardQueuerMessage::Shutdown));
383 }
384}
385
386pub struct ShardManagerOptions {
387 pub data: Arc<RwLock<TypeMap>>,
388 pub event_handlers: Vec<Arc<dyn EventHandler>>,
389 pub raw_event_handlers: Vec<Arc<dyn RawEventHandler>>,
390 #[cfg(feature = "framework")]
391 pub framework: Arc<OnceLock<Arc<dyn Framework>>>,
392 pub shard_index: u32,
393 pub shard_init: u32,
394 pub shard_total: u32,
395 #[cfg(feature = "voice")]
396 pub voice_manager: Option<Arc<dyn VoiceGatewayManager>>,
397 pub ws_url: Arc<Mutex<String>>,
398 #[cfg(feature = "cache")]
399 pub cache: Arc<Cache>,
400 pub http: Arc<Http>,
401 pub intents: GatewayIntents,
402 pub presence: Option<PresenceData>,
403}