Skip to main content

serenity/gateway/bridge/
shard_runner.rs

1use std::borrow::Cow;
2use std::sync::Arc;
3
4use futures::channel::mpsc::{self, UnboundedReceiver as Receiver, UnboundedSender as Sender};
5use tokio::sync::RwLock;
6use tokio_tungstenite::tungstenite;
7use tokio_tungstenite::tungstenite::error::Error as TungsteniteError;
8use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame;
9use tracing::{debug, error, info, instrument, trace, warn};
10use typemap_rev::TypeMap;
11
12use super::event::ShardStageUpdateEvent;
13#[cfg(feature = "collector")]
14use super::CollectorCallback;
15#[cfg(feature = "voice")]
16use super::VoiceGatewayManager;
17use super::{ShardId, ShardManager, ShardRunnerMessage};
18#[cfg(feature = "cache")]
19use crate::cache::Cache;
20use crate::client::dispatch::dispatch_model;
21use crate::client::{Context, EventHandler, RawEventHandler};
22#[cfg(feature = "framework")]
23use crate::framework::Framework;
24use crate::gateway::{GatewayError, ReconnectType, Shard, ShardAction};
25use crate::http::Http;
26use crate::internal::prelude::*;
27use crate::internal::tokio::spawn_named;
28use crate::model::event::{Event, GatewayEvent};
29
30/// A runner for managing a [`Shard`] and its respective WebSocket client.
31pub struct ShardRunner {
32    data: Arc<RwLock<TypeMap>>,
33    event_handlers: Vec<Arc<dyn EventHandler>>,
34    raw_event_handlers: Vec<Arc<dyn RawEventHandler>>,
35    #[cfg(feature = "framework")]
36    framework: Option<Arc<dyn Framework>>,
37    manager: Arc<ShardManager>,
38    // channel to receive messages from the shard manager and dispatches
39    runner_rx: Receiver<ShardRunnerMessage>,
40    // channel to send messages to the shard runner from the shard manager
41    runner_tx: Sender<ShardRunnerMessage>,
42    pub(crate) shard: Shard,
43    #[cfg(feature = "voice")]
44    voice_manager: Option<Arc<dyn VoiceGatewayManager + 'static>>,
45    #[cfg(feature = "cache")]
46    pub cache: Arc<Cache>,
47    pub http: Arc<Http>,
48    #[cfg(feature = "collector")]
49    pub(crate) collectors: Arc<std::sync::Mutex<Vec<CollectorCallback>>>,
50}
51
52impl ShardRunner {
53    /// Creates a new runner for a Shard.
54    pub fn new(opt: ShardRunnerOptions) -> Self {
55        let (tx, rx) = mpsc::unbounded();
56
57        Self {
58            runner_rx: rx,
59            runner_tx: tx,
60            data: opt.data,
61            event_handlers: opt.event_handlers,
62            raw_event_handlers: opt.raw_event_handlers,
63            #[cfg(feature = "framework")]
64            framework: opt.framework,
65            manager: opt.manager,
66            shard: opt.shard,
67            #[cfg(feature = "voice")]
68            voice_manager: opt.voice_manager,
69            #[cfg(feature = "cache")]
70            cache: opt.cache,
71            http: opt.http,
72            #[cfg(feature = "collector")]
73            collectors: Arc::new(std::sync::Mutex::new(vec![])),
74        }
75    }
76
77    /// Starts the runner's loop to receive events.
78    ///
79    /// This runs a loop that performs the following in each iteration:
80    ///
81    /// 1. checks the receiver for [`ShardRunnerMessage`]s, possibly from the [`ShardManager`], and
82    ///    if there is one, acts on it.
83    ///
84    /// 2. checks if a heartbeat should be sent to the discord Gateway, and if so, sends one.
85    ///
86    /// 3. attempts to retrieve a message from the WebSocket, processing it into a [`GatewayEvent`].
87    ///    This will block for 100ms before assuming there is no message available.
88    ///
89    /// 4. Checks with the [`Shard`] to determine if the gateway event is specifying an action to
90    ///    take (e.g. resuming, reconnecting, heartbeating) and then performs that action, if any.
91    ///
92    /// 5. Dispatches the event via the Client.
93    ///
94    /// 6. Go back to 1.
95    ///
96    /// # Errors
97    ///
98    /// Returns an error if authentication fails or the intents contained disallowed values.
99    ///
100    /// [`ShardManager`]: super::ShardManager
101    #[instrument(skip(self))]
102    pub async fn run(&mut self) -> Result<()> {
103        info!("[ShardRunner {:?}] Running", self.shard.shard_info());
104
105        loop {
106            trace!("[ShardRunner {:?}] loop iteration started.", self.shard.shard_info());
107            if !self.recv().await {
108                return Ok(());
109            }
110
111            // check heartbeat
112            if !self.shard.do_heartbeat().await {
113                warn!("[ShardRunner {:?}] Error heartbeating", self.shard.shard_info(),);
114
115                self.request_restart().await;
116                return Ok(());
117            }
118
119            let pre = self.shard.stage();
120            let (event, action, successful) = self.recv_event().await?;
121            let post = self.shard.stage();
122
123            if post != pre {
124                self.update_manager().await;
125
126                for event_handler in self.event_handlers.clone() {
127                    let context = self.make_context();
128                    let event = ShardStageUpdateEvent {
129                        new: post,
130                        old: pre,
131                        shard_id: self.shard.shard_info().id,
132                    };
133                    spawn_named("dispatch::event_handler::shard_stage_update", async move {
134                        event_handler.shard_stage_update(context, event).await;
135                    });
136                }
137            }
138
139            match action {
140                Some(ShardAction::Reconnect(ReconnectType::Reidentify)) => {
141                    self.request_restart().await;
142                    return Ok(());
143                },
144                Some(other) => {
145                    if let Err(e) = self.action(&other).await {
146                        debug!(
147                            "[ShardRunner {:?}] Reconnecting due to error performing {:?}: {:?}",
148                            self.shard.shard_info(),
149                            other,
150                            e
151                        );
152                        match self.shard.reconnection_type() {
153                            ReconnectType::Reidentify => {
154                                self.request_restart().await;
155                                return Ok(());
156                            },
157                            ReconnectType::Resume => {
158                                if let Err(why) = self.shard.resume().await {
159                                    warn!(
160                                        "[ShardRunner {:?}] Resume failed, reidentifying: {:?}",
161                                        self.shard.shard_info(),
162                                        why
163                                    );
164
165                                    self.request_restart().await;
166                                    return Ok(());
167                                }
168                            },
169                        }
170                    }
171                },
172                None => {},
173            }
174
175            if let Some(event) = event {
176                #[cfg(feature = "collector")]
177                self.collectors.lock().expect("poison").retain_mut(|callback| (callback.0)(&event));
178
179                dispatch_model(
180                    event,
181                    &self.make_context(),
182                    #[cfg(feature = "framework")]
183                    self.framework.clone(),
184                    self.event_handlers.clone(),
185                    self.raw_event_handlers.clone(),
186                );
187            }
188
189            if !successful && !self.shard.stage().is_connecting() {
190                self.request_restart().await;
191                return Ok(());
192            }
193            trace!("[ShardRunner {:?}] loop iteration reached the end.", self.shard.shard_info());
194        }
195    }
196
197    /// Clones the internal copy of the Sender to the shard runner.
198    pub(super) fn runner_tx(&self) -> Sender<ShardRunnerMessage> {
199        self.runner_tx.clone()
200    }
201
202    /// Takes an action that a [`Shard`] has determined should happen and then does it.
203    ///
204    /// For example, if the shard says that an Identify message needs to be sent, this will do
205    /// that.
206    ///
207    /// # Errors
208    ///
209    /// Returns
210    #[instrument(skip(self, action))]
211    async fn action(&mut self, action: &ShardAction) -> Result<()> {
212        match *action {
213            ShardAction::Reconnect(ReconnectType::Reidentify) => {
214                self.request_restart().await;
215                Ok(())
216            },
217            ShardAction::Reconnect(ReconnectType::Resume) => self.shard.resume().await,
218            ShardAction::Heartbeat => self.shard.heartbeat().await,
219            ShardAction::Identify => self.shard.identify().await,
220        }
221    }
222
223    // Checks if the ID received to shutdown is equivalent to the ID of the shard this runner is
224    // responsible. If so, it shuts down the WebSocket client.
225    //
226    // Returns whether the WebSocket client is still active.
227    //
228    // If true, the WebSocket client was _not_ shutdown. If false, it was.
229    #[instrument(skip(self))]
230    async fn checked_shutdown(&mut self, id: ShardId, close_code: u16) -> bool {
231        // First verify the ID so we know for certain this runner is to shutdown.
232        if id != self.shard.shard_info().id {
233            // Not meant for this runner for some reason, don't shutdown.
234            return true;
235        }
236
237        // Send a Close Frame to Discord, which allows a bot to "log off"
238        drop(
239            self.shard
240                .client
241                .close(Some(CloseFrame {
242                    code: close_code.into(),
243                    reason: Cow::from(""),
244                }))
245                .await,
246        );
247
248        // In return, we wait for either a Close Frame response, or an error, after which this WS
249        // is deemed disconnected from Discord.
250        loop {
251            match self.shard.client.next().await {
252                Some(Ok(tungstenite::Message::Close(_))) => break,
253                Some(Err(_)) => {
254                    warn!(
255                        "[ShardRunner {:?}] Received an error awaiting close frame",
256                        self.shard.shard_info(),
257                    );
258                    break;
259                },
260                _ => {},
261            }
262        }
263
264        // Inform the manager that shutdown for this shard has finished.
265        self.manager.shutdown_finished(id);
266        false
267    }
268
269    fn make_context(&self) -> Context {
270        Context::new(
271            Arc::clone(&self.data),
272            self,
273            self.shard.shard_info().id,
274            Arc::clone(&self.http),
275            #[cfg(feature = "cache")]
276            Arc::clone(&self.cache),
277        )
278    }
279
280    // Handles a received value over the shard runner rx channel.
281    //
282    // Returns a boolean on whether the shard runner can continue.
283    //
284    // This always returns true, except in the case that the shard manager asked the runner to
285    // shutdown.
286    #[instrument(skip(self))]
287    async fn handle_rx_value(&mut self, msg: ShardRunnerMessage) -> bool {
288        match msg {
289            ShardRunnerMessage::Restart(id) => self.checked_shutdown(id, 4000).await,
290            ShardRunnerMessage::Shutdown(id, code) => self.checked_shutdown(id, code).await,
291            ShardRunnerMessage::ChunkGuild {
292                guild_id,
293                limit,
294                presences,
295                filter,
296                nonce,
297            } => self
298                .shard
299                .chunk_guild(guild_id, limit, presences, filter, nonce.as_deref())
300                .await
301                .is_ok(),
302            ShardRunnerMessage::SoundboardSounds {
303                guild_ids,
304            } => self.shard.request_soundboard_sounds(&guild_ids).await.is_ok(),
305            ShardRunnerMessage::Close(code, reason) => {
306                let reason = reason.unwrap_or_default();
307                let close = CloseFrame {
308                    code: code.into(),
309                    reason: Cow::from(reason),
310                };
311                self.shard.client.close(Some(close)).await.is_ok()
312            },
313            ShardRunnerMessage::Message(msg) => self.shard.client.send(msg).await.is_ok(),
314            ShardRunnerMessage::SetActivity(activity) => {
315                self.shard.set_activity(activity);
316                self.shard.update_presence().await.is_ok()
317            },
318            ShardRunnerMessage::SetPresence(activity, status) => {
319                self.shard.set_presence(activity, status);
320                self.shard.update_presence().await.is_ok()
321            },
322            ShardRunnerMessage::SetStatus(status) => {
323                self.shard.set_status(status);
324                self.shard.update_presence().await.is_ok()
325            },
326        }
327    }
328
329    #[cfg(feature = "voice")]
330    #[instrument(skip(self))]
331    async fn handle_voice_event(&self, event: &Event) {
332        if let Some(voice_manager) = &self.voice_manager {
333            match event {
334                Event::Ready(_) => {
335                    voice_manager
336                        .register_shard(self.shard.shard_info().id.0, self.runner_tx.clone())
337                        .await;
338                },
339                Event::VoiceServerUpdate(event) => {
340                    if let Some(guild_id) = event.guild_id {
341                        voice_manager.server_update(guild_id, &event.endpoint, &event.token).await;
342                    }
343                },
344                Event::VoiceStateUpdate(event) => {
345                    if let Some(guild_id) = event.voice_state.guild_id {
346                        voice_manager.state_update(guild_id, &event.voice_state).await;
347                    }
348                },
349                _ => {},
350            }
351        }
352    }
353
354    // Receives values over the internal shard runner rx channel and handles them.
355    //
356    // This will loop over values until there is no longer one.
357    //
358    // Requests a restart if the sending half of the channel disconnects. This should _never_
359    // happen, as the sending half is kept on the runner.
360    // Returns whether the shard runner is in a state that can continue.
361    #[instrument(skip(self))]
362    async fn recv(&mut self) -> bool {
363        loop {
364            match self.runner_rx.try_next() {
365                Ok(Some(value)) => {
366                    if !self.handle_rx_value(value).await {
367                        return false;
368                    }
369                },
370                Ok(None) => {
371                    warn!(
372                        "[ShardRunner {:?}] Sending half DC; restarting",
373                        self.shard.shard_info(),
374                    );
375
376                    self.request_restart().await;
377                    return false;
378                },
379                Err(_) => break,
380            }
381        }
382
383        // There are no longer any values available.
384
385        true
386    }
387
388    /// Returns a received event, as well as whether reading the potentially present event was
389    /// successful.
390    #[instrument(skip(self))]
391    async fn recv_event(&mut self) -> Result<(Option<Event>, Option<ShardAction>, bool)> {
392        let gw_event = match self.shard.client.recv_json().await {
393            Ok(inner) => Ok(inner),
394            Err(Error::Tungstenite(TungsteniteError::Io(_))) => {
395                debug!("Attempting to auto-reconnect");
396
397                match self.shard.reconnection_type() {
398                    ReconnectType::Reidentify => return Ok((None, None, false)),
399                    ReconnectType::Resume => {
400                        if let Err(why) = self.shard.resume().await {
401                            warn!("Failed to resume: {:?}", why);
402
403                            // Don't spam reattempts on internet connection loss
404                            tokio::time::sleep(std::time::Duration::from_secs(1)).await;
405
406                            return Ok((None, None, false));
407                        }
408                    },
409                }
410
411                return Ok((None, None, true));
412            },
413            Err(why) => Err(why),
414        };
415
416        let event = match gw_event {
417            Ok(Some(event)) => Ok(event),
418            Ok(None) => return Ok((None, None, true)),
419            Err(why) => Err(why),
420        };
421
422        let action = match self.shard.handle_event(&event) {
423            Ok(Some(action)) => Some(action),
424            Ok(None) => None,
425            Err(why) => {
426                error!("Shard handler received err: {:?}", why);
427
428                match &why {
429                    Error::Gateway(
430                        error @ (GatewayError::InvalidAuthentication
431                        | GatewayError::InvalidGatewayIntents
432                        | GatewayError::DisallowedGatewayIntents),
433                    ) => {
434                        self.manager.return_with_value(Err(error.clone())).await;
435
436                        return Err(why);
437                    },
438                    _ => return Ok((None, None, true)),
439                }
440            },
441        };
442
443        if let Ok(GatewayEvent::HeartbeatAck) = event {
444            self.update_manager().await;
445        }
446
447        #[cfg(feature = "voice")]
448        {
449            if let Ok(GatewayEvent::Dispatch(_, ref event)) = event {
450                self.handle_voice_event(event).await;
451            }
452        }
453
454        let event = match event {
455            Ok(GatewayEvent::Dispatch(_, event)) => Some(event),
456            _ => None,
457        };
458
459        Ok((event, action, true))
460    }
461
462    #[instrument(skip(self))]
463    async fn request_restart(&mut self) {
464        debug!("[ShardRunner {:?}] Requesting restart", self.shard.shard_info());
465
466        self.update_manager().await;
467
468        let shard_id = self.shard.shard_info().id;
469        self.manager.restart_shard(shard_id).await;
470
471        #[cfg(feature = "voice")]
472        if let Some(voice_manager) = &self.voice_manager {
473            voice_manager.deregister_shard(shard_id.0).await;
474        }
475    }
476
477    #[instrument(skip(self))]
478    async fn update_manager(&self) {
479        self.manager
480            .update_shard_latency_and_stage(
481                self.shard.shard_info().id,
482                self.shard.latency(),
483                self.shard.stage(),
484            )
485            .await;
486    }
487}
488
489/// Options to be passed to [`ShardRunner::new`].
490pub struct ShardRunnerOptions {
491    pub data: Arc<RwLock<TypeMap>>,
492    pub event_handlers: Vec<Arc<dyn EventHandler>>,
493    pub raw_event_handlers: Vec<Arc<dyn RawEventHandler>>,
494    #[cfg(feature = "framework")]
495    pub framework: Option<Arc<dyn Framework>>,
496    pub manager: Arc<ShardManager>,
497    pub shard: Shard,
498    #[cfg(feature = "voice")]
499    pub voice_manager: Option<Arc<dyn VoiceGatewayManager>>,
500    #[cfg(feature = "cache")]
501    pub cache: Arc<Cache>,
502    pub http: Arc<Http>,
503}