serenity/gateway/bridge/
shard_queuer.rs

1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3#[cfg(feature = "framework")]
4use std::sync::OnceLock;
5
6use futures::channel::mpsc::UnboundedReceiver as Receiver;
7use futures::StreamExt;
8use tokio::sync::{Mutex, RwLock};
9use tokio::time::{sleep, timeout, Duration, Instant};
10use tracing::{debug, info, instrument, warn};
11use typemap_rev::TypeMap;
12
13#[cfg(feature = "voice")]
14use super::VoiceGatewayManager;
15use super::{
16    ShardId,
17    ShardManager,
18    ShardMessenger,
19    ShardQueuerMessage,
20    ShardRunner,
21    ShardRunnerInfo,
22    ShardRunnerOptions,
23};
24#[cfg(feature = "cache")]
25use crate::cache::Cache;
26use crate::client::{EventHandler, RawEventHandler};
27#[cfg(feature = "framework")]
28use crate::framework::Framework;
29use crate::gateway::{ConnectionStage, PresenceData, Shard, ShardRunnerMessage};
30use crate::http::Http;
31use crate::internal::prelude::*;
32use crate::internal::tokio::spawn_named;
33use crate::model::gateway::{GatewayIntents, ShardInfo};
34
35const WAIT_BETWEEN_BOOTS_IN_SECONDS: u64 = 5;
36
37/// The shard queuer is a simple loop that runs indefinitely to manage the startup of shards.
38///
39/// A shard queuer instance _should_ be run in its own thread, due to the blocking nature of the
40/// loop itself as well as a 5 second thread sleep between shard starts.
41pub struct ShardQueuer {
42    /// A copy of [`Client::data`] to be given to runners for contextual dispatching.
43    ///
44    /// [`Client::data`]: crate::Client::data
45    pub data: Arc<RwLock<TypeMap>>,
46    /// A reference to an [`EventHandler`], such as the one given to the [`Client`].
47    ///
48    /// [`Client`]: crate::Client
49    pub event_handlers: Vec<Arc<dyn EventHandler>>,
50    /// A reference to an [`RawEventHandler`], such as the one given to the [`Client`].
51    ///
52    /// [`Client`]: crate::Client
53    pub raw_event_handlers: Vec<Arc<dyn RawEventHandler>>,
54    /// A copy of the framework
55    #[cfg(feature = "framework")]
56    pub framework: Arc<OnceLock<Arc<dyn Framework>>>,
57    /// The instant that a shard was last started.
58    ///
59    /// This is used to determine how long to wait between shard IDENTIFYs.
60    pub last_start: Option<Instant>,
61    /// A copy of the [`ShardManager`] to communicate with it.
62    pub manager: Arc<ShardManager>,
63    /// The shards that are queued for booting.
64    ///
65    /// This will typically be filled with previously failed boots.
66    pub queue: VecDeque<ShardInfo>,
67    /// A copy of the map of shard runners.
68    pub runners: Arc<Mutex<HashMap<ShardId, ShardRunnerInfo>>>,
69    /// A receiver channel for the shard queuer to be told to start shards.
70    pub rx: Receiver<ShardQueuerMessage>,
71    /// A copy of the client's voice manager.
72    #[cfg(feature = "voice")]
73    pub voice_manager: Option<Arc<dyn VoiceGatewayManager + 'static>>,
74    /// A copy of the URL to use to connect to the gateway.
75    pub ws_url: Arc<Mutex<String>>,
76    #[cfg(feature = "cache")]
77    pub cache: Arc<Cache>,
78    pub http: Arc<Http>,
79    pub intents: GatewayIntents,
80    pub presence: Option<PresenceData>,
81}
82
83impl ShardQueuer {
84    /// Begins the shard queuer loop.
85    ///
86    /// This will loop over the internal [`Self::rx`] for [`ShardQueuerMessage`]s, blocking for
87    /// messages on what to do.
88    ///
89    /// If a [`ShardQueuerMessage::Start`] is received, this will:
90    ///
91    /// 1. Check how much time has passed since the last shard was started
92    /// 2. If the amount of time is less than the ratelimit, it will sleep until that time has
93    ///    passed
94    /// 3. Start the shard by ID
95    ///
96    /// If a [`ShardQueuerMessage::Shutdown`] is received, this will return and the loop will be
97    /// over.
98    ///
99    /// **Note**: This should be run in its own thread due to the blocking nature of the loop.
100    #[instrument(skip(self))]
101    pub async fn run(&mut self) {
102        // The duration to timeout from reads over the Rx channel. This can be done in a loop, and
103        // if the read times out then a shard can be started if one is presently waiting in the
104        // queue.
105        const TIMEOUT: Duration = Duration::from_secs(WAIT_BETWEEN_BOOTS_IN_SECONDS);
106
107        loop {
108            match timeout(TIMEOUT, self.rx.next()).await {
109                Ok(Some(ShardQueuerMessage::Shutdown)) => {
110                    debug!("[Shard Queuer] Received to shutdown.");
111                    self.shutdown_runners().await;
112
113                    break;
114                },
115                Ok(Some(ShardQueuerMessage::ShutdownShard(shard, code))) => {
116                    debug!("[Shard Queuer] Received to shutdown shard {} with {}.", shard.0, code);
117                    self.shutdown(shard, code).await;
118                },
119                Ok(Some(ShardQueuerMessage::Start(id, total))) => {
120                    debug!("[Shard Queuer] Received to start shard {} of {}.", id.0, total.0);
121                    self.checked_start(id, total.0).await;
122                },
123                Ok(None) => break,
124                Err(_) => {
125                    if let Some(shard) = self.queue.pop_front() {
126                        self.checked_start(shard.id, shard.total).await;
127                    }
128                },
129            }
130        }
131    }
132
133    #[instrument(skip(self))]
134    async fn check_last_start(&mut self) {
135        let Some(instant) = self.last_start else { return };
136
137        // We must wait 5 seconds between IDENTIFYs to avoid session invalidations.
138        let duration = Duration::from_secs(WAIT_BETWEEN_BOOTS_IN_SECONDS);
139        let elapsed = instant.elapsed();
140
141        if elapsed >= duration {
142            return;
143        }
144
145        let to_sleep = duration - elapsed;
146
147        sleep(to_sleep).await;
148    }
149
150    #[instrument(skip(self))]
151    async fn checked_start(&mut self, id: ShardId, total: u32) {
152        debug!("[Shard Queuer] Checked start for shard {} out of {}", id, total);
153        self.check_last_start().await;
154
155        if let Err(why) = self.start(id, total).await {
156            warn!("[Shard Queuer] Err starting shard {}: {:?}", id, why);
157            info!("[Shard Queuer] Re-queueing start of shard {}", id);
158
159            self.queue.push_back(ShardInfo::new(id, total));
160        }
161
162        self.last_start = Some(Instant::now());
163    }
164
165    #[instrument(skip(self))]
166    async fn start(&mut self, id: ShardId, total: u32) -> Result<()> {
167        let shard_info = ShardInfo::new(id, total);
168
169        let mut shard = Shard::new(
170            Arc::clone(&self.ws_url),
171            self.http.token(),
172            shard_info,
173            self.intents,
174            self.presence.clone(),
175        )
176        .await?;
177
178        let cloned_http = Arc::clone(&self.http);
179        shard.set_application_id_callback(move |id| cloned_http.set_application_id(id));
180
181        let mut runner = ShardRunner::new(ShardRunnerOptions {
182            data: Arc::clone(&self.data),
183            event_handlers: self.event_handlers.clone(),
184            raw_event_handlers: self.raw_event_handlers.clone(),
185            #[cfg(feature = "framework")]
186            framework: self.framework.get().cloned(),
187            manager: Arc::clone(&self.manager),
188            #[cfg(feature = "voice")]
189            voice_manager: self.voice_manager.clone(),
190            shard,
191            #[cfg(feature = "cache")]
192            cache: Arc::clone(&self.cache),
193            http: Arc::clone(&self.http),
194        });
195
196        let runner_info = ShardRunnerInfo {
197            latency: None,
198            runner_tx: ShardMessenger::new(&runner),
199            stage: ConnectionStage::Disconnected,
200        };
201
202        spawn_named("shard_queuer::stop", async move {
203            drop(runner.run().await);
204            debug!("[ShardRunner {:?}] Stopping", runner.shard.shard_info());
205        });
206
207        self.runners.lock().await.insert(id, runner_info);
208
209        Ok(())
210    }
211
212    #[instrument(skip(self))]
213    async fn shutdown_runners(&mut self) {
214        let keys = {
215            let runners = self.runners.lock().await;
216
217            if runners.is_empty() {
218                return;
219            }
220
221            runners.keys().copied().collect::<Vec<_>>()
222        };
223
224        info!("Shutting down all shards");
225
226        for shard_id in keys {
227            self.shutdown(shard_id, 1000).await;
228        }
229    }
230
231    /// Attempts to shut down the shard runner by Id.
232    ///
233    /// **Note**: If the receiving end of an mpsc channel - owned by the shard runner - no longer
234    /// exists, then the shard runner will not know it should shut down. This _should never happen_.
235    /// It may already be stopped.
236    #[instrument(skip(self))]
237    pub async fn shutdown(&mut self, shard_id: ShardId, code: u16) {
238        info!("Shutting down shard {}", shard_id);
239
240        if let Some(runner) = self.runners.lock().await.get(&shard_id) {
241            let msg = ShardRunnerMessage::Shutdown(shard_id, code);
242
243            if let Err(why) = runner.runner_tx.tx.unbounded_send(msg) {
244                warn!(
245                    "Failed to cleanly shutdown shard {} when sending message to shard runner: {:?}",
246                    shard_id,
247                    why,
248                );
249            }
250        }
251    }
252}