Skip to main content

serenity/gateway/bridge/
shard_queuer.rs

1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3#[cfg(feature = "framework")]
4use std::sync::OnceLock;
5
6use futures::channel::mpsc::UnboundedReceiver as Receiver;
7use futures::StreamExt;
8use tokio::sync::{Mutex, RwLock};
9use tokio::time::{sleep, timeout, Duration, Instant};
10use tracing::{debug, info, instrument, warn};
11use typemap_rev::TypeMap;
12
13#[cfg(feature = "voice")]
14use super::VoiceGatewayManager;
15use super::{
16    ShardId,
17    ShardManager,
18    ShardMessenger,
19    ShardQueuerMessage,
20    ShardRunner,
21    ShardRunnerInfo,
22    ShardRunnerOptions,
23};
24#[cfg(feature = "cache")]
25use crate::cache::Cache;
26use crate::client::{EventHandler, RawEventHandler};
27#[cfg(feature = "framework")]
28use crate::framework::Framework;
29use crate::gateway::{ConnectionStage, PresenceData, Shard, ShardRunnerMessage};
30use crate::http::Http;
31use crate::internal::prelude::*;
32use crate::internal::tokio::spawn_named;
33use crate::model::gateway::{GatewayIntents, ShardInfo};
34
35const WAIT_BETWEEN_BOOTS_IN_SECONDS: u64 = 5;
36
37/// The shard queuer is a simple loop that runs indefinitely to manage the startup of shards.
38///
39/// A shard queuer instance _should_ be run in its own thread, due to the blocking nature of the
40/// loop itself as well as a 5 second thread sleep between shard starts.
41pub struct ShardQueuer {
42    /// A copy of [`Client::data`] to be given to runners for contextual dispatching.
43    ///
44    /// [`Client::data`]: crate::Client::data
45    pub data: Arc<RwLock<TypeMap>>,
46    /// A reference to an [`EventHandler`], such as the one given to the [`Client`].
47    ///
48    /// [`Client`]: crate::Client
49    pub event_handlers: Vec<Arc<dyn EventHandler>>,
50    /// A reference to an [`RawEventHandler`], such as the one given to the [`Client`].
51    ///
52    /// [`Client`]: crate::Client
53    pub raw_event_handlers: Vec<Arc<dyn RawEventHandler>>,
54    /// A copy of the framework
55    #[cfg(feature = "framework")]
56    pub framework: Arc<OnceLock<Arc<dyn Framework>>>,
57    /// The instant that a shard was last started.
58    ///
59    /// This is used to determine how long to wait between shard IDENTIFYs.
60    pub last_start: Option<Instant>,
61    /// A copy of the [`ShardManager`] to communicate with it.
62    pub manager: Arc<ShardManager>,
63    /// The shards that are queued for booting.
64    ///
65    /// This will typically be filled with previously failed boots.
66    pub queue: VecDeque<ShardInfo>,
67    /// A copy of the map of shard runners.
68    pub runners: Arc<Mutex<HashMap<ShardId, ShardRunnerInfo>>>,
69    /// A receiver channel for the shard queuer to be told to start shards.
70    pub rx: Receiver<ShardQueuerMessage>,
71    /// A copy of the client's voice manager.
72    #[cfg(feature = "voice")]
73    pub voice_manager: Option<Arc<dyn VoiceGatewayManager + 'static>>,
74    /// A copy of the URL to use to connect to the gateway.
75    pub ws_url: Arc<Mutex<String>>,
76    #[cfg(feature = "cache")]
77    pub cache: Arc<Cache>,
78    pub http: Arc<Http>,
79    pub intents: GatewayIntents,
80    pub presence: Option<PresenceData>,
81}
82
83impl ShardQueuer {
84    /// Begins the shard queuer loop.
85    ///
86    /// This will loop over the internal [`Self::rx`] for [`ShardQueuerMessage`]s, blocking for
87    /// messages on what to do.
88    ///
89    /// If a [`ShardQueuerMessage::Start`] is received, this will:
90    ///
91    /// 1. Check how much time has passed since the last shard was started
92    /// 2. If the amount of time is less than the ratelimit, it will sleep until that time has
93    ///    passed
94    /// 3. Start the shard by ID
95    ///
96    /// If a [`ShardQueuerMessage::Shutdown`] is received, this will return and the loop will be
97    /// over.
98    ///
99    /// **Note**: This should be run in its own thread due to the blocking nature of the loop.
100    #[instrument(skip(self))]
101    pub async fn run(&mut self) {
102        // The duration to timeout from reads over the Rx channel. This can be done in a loop, and
103        // if the read times out then a shard can be started if one is presently waiting in the
104        // queue.
105        const TIMEOUT: Duration = Duration::from_secs(WAIT_BETWEEN_BOOTS_IN_SECONDS);
106
107        loop {
108            match timeout(TIMEOUT, self.rx.next()).await {
109                Ok(Some(ShardQueuerMessage::Shutdown)) => {
110                    debug!("[Shard Queuer] Received to shutdown.");
111                    self.shutdown_runners().await;
112
113                    break;
114                },
115                Ok(Some(ShardQueuerMessage::ShutdownShard(shard, code))) => {
116                    debug!("[Shard Queuer] Received to shutdown shard {} with {}.", shard.0, code);
117                    self.shutdown(shard, code).await;
118                },
119                Ok(Some(ShardQueuerMessage::Start(id, total))) => {
120                    debug!("[Shard Queuer] Received to start shard {} of {}.", id.0, total.0);
121                    self.checked_start(id, total.0).await;
122                },
123                Ok(None) => break,
124                Err(_) => {
125                    if let Some(shard) = self.queue.pop_front() {
126                        self.checked_start(shard.id, shard.total).await;
127                    }
128                },
129            }
130        }
131    }
132
133    #[instrument(skip(self))]
134    async fn check_last_start(&mut self) {
135        let Some(instant) = self.last_start else { return };
136
137        // We must wait 5 seconds between IDENTIFYs to avoid session invalidations.
138        let duration = Duration::from_secs(WAIT_BETWEEN_BOOTS_IN_SECONDS);
139        let elapsed = instant.elapsed();
140
141        if let Some(to_sleep) = duration.checked_sub(elapsed) {
142            sleep(to_sleep).await;
143        }
144    }
145
146    #[instrument(skip(self))]
147    async fn checked_start(&mut self, id: ShardId, total: u32) {
148        debug!("[Shard Queuer] Checked start for shard {} out of {}", id, total);
149        self.check_last_start().await;
150
151        if let Err(why) = self.start(id, total).await {
152            warn!("[Shard Queuer] Err starting shard {}: {:?}", id, why);
153            info!("[Shard Queuer] Re-queueing start of shard {}", id);
154
155            self.queue.push_back(ShardInfo::new(id, total));
156        }
157
158        self.last_start = Some(Instant::now());
159    }
160
161    #[instrument(skip(self))]
162    async fn start(&mut self, id: ShardId, total: u32) -> Result<()> {
163        let shard_info = ShardInfo::new(id, total);
164
165        let mut shard = Shard::new(
166            Arc::clone(&self.ws_url),
167            self.http.token(),
168            shard_info,
169            self.intents,
170            self.presence.clone(),
171        )
172        .await?;
173
174        let cloned_http = Arc::clone(&self.http);
175        shard.set_application_id_callback(move |id| cloned_http.set_application_id(id));
176
177        let mut runner = ShardRunner::new(ShardRunnerOptions {
178            data: Arc::clone(&self.data),
179            event_handlers: self.event_handlers.clone(),
180            raw_event_handlers: self.raw_event_handlers.clone(),
181            #[cfg(feature = "framework")]
182            framework: self.framework.get().cloned(),
183            manager: Arc::clone(&self.manager),
184            #[cfg(feature = "voice")]
185            voice_manager: self.voice_manager.clone(),
186            shard,
187            #[cfg(feature = "cache")]
188            cache: Arc::clone(&self.cache),
189            http: Arc::clone(&self.http),
190        });
191
192        let runner_info = ShardRunnerInfo {
193            latency: None,
194            runner_tx: ShardMessenger::new(&runner),
195            stage: ConnectionStage::Disconnected,
196        };
197
198        spawn_named("shard_queuer::stop", async move {
199            drop(Box::pin(runner.run()).await);
200            debug!("[ShardRunner {:?}] Stopping", runner.shard.shard_info());
201        });
202
203        self.runners.lock().await.insert(id, runner_info);
204
205        Ok(())
206    }
207
208    #[instrument(skip(self))]
209    async fn shutdown_runners(&mut self) {
210        let keys = {
211            let runners = self.runners.lock().await;
212
213            if runners.is_empty() {
214                return;
215            }
216
217            runners.keys().copied().collect::<Vec<_>>()
218        };
219
220        info!("Shutting down all shards");
221
222        for shard_id in keys {
223            self.shutdown(shard_id, 1000).await;
224        }
225    }
226
227    /// Attempts to shut down the shard runner by Id.
228    ///
229    /// **Note**: If the receiving end of an mpsc channel - owned by the shard runner - no longer
230    /// exists, then the shard runner will not know it should shut down. This _should never happen_.
231    /// It may already be stopped.
232    #[instrument(skip(self))]
233    pub async fn shutdown(&mut self, shard_id: ShardId, code: u16) {
234        info!("Shutting down shard {}", shard_id);
235
236        if let Some(runner) = self.runners.lock().await.get(&shard_id) {
237            let msg = ShardRunnerMessage::Shutdown(shard_id, code);
238
239            if let Err(why) = runner.runner_tx.tx.unbounded_send(msg) {
240                warn!(
241                    "Failed to cleanly shutdown shard {} when sending message to shard runner: {:?}",
242                    shard_id,
243                    why,
244                );
245            }
246        }
247    }
248}