serenity/gateway/bridge/
shard_manager.rs

1use std::collections::{HashMap, VecDeque};
2use std::sync::atomic::{AtomicU32, Ordering};
3use std::sync::Arc;
4#[cfg(feature = "framework")]
5use std::sync::OnceLock;
6use std::time::Duration;
7
8use futures::channel::mpsc::{self, UnboundedReceiver as Receiver, UnboundedSender as Sender};
9use futures::{SinkExt, StreamExt};
10use tokio::sync::{Mutex, RwLock};
11use tokio::time::timeout;
12use tracing::{info, instrument, warn};
13use typemap_rev::TypeMap;
14
15#[cfg(feature = "voice")]
16use super::VoiceGatewayManager;
17use super::{ShardId, ShardQueuer, ShardQueuerMessage, ShardRunnerInfo};
18#[cfg(feature = "cache")]
19use crate::cache::Cache;
20use crate::client::{EventHandler, RawEventHandler};
21#[cfg(feature = "framework")]
22use crate::framework::Framework;
23use crate::gateway::{ConnectionStage, GatewayError, PresenceData};
24use crate::http::Http;
25use crate::internal::prelude::*;
26use crate::internal::tokio::spawn_named;
27use crate::model::gateway::GatewayIntents;
28
29/// A manager for handling the status of shards by starting them, restarting them, and stopping
30/// them when required.
31///
32/// **Note**: The [`Client`] internally uses a shard manager. If you are using a Client, then you
33/// do not need to make one of these.
34///
35/// # Examples
36///
37/// Initialize a shard manager with a framework responsible for shards 0 through 2, of 5 total
38/// shards:
39///
40/// ```rust,no_run
41/// # use std::error::Error;
42/// #
43/// # #[cfg(feature = "voice")]
44/// # use serenity::model::id::UserId;
45/// # #[cfg(feature = "cache")]
46/// # use serenity::cache::Cache;
47/// #
48/// # #[cfg(feature = "framework")]
49/// # async fn run() -> Result<(), Box<dyn Error>> {
50/// #
51/// use std::env;
52/// use std::sync::{Arc, OnceLock};
53///
54/// use serenity::client::{EventHandler, RawEventHandler};
55/// use serenity::framework::{Framework, StandardFramework};
56/// use serenity::gateway::{ShardManager, ShardManagerOptions};
57/// use serenity::http::Http;
58/// use serenity::model::gateway::GatewayIntents;
59/// use serenity::prelude::*;
60/// use tokio::sync::{Mutex, RwLock};
61///
62/// struct Handler;
63///
64/// impl EventHandler for Handler {}
65/// impl RawEventHandler for Handler {}
66///
67/// # let http: Arc<Http> = unimplemented!();
68/// let ws_url = Arc::new(Mutex::new(http.get_gateway().await?.url));
69/// let data = Arc::new(RwLock::new(TypeMap::new()));
70/// let event_handler = Arc::new(Handler) as Arc<dyn EventHandler>;
71/// let framework = Arc::new(StandardFramework::new()) as Arc<dyn Framework + 'static>;
72///
73/// ShardManager::new(ShardManagerOptions {
74///     data,
75///     event_handlers: vec![event_handler],
76///     raw_event_handlers: vec![],
77///     framework: Arc::new(OnceLock::from(framework)),
78///     // the shard index to start initiating from
79///     shard_index: 0,
80///     // the number of shards to initiate (this initiates 0, 1, and 2)
81///     shard_init: 3,
82///     // the total number of shards in use
83///     shard_total: 5,
84///     # #[cfg(feature = "voice")]
85///     # voice_manager: None,
86///     ws_url,
87///     # #[cfg(feature = "cache")]
88///     # cache: unimplemented!(),
89///     # http,
90///     intents: GatewayIntents::non_privileged(),
91///     presence: None,
92/// });
93/// # Ok(())
94/// # }
95/// ```
96///
97/// [`Client`]: crate::Client
98#[derive(Debug)]
99pub struct ShardManager {
100    return_value_tx: Mutex<Sender<Result<(), GatewayError>>>,
101    /// The shard runners currently managed.
102    ///
103    /// **Note**: It is highly unrecommended to mutate this yourself unless you need to. Instead
104    /// prefer to use methods on this struct that are provided where possible.
105    pub runners: Arc<Mutex<HashMap<ShardId, ShardRunnerInfo>>>,
106    /// The index of the first shard to initialize, 0-indexed.
107    // Atomics are used here to allow for mutation without requiring a mutable reference to self.
108    shard_index: AtomicU32,
109    /// The number of shards to initialize.
110    shard_init: AtomicU32,
111    /// The total shards in use, 1-indexed.
112    shard_total: AtomicU32,
113    shard_queuer: Sender<ShardQueuerMessage>,
114    // We can safely use a Mutex for this field, as it is only ever used in one single place
115    // and only is ever used to receive a single message
116    shard_shutdown: Mutex<Receiver<ShardId>>,
117    shard_shutdown_send: Sender<ShardId>,
118    gateway_intents: GatewayIntents,
119}
120
121impl ShardManager {
122    /// Creates a new shard manager, returning both the manager and a monitor for usage in a
123    /// separate thread.
124    #[must_use]
125    pub fn new(opt: ShardManagerOptions) -> (Arc<Self>, Receiver<Result<(), GatewayError>>) {
126        let (return_value_tx, return_value_rx) = mpsc::unbounded();
127        let (shard_queue_tx, shard_queue_rx) = mpsc::unbounded();
128
129        let runners = Arc::new(Mutex::new(HashMap::new()));
130        let (shutdown_send, shutdown_recv) = mpsc::unbounded();
131
132        let manager = Arc::new(Self {
133            return_value_tx: Mutex::new(return_value_tx),
134            shard_index: AtomicU32::new(opt.shard_index),
135            shard_init: AtomicU32::new(opt.shard_init),
136            shard_queuer: shard_queue_tx,
137            shard_total: AtomicU32::new(opt.shard_total),
138            shard_shutdown: Mutex::new(shutdown_recv),
139            shard_shutdown_send: shutdown_send,
140            runners: Arc::clone(&runners),
141            gateway_intents: opt.intents,
142        });
143
144        let mut shard_queuer = ShardQueuer {
145            data: opt.data,
146            event_handlers: opt.event_handlers,
147            raw_event_handlers: opt.raw_event_handlers,
148            #[cfg(feature = "framework")]
149            framework: opt.framework,
150            last_start: None,
151            manager: Arc::clone(&manager),
152            queue: VecDeque::new(),
153            runners,
154            rx: shard_queue_rx,
155            #[cfg(feature = "voice")]
156            voice_manager: opt.voice_manager,
157            ws_url: opt.ws_url,
158            #[cfg(feature = "cache")]
159            cache: opt.cache,
160            http: opt.http,
161            intents: opt.intents,
162            presence: opt.presence,
163        };
164
165        spawn_named("shard_queuer::run", async move {
166            shard_queuer.run().await;
167        });
168
169        (Arc::clone(&manager), return_value_rx)
170    }
171
172    /// Returns whether the shard manager contains either an active instance of a shard runner
173    /// responsible for the given ID.
174    ///
175    /// If a shard has been queued but has not yet been initiated, then this will return `false`.
176    pub async fn has(&self, shard_id: ShardId) -> bool {
177        self.runners.lock().await.contains_key(&shard_id)
178    }
179
180    /// Initializes all shards that the manager is responsible for.
181    ///
182    /// This will communicate shard boots with the [`ShardQueuer`] so that they are properly
183    /// queued.
184    #[instrument(skip(self))]
185    pub fn initialize(&self) -> Result<()> {
186        let shard_index = self.shard_index.load(Ordering::Relaxed);
187        let shard_init = self.shard_init.load(Ordering::Relaxed);
188        let shard_total = self.shard_total.load(Ordering::Relaxed);
189
190        let shard_to = shard_index + shard_init;
191
192        for shard_id in shard_index..shard_to {
193            self.boot([ShardId(shard_id), ShardId(shard_total)]);
194        }
195
196        Ok(())
197    }
198
199    /// Sets the new sharding information for the manager.
200    ///
201    /// This will shutdown all existing shards.
202    ///
203    /// This will _not_ instantiate the new shards.
204    #[instrument(skip(self))]
205    pub async fn set_shards(&self, index: u32, init: u32, total: u32) {
206        self.shutdown_all().await;
207
208        self.shard_index.store(index, Ordering::Relaxed);
209        self.shard_init.store(init, Ordering::Relaxed);
210        self.shard_total.store(total, Ordering::Relaxed);
211    }
212
213    /// Restarts a shard runner.
214    ///
215    /// This sends a shutdown signal to a shard's associated [`ShardRunner`], and then queues a
216    /// initialization of a shard runner for the same shard via the [`ShardQueuer`].
217    ///
218    /// # Examples
219    ///
220    /// Restarting a shard by ID:
221    ///
222    /// ```rust,no_run
223    /// use serenity::model::id::ShardId;
224    /// use serenity::prelude::*;
225    ///
226    /// # async fn run(client: Client) {
227    /// // restart shard ID 7
228    /// client.shard_manager.restart(ShardId(7)).await;
229    /// # }
230    /// ```
231    ///
232    /// [`ShardRunner`]: super::ShardRunner
233    #[instrument(skip(self))]
234    pub async fn restart(&self, shard_id: ShardId) {
235        info!("Restarting shard {}", shard_id);
236        self.shutdown(shard_id, 4000).await;
237
238        let shard_total = self.shard_total.load(Ordering::Relaxed);
239
240        self.boot([shard_id, ShardId(shard_total)]);
241    }
242
243    /// Returns the [`ShardId`]s of the shards that have been instantiated and currently have a
244    /// valid [`ShardRunner`].
245    ///
246    /// [`ShardRunner`]: super::ShardRunner
247    #[instrument(skip(self))]
248    pub async fn shards_instantiated(&self) -> Vec<ShardId> {
249        self.runners.lock().await.keys().copied().collect()
250    }
251
252    /// Attempts to shut down the shard runner by Id.
253    ///
254    /// Returns a boolean indicating whether a shard runner was present. This is _not_ necessary an
255    /// indicator of whether the shard runner was successfully shut down.
256    ///
257    /// **Note**: If the receiving end of an mpsc channel - owned by the shard runner - no longer
258    /// exists, then the shard runner will not know it should shut down. This _should never happen_.
259    /// It may already be stopped.
260    #[instrument(skip(self))]
261    pub async fn shutdown(&self, shard_id: ShardId, code: u16) {
262        const TIMEOUT: tokio::time::Duration = tokio::time::Duration::from_secs(5);
263
264        info!("Shutting down shard {}", shard_id);
265
266        {
267            let mut shard_shutdown = self.shard_shutdown.lock().await;
268
269            drop(
270                self.shard_queuer.unbounded_send(ShardQueuerMessage::ShutdownShard(shard_id, code)),
271            );
272            match timeout(TIMEOUT, shard_shutdown.next()).await {
273                Ok(Some(shutdown_shard_id)) => {
274                    if shutdown_shard_id != shard_id {
275                        warn!(
276                        "Failed to cleanly shutdown shard {}: Shutdown channel sent incorrect ID",
277                        shard_id,
278                    );
279                    }
280                },
281                Ok(None) => (),
282                Err(why) => {
283                    warn!(
284                        "Failed to cleanly shutdown shard {}, reached timeout: {:?}",
285                        shard_id, why
286                    );
287                },
288            }
289            // shard_shutdown is dropped here and releases the lock
290            // in theory we should never have two calls to shutdown()
291            // at the same time but this is a safety measure just in case:tm:
292        }
293
294        self.runners.lock().await.remove(&shard_id);
295    }
296
297    /// Sends a shutdown message for all shards that the manager is responsible for that are still
298    /// known to be running.
299    ///
300    /// If you only need to shutdown a select number of shards, prefer looping over the
301    /// [`Self::shutdown`] method.
302    #[instrument(skip(self))]
303    pub async fn shutdown_all(&self) {
304        let keys = {
305            let runners = self.runners.lock().await;
306
307            if runners.is_empty() {
308                return;
309            }
310
311            runners.keys().copied().collect::<Vec<_>>()
312        };
313
314        info!("Shutting down all shards");
315
316        for shard_id in keys {
317            self.shutdown(shard_id, 1000).await;
318        }
319
320        drop(self.shard_queuer.unbounded_send(ShardQueuerMessage::Shutdown));
321
322        // this message is received by Client::start_connection, which lets the main thread know
323        // and finally return from Client::start
324        drop(self.return_value_tx.lock().await.unbounded_send(Ok(())));
325    }
326
327    #[instrument(skip(self))]
328    fn boot(&self, shard_info: [ShardId; 2]) {
329        info!("Telling shard queuer to start shard {}", shard_info[0]);
330
331        let msg = ShardQueuerMessage::Start(shard_info[0], shard_info[1]);
332
333        drop(self.shard_queuer.unbounded_send(msg));
334    }
335
336    /// Returns the gateway intents used for this gateway connection.
337    #[must_use]
338    pub fn intents(&self) -> GatewayIntents {
339        self.gateway_intents
340    }
341
342    pub async fn return_with_value(&self, ret: Result<(), GatewayError>) {
343        if let Err(e) = self.return_value_tx.lock().await.send(ret).await {
344            tracing::warn!("failed to send return value: {}", e);
345        }
346    }
347
348    pub fn shutdown_finished(&self, id: ShardId) {
349        if let Err(e) = self.shard_shutdown_send.unbounded_send(id) {
350            tracing::warn!("failed to notify about finished shutdown: {}", e);
351        }
352    }
353
354    pub async fn restart_shard(&self, id: ShardId) {
355        self.restart(id).await;
356        if let Err(e) = self.shard_shutdown_send.unbounded_send(id) {
357            tracing::warn!("failed to notify about finished shutdown: {}", e);
358        }
359    }
360
361    pub async fn update_shard_latency_and_stage(
362        &self,
363        id: ShardId,
364        latency: Option<Duration>,
365        stage: ConnectionStage,
366    ) {
367        if let Some(runner) = self.runners.lock().await.get_mut(&id) {
368            runner.latency = latency;
369            runner.stage = stage;
370        }
371    }
372}
373
374impl Drop for ShardManager {
375    /// A custom drop implementation to clean up after the manager.
376    ///
377    /// This shuts down all active [`ShardRunner`]s and attempts to tell the [`ShardQueuer`] to
378    /// shutdown.
379    ///
380    /// [`ShardRunner`]: super::ShardRunner
381    fn drop(&mut self) {
382        drop(self.shard_queuer.unbounded_send(ShardQueuerMessage::Shutdown));
383    }
384}
385
386pub struct ShardManagerOptions {
387    pub data: Arc<RwLock<TypeMap>>,
388    pub event_handlers: Vec<Arc<dyn EventHandler>>,
389    pub raw_event_handlers: Vec<Arc<dyn RawEventHandler>>,
390    #[cfg(feature = "framework")]
391    pub framework: Arc<OnceLock<Arc<dyn Framework>>>,
392    pub shard_index: u32,
393    pub shard_init: u32,
394    pub shard_total: u32,
395    #[cfg(feature = "voice")]
396    pub voice_manager: Option<Arc<dyn VoiceGatewayManager>>,
397    pub ws_url: Arc<Mutex<String>>,
398    #[cfg(feature = "cache")]
399    pub cache: Arc<Cache>,
400    pub http: Arc<Http>,
401    pub intents: GatewayIntents,
402    pub presence: Option<PresenceData>,
403}