1use std::collections::{HashMap, VecDeque};
2use std::sync::atomic::{AtomicU32, Ordering};
3use std::sync::Arc;
4#[cfg(feature = "framework")]
5use std::sync::OnceLock;
6use std::time::Duration;
7
8use futures::channel::mpsc::{self, UnboundedReceiver as Receiver, UnboundedSender as Sender};
9use futures::{SinkExt, StreamExt};
10use tokio::sync::{Mutex, RwLock};
11use tokio::time::timeout;
12use tracing::{info, instrument, warn};
13use typemap_rev::TypeMap;
14
15#[cfg(feature = "voice")]
16use super::VoiceGatewayManager;
17use super::{ShardId, ShardQueuer, ShardQueuerMessage, ShardRunnerInfo};
18#[cfg(feature = "cache")]
19use crate::cache::Cache;
20use crate::client::{EventHandler, RawEventHandler};
21#[cfg(feature = "framework")]
22use crate::framework::Framework;
23use crate::gateway::{ConnectionStage, GatewayError, PresenceData};
24use crate::http::Http;
25use crate::internal::prelude::*;
26use crate::internal::tokio::spawn_named;
27use crate::model::gateway::GatewayIntents;
28
29#[derive(Debug)]
99pub struct ShardManager {
100 return_value_tx: Mutex<Sender<Result<(), GatewayError>>>,
101 pub runners: Arc<Mutex<HashMap<ShardId, ShardRunnerInfo>>>,
106 shard_index: AtomicU32,
109 shard_init: AtomicU32,
111 shard_total: AtomicU32,
113 shard_queuer: Sender<ShardQueuerMessage>,
114 shard_shutdown: Mutex<Receiver<ShardId>>,
117 shard_shutdown_send: Sender<ShardId>,
118 gateway_intents: GatewayIntents,
119}
120
121impl ShardManager {
122 #[must_use]
125 pub fn new(opt: ShardManagerOptions) -> (Arc<Self>, Receiver<Result<(), GatewayError>>) {
126 let (return_value_tx, return_value_rx) = mpsc::unbounded();
127 let (shard_queue_tx, shard_queue_rx) = mpsc::unbounded();
128
129 let runners = Arc::new(Mutex::new(HashMap::new()));
130 let (shutdown_send, shutdown_recv) = mpsc::unbounded();
131
132 let manager = Arc::new(Self {
133 return_value_tx: Mutex::new(return_value_tx),
134 shard_index: AtomicU32::new(opt.shard_index),
135 shard_init: AtomicU32::new(opt.shard_init),
136 shard_queuer: shard_queue_tx,
137 shard_total: AtomicU32::new(opt.shard_total),
138 shard_shutdown: Mutex::new(shutdown_recv),
139 shard_shutdown_send: shutdown_send,
140 runners: Arc::clone(&runners),
141 gateway_intents: opt.intents,
142 });
143
144 let mut shard_queuer = ShardQueuer {
145 data: opt.data,
146 event_handlers: opt.event_handlers,
147 raw_event_handlers: opt.raw_event_handlers,
148 #[cfg(feature = "framework")]
149 framework: opt.framework,
150 last_start: None,
151 manager: Arc::clone(&manager),
152 queue: VecDeque::new(),
153 runners,
154 rx: shard_queue_rx,
155 #[cfg(feature = "voice")]
156 voice_manager: opt.voice_manager,
157 ws_url: opt.ws_url,
158 #[cfg(feature = "cache")]
159 cache: opt.cache,
160 http: opt.http,
161 intents: opt.intents,
162 presence: opt.presence,
163 };
164
165 spawn_named("shard_queuer::run", async move {
166 shard_queuer.run().await;
167 });
168
169 (Arc::clone(&manager), return_value_rx)
170 }
171
172 pub async fn has(&self, shard_id: ShardId) -> bool {
177 self.runners.lock().await.contains_key(&shard_id)
178 }
179
180 #[instrument(skip(self))]
185 #[allow(clippy::missing_errors_doc)] pub fn initialize(&self) -> Result<()> {
187 let shard_index = self.shard_index.load(Ordering::Relaxed);
188 let shard_init = self.shard_init.load(Ordering::Relaxed);
189 let shard_total = self.shard_total.load(Ordering::Relaxed);
190
191 let shard_to = shard_index + shard_init;
192
193 for shard_id in shard_index..shard_to {
194 self.boot([ShardId(shard_id), ShardId(shard_total)]);
195 }
196
197 Ok(())
198 }
199
200 #[instrument(skip(self))]
206 pub async fn set_shards(&self, index: u32, init: u32, total: u32) {
207 self.shutdown_all().await;
208
209 self.shard_index.store(index, Ordering::Relaxed);
210 self.shard_init.store(init, Ordering::Relaxed);
211 self.shard_total.store(total, Ordering::Relaxed);
212 }
213
214 #[instrument(skip(self))]
235 pub async fn restart(&self, shard_id: ShardId) {
236 info!("Restarting shard {}", shard_id);
237 self.shutdown(shard_id, 4000).await;
238
239 let shard_total = self.shard_total.load(Ordering::Relaxed);
240
241 self.boot([shard_id, ShardId(shard_total)]);
242 }
243
244 #[instrument(skip(self))]
249 pub async fn shards_instantiated(&self) -> Vec<ShardId> {
250 self.runners.lock().await.keys().copied().collect()
251 }
252
253 #[instrument(skip(self))]
262 pub async fn shutdown(&self, shard_id: ShardId, code: u16) {
263 const TIMEOUT: tokio::time::Duration = tokio::time::Duration::from_secs(5);
264
265 info!("Shutting down shard {}", shard_id);
266
267 {
268 let mut shard_shutdown = self.shard_shutdown.lock().await;
269
270 drop(
271 self.shard_queuer.unbounded_send(ShardQueuerMessage::ShutdownShard(shard_id, code)),
272 );
273 match timeout(TIMEOUT, shard_shutdown.next()).await {
274 Ok(Some(shutdown_shard_id)) => {
275 if shutdown_shard_id != shard_id {
276 warn!(
277 "Failed to cleanly shutdown shard {}: Shutdown channel sent incorrect ID",
278 shard_id,
279 );
280 }
281 },
282 Ok(None) => (),
283 Err(why) => {
284 warn!(
285 "Failed to cleanly shutdown shard {}, reached timeout: {:?}",
286 shard_id, why
287 );
288 },
289 }
290 }
294
295 self.runners.lock().await.remove(&shard_id);
296 }
297
298 #[instrument(skip(self))]
304 pub async fn shutdown_all(&self) {
305 let keys = {
306 let runners = self.runners.lock().await;
307
308 if runners.is_empty() {
309 return;
310 }
311
312 runners.keys().copied().collect::<Vec<_>>()
313 };
314
315 info!("Shutting down all shards");
316
317 for shard_id in keys {
318 self.shutdown(shard_id, 1000).await;
319 }
320
321 drop(self.shard_queuer.unbounded_send(ShardQueuerMessage::Shutdown));
322
323 drop(self.return_value_tx.lock().await.unbounded_send(Ok(())));
326 }
327
328 #[instrument(skip(self))]
329 fn boot(&self, shard_info: [ShardId; 2]) {
330 info!("Telling shard queuer to start shard {}", shard_info[0]);
331
332 let msg = ShardQueuerMessage::Start(shard_info[0], shard_info[1]);
333
334 drop(self.shard_queuer.unbounded_send(msg));
335 }
336
337 #[must_use]
339 pub fn intents(&self) -> GatewayIntents {
340 self.gateway_intents
341 }
342
343 pub async fn return_with_value(&self, ret: Result<(), GatewayError>) {
344 if let Err(e) = self.return_value_tx.lock().await.send(ret).await {
345 tracing::warn!("failed to send return value: {}", e);
346 }
347 }
348
349 pub fn shutdown_finished(&self, id: ShardId) {
350 if let Err(e) = self.shard_shutdown_send.unbounded_send(id) {
351 tracing::warn!("failed to notify about finished shutdown: {}", e);
352 }
353 }
354
355 pub async fn restart_shard(&self, id: ShardId) {
356 self.restart(id).await;
357 if let Err(e) = self.shard_shutdown_send.unbounded_send(id) {
358 tracing::warn!("failed to notify about finished shutdown: {}", e);
359 }
360 }
361
362 pub async fn update_shard_latency_and_stage(
363 &self,
364 id: ShardId,
365 latency: Option<Duration>,
366 stage: ConnectionStage,
367 ) {
368 if let Some(runner) = self.runners.lock().await.get_mut(&id) {
369 runner.latency = latency;
370 runner.stage = stage;
371 }
372 }
373}
374
375impl Drop for ShardManager {
376 fn drop(&mut self) {
383 drop(self.shard_queuer.unbounded_send(ShardQueuerMessage::Shutdown));
384 }
385}
386
387pub struct ShardManagerOptions {
388 pub data: Arc<RwLock<TypeMap>>,
389 pub event_handlers: Vec<Arc<dyn EventHandler>>,
390 pub raw_event_handlers: Vec<Arc<dyn RawEventHandler>>,
391 #[cfg(feature = "framework")]
392 pub framework: Arc<OnceLock<Arc<dyn Framework>>>,
393 pub shard_index: u32,
394 pub shard_init: u32,
395 pub shard_total: u32,
396 #[cfg(feature = "voice")]
397 pub voice_manager: Option<Arc<dyn VoiceGatewayManager>>,
398 pub ws_url: Arc<Mutex<String>>,
399 #[cfg(feature = "cache")]
400 pub cache: Arc<Cache>,
401 pub http: Arc<Http>,
402 pub intents: GatewayIntents,
403 pub presence: Option<PresenceData>,
404}