serenity/gateway/bridge/
shard_runner.rs

1use std::borrow::Cow;
2use std::sync::Arc;
3
4use futures::channel::mpsc::{self, UnboundedReceiver as Receiver, UnboundedSender as Sender};
5use tokio::sync::RwLock;
6use tokio_tungstenite::tungstenite;
7use tokio_tungstenite::tungstenite::error::Error as TungsteniteError;
8use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame;
9use tracing::{debug, error, info, instrument, trace, warn};
10use typemap_rev::TypeMap;
11
12use super::event::ShardStageUpdateEvent;
13#[cfg(feature = "collector")]
14use super::CollectorCallback;
15#[cfg(feature = "voice")]
16use super::VoiceGatewayManager;
17use super::{ShardId, ShardManager, ShardRunnerMessage};
18#[cfg(feature = "cache")]
19use crate::cache::Cache;
20use crate::client::dispatch::dispatch_model;
21use crate::client::{Context, EventHandler, RawEventHandler};
22#[cfg(feature = "framework")]
23use crate::framework::Framework;
24use crate::gateway::{GatewayError, ReconnectType, Shard, ShardAction};
25use crate::http::Http;
26use crate::internal::prelude::*;
27use crate::internal::tokio::spawn_named;
28use crate::model::event::{Event, GatewayEvent};
29
30/// A runner for managing a [`Shard`] and its respective WebSocket client.
31pub struct ShardRunner {
32    data: Arc<RwLock<TypeMap>>,
33    event_handlers: Vec<Arc<dyn EventHandler>>,
34    raw_event_handlers: Vec<Arc<dyn RawEventHandler>>,
35    #[cfg(feature = "framework")]
36    framework: Option<Arc<dyn Framework>>,
37    manager: Arc<ShardManager>,
38    // channel to receive messages from the shard manager and dispatches
39    runner_rx: Receiver<ShardRunnerMessage>,
40    // channel to send messages to the shard runner from the shard manager
41    runner_tx: Sender<ShardRunnerMessage>,
42    pub(crate) shard: Shard,
43    #[cfg(feature = "voice")]
44    voice_manager: Option<Arc<dyn VoiceGatewayManager + 'static>>,
45    #[cfg(feature = "cache")]
46    pub cache: Arc<Cache>,
47    pub http: Arc<Http>,
48    #[cfg(feature = "collector")]
49    pub(crate) collectors: Arc<std::sync::Mutex<Vec<CollectorCallback>>>,
50}
51
52impl ShardRunner {
53    /// Creates a new runner for a Shard.
54    pub fn new(opt: ShardRunnerOptions) -> Self {
55        let (tx, rx) = mpsc::unbounded();
56
57        Self {
58            runner_rx: rx,
59            runner_tx: tx,
60            data: opt.data,
61            event_handlers: opt.event_handlers,
62            raw_event_handlers: opt.raw_event_handlers,
63            #[cfg(feature = "framework")]
64            framework: opt.framework,
65            manager: opt.manager,
66            shard: opt.shard,
67            #[cfg(feature = "voice")]
68            voice_manager: opt.voice_manager,
69            #[cfg(feature = "cache")]
70            cache: opt.cache,
71            http: opt.http,
72            #[cfg(feature = "collector")]
73            collectors: Arc::new(std::sync::Mutex::new(vec![])),
74        }
75    }
76
77    /// Starts the runner's loop to receive events.
78    ///
79    /// This runs a loop that performs the following in each iteration:
80    ///
81    /// 1. checks the receiver for [`ShardRunnerMessage`]s, possibly from the [`ShardManager`], and
82    ///    if there is one, acts on it.
83    ///
84    /// 2. checks if a heartbeat should be sent to the discord Gateway, and if so, sends one.
85    ///
86    /// 3. attempts to retrieve a message from the WebSocket, processing it into a [`GatewayEvent`].
87    ///    This will block for 100ms before assuming there is no message available.
88    ///
89    /// 4. Checks with the [`Shard`] to determine if the gateway event is specifying an action to
90    ///    take (e.g. resuming, reconnecting, heartbeating) and then performs that action, if any.
91    ///
92    /// 5. Dispatches the event via the Client.
93    ///
94    /// 6. Go back to 1.
95    ///
96    /// [`ShardManager`]: super::ShardManager
97    #[instrument(skip(self))]
98    pub async fn run(&mut self) -> Result<()> {
99        info!("[ShardRunner {:?}] Running", self.shard.shard_info());
100
101        loop {
102            trace!("[ShardRunner {:?}] loop iteration started.", self.shard.shard_info());
103            if !self.recv().await {
104                return Ok(());
105            }
106
107            // check heartbeat
108            if !self.shard.do_heartbeat().await {
109                warn!("[ShardRunner {:?}] Error heartbeating", self.shard.shard_info(),);
110
111                self.request_restart().await;
112                return Ok(());
113            }
114
115            let pre = self.shard.stage();
116            let (event, action, successful) = self.recv_event().await?;
117            let post = self.shard.stage();
118
119            if post != pre {
120                self.update_manager().await;
121
122                for event_handler in self.event_handlers.clone() {
123                    let context = self.make_context();
124                    let event = ShardStageUpdateEvent {
125                        new: post,
126                        old: pre,
127                        shard_id: self.shard.shard_info().id,
128                    };
129                    spawn_named("dispatch::event_handler::shard_stage_update", async move {
130                        event_handler.shard_stage_update(context, event).await;
131                    });
132                }
133            }
134
135            match action {
136                Some(ShardAction::Reconnect(ReconnectType::Reidentify)) => {
137                    self.request_restart().await;
138                    return Ok(());
139                },
140                Some(other) => {
141                    if let Err(e) = self.action(&other).await {
142                        debug!(
143                            "[ShardRunner {:?}] Reconnecting due to error performing {:?}: {:?}",
144                            self.shard.shard_info(),
145                            other,
146                            e
147                        );
148                        match self.shard.reconnection_type() {
149                            ReconnectType::Reidentify => {
150                                self.request_restart().await;
151                                return Ok(());
152                            },
153                            ReconnectType::Resume => {
154                                if let Err(why) = self.shard.resume().await {
155                                    warn!(
156                                        "[ShardRunner {:?}] Resume failed, reidentifying: {:?}",
157                                        self.shard.shard_info(),
158                                        why
159                                    );
160
161                                    self.request_restart().await;
162                                    return Ok(());
163                                }
164                            },
165                        };
166                    }
167                },
168                None => {},
169            }
170
171            if let Some(event) = event {
172                #[cfg(feature = "collector")]
173                self.collectors.lock().expect("poison").retain_mut(|callback| (callback.0)(&event));
174
175                dispatch_model(
176                    event,
177                    &self.make_context(),
178                    #[cfg(feature = "framework")]
179                    self.framework.clone(),
180                    self.event_handlers.clone(),
181                    self.raw_event_handlers.clone(),
182                );
183            }
184
185            if !successful && !self.shard.stage().is_connecting() {
186                self.request_restart().await;
187                return Ok(());
188            }
189            trace!("[ShardRunner {:?}] loop iteration reached the end.", self.shard.shard_info());
190        }
191    }
192
193    /// Clones the internal copy of the Sender to the shard runner.
194    pub(super) fn runner_tx(&self) -> Sender<ShardRunnerMessage> {
195        self.runner_tx.clone()
196    }
197
198    /// Takes an action that a [`Shard`] has determined should happen and then does it.
199    ///
200    /// For example, if the shard says that an Identify message needs to be sent, this will do
201    /// that.
202    ///
203    /// # Errors
204    ///
205    /// Returns
206    #[instrument(skip(self, action))]
207    async fn action(&mut self, action: &ShardAction) -> Result<()> {
208        match *action {
209            ShardAction::Reconnect(ReconnectType::Reidentify) => {
210                self.request_restart().await;
211                Ok(())
212            },
213            ShardAction::Reconnect(ReconnectType::Resume) => self.shard.resume().await,
214            ShardAction::Heartbeat => self.shard.heartbeat().await,
215            ShardAction::Identify => self.shard.identify().await,
216        }
217    }
218
219    // Checks if the ID received to shutdown is equivalent to the ID of the shard this runner is
220    // responsible. If so, it shuts down the WebSocket client.
221    //
222    // Returns whether the WebSocket client is still active.
223    //
224    // If true, the WebSocket client was _not_ shutdown. If false, it was.
225    #[instrument(skip(self))]
226    async fn checked_shutdown(&mut self, id: ShardId, close_code: u16) -> bool {
227        // First verify the ID so we know for certain this runner is to shutdown.
228        if id != self.shard.shard_info().id {
229            // Not meant for this runner for some reason, don't shutdown.
230            return true;
231        }
232
233        // Send a Close Frame to Discord, which allows a bot to "log off"
234        drop(
235            self.shard
236                .client
237                .close(Some(CloseFrame {
238                    code: close_code.into(),
239                    reason: Cow::from(""),
240                }))
241                .await,
242        );
243
244        // In return, we wait for either a Close Frame response, or an error, after which this WS
245        // is deemed disconnected from Discord.
246        loop {
247            match self.shard.client.next().await {
248                Some(Ok(tungstenite::Message::Close(_))) => break,
249                Some(Err(_)) => {
250                    warn!(
251                        "[ShardRunner {:?}] Received an error awaiting close frame",
252                        self.shard.shard_info(),
253                    );
254                    break;
255                },
256                _ => continue,
257            }
258        }
259
260        // Inform the manager that shutdown for this shard has finished.
261        self.manager.shutdown_finished(id);
262        false
263    }
264
265    fn make_context(&self) -> Context {
266        Context::new(
267            Arc::clone(&self.data),
268            self,
269            self.shard.shard_info().id,
270            Arc::clone(&self.http),
271            #[cfg(feature = "cache")]
272            Arc::clone(&self.cache),
273        )
274    }
275
276    // Handles a received value over the shard runner rx channel.
277    //
278    // Returns a boolean on whether the shard runner can continue.
279    //
280    // This always returns true, except in the case that the shard manager asked the runner to
281    // shutdown.
282    #[instrument(skip(self))]
283    async fn handle_rx_value(&mut self, msg: ShardRunnerMessage) -> bool {
284        match msg {
285            ShardRunnerMessage::Restart(id) => self.checked_shutdown(id, 4000).await,
286            ShardRunnerMessage::Shutdown(id, code) => self.checked_shutdown(id, code).await,
287            ShardRunnerMessage::ChunkGuild {
288                guild_id,
289                limit,
290                presences,
291                filter,
292                nonce,
293            } => self
294                .shard
295                .chunk_guild(guild_id, limit, presences, filter, nonce.as_deref())
296                .await
297                .is_ok(),
298            ShardRunnerMessage::Close(code, reason) => {
299                let reason = reason.unwrap_or_default();
300                let close = CloseFrame {
301                    code: code.into(),
302                    reason: Cow::from(reason),
303                };
304                self.shard.client.close(Some(close)).await.is_ok()
305            },
306            ShardRunnerMessage::Message(msg) => self.shard.client.send(msg).await.is_ok(),
307            ShardRunnerMessage::SetActivity(activity) => {
308                self.shard.set_activity(activity);
309                self.shard.update_presence().await.is_ok()
310            },
311            ShardRunnerMessage::SetPresence(activity, status) => {
312                self.shard.set_presence(activity, status);
313                self.shard.update_presence().await.is_ok()
314            },
315            ShardRunnerMessage::SetStatus(status) => {
316                self.shard.set_status(status);
317                self.shard.update_presence().await.is_ok()
318            },
319        }
320    }
321
322    #[cfg(feature = "voice")]
323    #[instrument(skip(self))]
324    async fn handle_voice_event(&self, event: &Event) {
325        if let Some(voice_manager) = &self.voice_manager {
326            match event {
327                Event::Ready(_) => {
328                    voice_manager
329                        .register_shard(self.shard.shard_info().id.0, self.runner_tx.clone())
330                        .await;
331                },
332                Event::VoiceServerUpdate(event) => {
333                    if let Some(guild_id) = event.guild_id {
334                        voice_manager.server_update(guild_id, &event.endpoint, &event.token).await;
335                    }
336                },
337                Event::VoiceStateUpdate(event) => {
338                    if let Some(guild_id) = event.voice_state.guild_id {
339                        voice_manager.state_update(guild_id, &event.voice_state).await;
340                    }
341                },
342                _ => {},
343            }
344        }
345    }
346
347    // Receives values over the internal shard runner rx channel and handles them.
348    //
349    // This will loop over values until there is no longer one.
350    //
351    // Requests a restart if the sending half of the channel disconnects. This should _never_
352    // happen, as the sending half is kept on the runner.
353    // Returns whether the shard runner is in a state that can continue.
354    #[instrument(skip(self))]
355    async fn recv(&mut self) -> bool {
356        loop {
357            match self.runner_rx.try_next() {
358                Ok(Some(value)) => {
359                    if !self.handle_rx_value(value).await {
360                        return false;
361                    }
362                },
363                Ok(None) => {
364                    warn!(
365                        "[ShardRunner {:?}] Sending half DC; restarting",
366                        self.shard.shard_info(),
367                    );
368
369                    self.request_restart().await;
370                    return false;
371                },
372                Err(_) => break,
373            }
374        }
375
376        // There are no longer any values available.
377
378        true
379    }
380
381    /// Returns a received event, as well as whether reading the potentially present event was
382    /// successful.
383    #[instrument(skip(self))]
384    async fn recv_event(&mut self) -> Result<(Option<Event>, Option<ShardAction>, bool)> {
385        let gw_event = match self.shard.client.recv_json().await {
386            Ok(inner) => Ok(inner),
387            Err(Error::Tungstenite(TungsteniteError::Io(_))) => {
388                debug!("Attempting to auto-reconnect");
389
390                match self.shard.reconnection_type() {
391                    ReconnectType::Reidentify => return Ok((None, None, false)),
392                    ReconnectType::Resume => {
393                        if let Err(why) = self.shard.resume().await {
394                            warn!("Failed to resume: {:?}", why);
395
396                            // Don't spam reattempts on internet connection loss
397                            tokio::time::sleep(std::time::Duration::from_secs(1)).await;
398
399                            return Ok((None, None, false));
400                        }
401                    },
402                }
403
404                return Ok((None, None, true));
405            },
406            Err(why) => Err(why),
407        };
408
409        let event = match gw_event {
410            Ok(Some(event)) => Ok(event),
411            Ok(None) => return Ok((None, None, true)),
412            Err(why) => Err(why),
413        };
414
415        let action = match self.shard.handle_event(&event) {
416            Ok(Some(action)) => Some(action),
417            Ok(None) => None,
418            Err(why) => {
419                error!("Shard handler received err: {:?}", why);
420
421                match &why {
422                    Error::Gateway(
423                        error @ (GatewayError::InvalidAuthentication
424                        | GatewayError::InvalidGatewayIntents
425                        | GatewayError::DisallowedGatewayIntents),
426                    ) => {
427                        self.manager.return_with_value(Err(error.clone())).await;
428
429                        return Err(why);
430                    },
431                    _ => return Ok((None, None, true)),
432                }
433            },
434        };
435
436        if let Ok(GatewayEvent::HeartbeatAck) = event {
437            self.update_manager().await;
438        }
439
440        #[cfg(feature = "voice")]
441        {
442            if let Ok(GatewayEvent::Dispatch(_, ref event)) = event {
443                self.handle_voice_event(event).await;
444            }
445        }
446
447        let event = match event {
448            Ok(GatewayEvent::Dispatch(_, event)) => Some(event),
449            _ => None,
450        };
451
452        Ok((event, action, true))
453    }
454
455    #[instrument(skip(self))]
456    async fn request_restart(&mut self) {
457        debug!("[ShardRunner {:?}] Requesting restart", self.shard.shard_info());
458
459        self.update_manager().await;
460
461        let shard_id = self.shard.shard_info().id;
462        self.manager.restart_shard(shard_id).await;
463
464        #[cfg(feature = "voice")]
465        if let Some(voice_manager) = &self.voice_manager {
466            voice_manager.deregister_shard(shard_id.0).await;
467        }
468    }
469
470    #[instrument(skip(self))]
471    async fn update_manager(&self) {
472        self.manager
473            .update_shard_latency_and_stage(
474                self.shard.shard_info().id,
475                self.shard.latency(),
476                self.shard.stage(),
477            )
478            .await;
479    }
480}
481
482/// Options to be passed to [`ShardRunner::new`].
483pub struct ShardRunnerOptions {
484    pub data: Arc<RwLock<TypeMap>>,
485    pub event_handlers: Vec<Arc<dyn EventHandler>>,
486    pub raw_event_handlers: Vec<Arc<dyn RawEventHandler>>,
487    #[cfg(feature = "framework")]
488    pub framework: Option<Arc<dyn Framework>>,
489    pub manager: Arc<ShardManager>,
490    pub shard: Shard,
491    #[cfg(feature = "voice")]
492    pub voice_manager: Option<Arc<dyn VoiceGatewayManager>>,
493    #[cfg(feature = "cache")]
494    pub cache: Arc<Cache>,
495    pub http: Arc<Http>,
496}