Skip to main content

serenity/gateway/bridge/
shard_manager.rs

1use std::collections::{HashMap, VecDeque};
2use std::sync::atomic::{AtomicU32, Ordering};
3use std::sync::Arc;
4#[cfg(feature = "framework")]
5use std::sync::OnceLock;
6use std::time::Duration;
7
8use futures::channel::mpsc::{self, UnboundedReceiver as Receiver, UnboundedSender as Sender};
9use futures::{SinkExt, StreamExt};
10use tokio::sync::{Mutex, RwLock};
11use tokio::time::timeout;
12use tracing::{info, instrument, warn};
13use typemap_rev::TypeMap;
14
15#[cfg(feature = "voice")]
16use super::VoiceGatewayManager;
17use super::{ShardId, ShardQueuer, ShardQueuerMessage, ShardRunnerInfo};
18#[cfg(feature = "cache")]
19use crate::cache::Cache;
20use crate::client::{EventHandler, RawEventHandler};
21#[cfg(feature = "framework")]
22use crate::framework::Framework;
23use crate::gateway::{ConnectionStage, GatewayError, PresenceData};
24use crate::http::Http;
25use crate::internal::prelude::*;
26use crate::internal::tokio::spawn_named;
27use crate::model::gateway::GatewayIntents;
28
29/// A manager for handling the status of shards by starting them, restarting them, and stopping
30/// them when required.
31///
32/// **Note**: The [`Client`] internally uses a shard manager. If you are using a Client, then you
33/// do not need to make one of these.
34///
35/// # Examples
36///
37/// Initialize a shard manager with a framework responsible for shards 0 through 2, of 5 total
38/// shards:
39///
40/// ```rust,no_run
41/// # use std::error::Error;
42/// #
43/// # #[cfg(feature = "voice")]
44/// # use serenity::model::id::UserId;
45/// # #[cfg(feature = "cache")]
46/// # use serenity::cache::Cache;
47/// #
48/// # #[cfg(feature = "framework")]
49/// # async fn run() -> Result<(), Box<dyn Error>> {
50/// #
51/// use std::env;
52/// use std::sync::{Arc, OnceLock};
53///
54/// use serenity::client::{EventHandler, RawEventHandler};
55/// use serenity::framework::{Framework, StandardFramework};
56/// use serenity::gateway::{ShardManager, ShardManagerOptions};
57/// use serenity::http::Http;
58/// use serenity::model::gateway::GatewayIntents;
59/// use serenity::prelude::*;
60/// use tokio::sync::{Mutex, RwLock};
61///
62/// struct Handler;
63///
64/// impl EventHandler for Handler {}
65/// impl RawEventHandler for Handler {}
66///
67/// # let http: Arc<Http> = unimplemented!();
68/// let ws_url = Arc::new(Mutex::new(http.get_gateway().await?.url));
69/// let data = Arc::new(RwLock::new(TypeMap::new()));
70/// let event_handler = Arc::new(Handler) as Arc<dyn EventHandler>;
71/// let framework = Arc::new(StandardFramework::new()) as Arc<dyn Framework + 'static>;
72///
73/// ShardManager::new(ShardManagerOptions {
74///     data,
75///     event_handlers: vec![event_handler],
76///     raw_event_handlers: vec![],
77///     framework: Arc::new(OnceLock::from(framework)),
78///     // the shard index to start initiating from
79///     shard_index: 0,
80///     // the number of shards to initiate (this initiates 0, 1, and 2)
81///     shard_init: 3,
82///     // the total number of shards in use
83///     shard_total: 5,
84///     # #[cfg(feature = "voice")]
85///     # voice_manager: None,
86///     ws_url,
87///     # #[cfg(feature = "cache")]
88///     # cache: unimplemented!(),
89///     # http,
90///     intents: GatewayIntents::non_privileged(),
91///     presence: None,
92/// });
93/// # Ok(())
94/// # }
95/// ```
96///
97/// [`Client`]: crate::Client
98#[derive(Debug)]
99pub struct ShardManager {
100    return_value_tx: Mutex<Sender<Result<(), GatewayError>>>,
101    /// The shard runners currently managed.
102    ///
103    /// **Note**: It is highly unrecommended to mutate this yourself unless you need to. Instead
104    /// prefer to use methods on this struct that are provided where possible.
105    pub runners: Arc<Mutex<HashMap<ShardId, ShardRunnerInfo>>>,
106    /// The index of the first shard to initialize, 0-indexed.
107    // Atomics are used here to allow for mutation without requiring a mutable reference to self.
108    shard_index: AtomicU32,
109    /// The number of shards to initialize.
110    shard_init: AtomicU32,
111    /// The total shards in use, 1-indexed.
112    shard_total: AtomicU32,
113    shard_queuer: Sender<ShardQueuerMessage>,
114    // We can safely use a Mutex for this field, as it is only ever used in one single place
115    // and only is ever used to receive a single message
116    shard_shutdown: Mutex<Receiver<ShardId>>,
117    shard_shutdown_send: Sender<ShardId>,
118    gateway_intents: GatewayIntents,
119}
120
121impl ShardManager {
122    /// Creates a new shard manager, returning both the manager and a monitor for usage in a
123    /// separate thread.
124    #[must_use]
125    pub fn new(opt: ShardManagerOptions) -> (Arc<Self>, Receiver<Result<(), GatewayError>>) {
126        let (return_value_tx, return_value_rx) = mpsc::unbounded();
127        let (shard_queue_tx, shard_queue_rx) = mpsc::unbounded();
128
129        let runners = Arc::new(Mutex::new(HashMap::new()));
130        let (shutdown_send, shutdown_recv) = mpsc::unbounded();
131
132        let manager = Arc::new(Self {
133            return_value_tx: Mutex::new(return_value_tx),
134            shard_index: AtomicU32::new(opt.shard_index),
135            shard_init: AtomicU32::new(opt.shard_init),
136            shard_queuer: shard_queue_tx,
137            shard_total: AtomicU32::new(opt.shard_total),
138            shard_shutdown: Mutex::new(shutdown_recv),
139            shard_shutdown_send: shutdown_send,
140            runners: Arc::clone(&runners),
141            gateway_intents: opt.intents,
142        });
143
144        let mut shard_queuer = ShardQueuer {
145            data: opt.data,
146            event_handlers: opt.event_handlers,
147            raw_event_handlers: opt.raw_event_handlers,
148            #[cfg(feature = "framework")]
149            framework: opt.framework,
150            last_start: None,
151            manager: Arc::clone(&manager),
152            queue: VecDeque::new(),
153            runners,
154            rx: shard_queue_rx,
155            #[cfg(feature = "voice")]
156            voice_manager: opt.voice_manager,
157            ws_url: opt.ws_url,
158            #[cfg(feature = "cache")]
159            cache: opt.cache,
160            http: opt.http,
161            intents: opt.intents,
162            presence: opt.presence,
163        };
164
165        spawn_named("shard_queuer::run", async move {
166            shard_queuer.run().await;
167        });
168
169        (Arc::clone(&manager), return_value_rx)
170    }
171
172    /// Returns whether the shard manager contains either an active instance of a shard runner
173    /// responsible for the given ID.
174    ///
175    /// If a shard has been queued but has not yet been initiated, then this will return `false`.
176    pub async fn has(&self, shard_id: ShardId) -> bool {
177        self.runners.lock().await.contains_key(&shard_id)
178    }
179
180    /// Initializes all shards that the manager is responsible for.
181    ///
182    /// This will communicate shard boots with the [`ShardQueuer`] so that they are properly
183    /// queued.
184    #[instrument(skip(self))]
185    #[allow(clippy::missing_errors_doc)] // Doesn't actually error, fixed on next.
186    pub fn initialize(&self) -> Result<()> {
187        let shard_index = self.shard_index.load(Ordering::Relaxed);
188        let shard_init = self.shard_init.load(Ordering::Relaxed);
189        let shard_total = self.shard_total.load(Ordering::Relaxed);
190
191        let shard_to = shard_index + shard_init;
192
193        for shard_id in shard_index..shard_to {
194            self.boot([ShardId(shard_id), ShardId(shard_total)]);
195        }
196
197        Ok(())
198    }
199
200    /// Sets the new sharding information for the manager.
201    ///
202    /// This will shutdown all existing shards.
203    ///
204    /// This will _not_ instantiate the new shards.
205    #[instrument(skip(self))]
206    pub async fn set_shards(&self, index: u32, init: u32, total: u32) {
207        self.shutdown_all().await;
208
209        self.shard_index.store(index, Ordering::Relaxed);
210        self.shard_init.store(init, Ordering::Relaxed);
211        self.shard_total.store(total, Ordering::Relaxed);
212    }
213
214    /// Restarts a shard runner.
215    ///
216    /// This sends a shutdown signal to a shard's associated [`ShardRunner`], and then queues a
217    /// initialization of a shard runner for the same shard via the [`ShardQueuer`].
218    ///
219    /// # Examples
220    ///
221    /// Restarting a shard by ID:
222    ///
223    /// ```rust,no_run
224    /// use serenity::model::id::ShardId;
225    /// use serenity::prelude::*;
226    ///
227    /// # async fn run(client: Client) {
228    /// // restart shard ID 7
229    /// client.shard_manager.restart(ShardId(7)).await;
230    /// # }
231    /// ```
232    ///
233    /// [`ShardRunner`]: super::ShardRunner
234    #[instrument(skip(self))]
235    pub async fn restart(&self, shard_id: ShardId) {
236        info!("Restarting shard {}", shard_id);
237        self.shutdown(shard_id, 4000).await;
238
239        let shard_total = self.shard_total.load(Ordering::Relaxed);
240
241        self.boot([shard_id, ShardId(shard_total)]);
242    }
243
244    /// Returns the [`ShardId`]s of the shards that have been instantiated and currently have a
245    /// valid [`ShardRunner`].
246    ///
247    /// [`ShardRunner`]: super::ShardRunner
248    #[instrument(skip(self))]
249    pub async fn shards_instantiated(&self) -> Vec<ShardId> {
250        self.runners.lock().await.keys().copied().collect()
251    }
252
253    /// Attempts to shut down the shard runner by Id.
254    ///
255    /// Returns a boolean indicating whether a shard runner was present. This is _not_ necessary an
256    /// indicator of whether the shard runner was successfully shut down.
257    ///
258    /// **Note**: If the receiving end of an mpsc channel - owned by the shard runner - no longer
259    /// exists, then the shard runner will not know it should shut down. This _should never happen_.
260    /// It may already be stopped.
261    #[instrument(skip(self))]
262    pub async fn shutdown(&self, shard_id: ShardId, code: u16) {
263        const TIMEOUT: tokio::time::Duration = tokio::time::Duration::from_secs(5);
264
265        info!("Shutting down shard {}", shard_id);
266
267        {
268            let mut shard_shutdown = self.shard_shutdown.lock().await;
269
270            drop(
271                self.shard_queuer.unbounded_send(ShardQueuerMessage::ShutdownShard(shard_id, code)),
272            );
273            match timeout(TIMEOUT, shard_shutdown.next()).await {
274                Ok(Some(shutdown_shard_id)) => {
275                    if shutdown_shard_id != shard_id {
276                        warn!(
277                        "Failed to cleanly shutdown shard {}: Shutdown channel sent incorrect ID",
278                        shard_id,
279                    );
280                    }
281                },
282                Ok(None) => (),
283                Err(why) => {
284                    warn!(
285                        "Failed to cleanly shutdown shard {}, reached timeout: {:?}",
286                        shard_id, why
287                    );
288                },
289            }
290            // shard_shutdown is dropped here and releases the lock
291            // in theory we should never have two calls to shutdown()
292            // at the same time but this is a safety measure just in case:tm:
293        }
294
295        self.runners.lock().await.remove(&shard_id);
296    }
297
298    /// Sends a shutdown message for all shards that the manager is responsible for that are still
299    /// known to be running.
300    ///
301    /// If you only need to shutdown a select number of shards, prefer looping over the
302    /// [`Self::shutdown`] method.
303    #[instrument(skip(self))]
304    pub async fn shutdown_all(&self) {
305        let keys = {
306            let runners = self.runners.lock().await;
307
308            if runners.is_empty() {
309                return;
310            }
311
312            runners.keys().copied().collect::<Vec<_>>()
313        };
314
315        info!("Shutting down all shards");
316
317        for shard_id in keys {
318            self.shutdown(shard_id, 1000).await;
319        }
320
321        drop(self.shard_queuer.unbounded_send(ShardQueuerMessage::Shutdown));
322
323        // this message is received by Client::start_connection, which lets the main thread know
324        // and finally return from Client::start
325        drop(self.return_value_tx.lock().await.unbounded_send(Ok(())));
326    }
327
328    #[instrument(skip(self))]
329    fn boot(&self, shard_info: [ShardId; 2]) {
330        info!("Telling shard queuer to start shard {}", shard_info[0]);
331
332        let msg = ShardQueuerMessage::Start(shard_info[0], shard_info[1]);
333
334        drop(self.shard_queuer.unbounded_send(msg));
335    }
336
337    /// Returns the gateway intents used for this gateway connection.
338    #[must_use]
339    pub fn intents(&self) -> GatewayIntents {
340        self.gateway_intents
341    }
342
343    pub async fn return_with_value(&self, ret: Result<(), GatewayError>) {
344        if let Err(e) = self.return_value_tx.lock().await.send(ret).await {
345            tracing::warn!("failed to send return value: {}", e);
346        }
347    }
348
349    pub fn shutdown_finished(&self, id: ShardId) {
350        if let Err(e) = self.shard_shutdown_send.unbounded_send(id) {
351            tracing::warn!("failed to notify about finished shutdown: {}", e);
352        }
353    }
354
355    pub async fn restart_shard(&self, id: ShardId) {
356        self.restart(id).await;
357        if let Err(e) = self.shard_shutdown_send.unbounded_send(id) {
358            tracing::warn!("failed to notify about finished shutdown: {}", e);
359        }
360    }
361
362    pub async fn update_shard_latency_and_stage(
363        &self,
364        id: ShardId,
365        latency: Option<Duration>,
366        stage: ConnectionStage,
367    ) {
368        if let Some(runner) = self.runners.lock().await.get_mut(&id) {
369            runner.latency = latency;
370            runner.stage = stage;
371        }
372    }
373}
374
375impl Drop for ShardManager {
376    /// A custom drop implementation to clean up after the manager.
377    ///
378    /// This shuts down all active [`ShardRunner`]s and attempts to tell the [`ShardQueuer`] to
379    /// shutdown.
380    ///
381    /// [`ShardRunner`]: super::ShardRunner
382    fn drop(&mut self) {
383        drop(self.shard_queuer.unbounded_send(ShardQueuerMessage::Shutdown));
384    }
385}
386
387pub struct ShardManagerOptions {
388    pub data: Arc<RwLock<TypeMap>>,
389    pub event_handlers: Vec<Arc<dyn EventHandler>>,
390    pub raw_event_handlers: Vec<Arc<dyn RawEventHandler>>,
391    #[cfg(feature = "framework")]
392    pub framework: Arc<OnceLock<Arc<dyn Framework>>>,
393    pub shard_index: u32,
394    pub shard_init: u32,
395    pub shard_total: u32,
396    #[cfg(feature = "voice")]
397    pub voice_manager: Option<Arc<dyn VoiceGatewayManager>>,
398    pub ws_url: Arc<Mutex<String>>,
399    #[cfg(feature = "cache")]
400    pub cache: Arc<Cache>,
401    pub http: Arc<Http>,
402    pub intents: GatewayIntents,
403    pub presence: Option<PresenceData>,
404}