serenity/gateway/shard.rs
1use std::sync::Arc;
2use std::time::{Duration as StdDuration, Instant};
3
4use tokio::sync::Mutex;
5use tokio_tungstenite::tungstenite::error::Error as TungsteniteError;
6use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame;
7use tracing::{debug, error, info, instrument, trace, warn};
8use url::Url;
9
10use super::{
11 ActivityData,
12 ChunkGuildFilter,
13 ConnectionStage,
14 GatewayError,
15 PresenceData,
16 ReconnectType,
17 ShardAction,
18 WsClient,
19};
20use crate::constants::{self, close_codes};
21use crate::internal::prelude::*;
22use crate::model::event::{Event, GatewayEvent};
23use crate::model::gateway::{GatewayIntents, ShardInfo};
24use crate::model::id::{ApplicationId, GuildId};
25use crate::model::user::OnlineStatus;
26
27/// A Shard is a higher-level handler for a websocket connection to Discord's gateway.
28///
29/// The shard allows for sending and receiving messages over the websocket, such as setting the
30/// active activity, reconnecting, syncing guilds, and more.
31///
32/// Refer to the [module-level documentation][module docs] for information on effectively using
33/// multiple shards, if you need to.
34///
35/// Note that there are additional methods available if you are manually managing a shard yourself,
36/// although they are hidden from the documentation since there are few use cases for doing such.
37///
38/// # Stand-alone shards
39///
40/// You may instantiate a shard yourself - decoupled from the [`Client`] - if you need to. For most
41/// use cases, you will not need to do this, and you can leave the client to do it.
42///
43/// This can be done by passing in the required parameters to [`Self::new`]. You can then manually
44/// handle the shard yourself.
45///
46/// **Note**: You _really_ do not need to do this. Just call one of the appropriate methods on the
47/// [`Client`].
48///
49/// # Examples
50///
51/// See the documentation for [`Self::new`] on how to use this.
52///
53/// [`Client`]: crate::Client
54/// [`receive`]: #method.receive
55/// [docs]: https://discord.com/developers/docs/topics/gateway#sharding
56/// [module docs]: crate::gateway#sharding
57pub struct Shard {
58 pub client: WsClient,
59 presence: PresenceData,
60 last_heartbeat_sent: Option<Instant>,
61 last_heartbeat_ack: Option<Instant>,
62 heartbeat_interval: Option<std::time::Duration>,
63 application_id_callback: Option<Box<dyn FnOnce(ApplicationId) + Send + Sync>>,
64 /// This is used by the heartbeater to determine whether the last heartbeat was sent without an
65 /// acknowledgement, and whether to reconnect.
66 // This must be set to `true` in `Shard::handle_event`'s `Ok(GatewayEvent::HeartbeatAck)` arm.
67 last_heartbeat_acknowledged: bool,
68 seq: u64,
69 session_id: Option<String>,
70 info: ShardInfo,
71 stage: ConnectionStage,
72 /// Instant of when the shard was started.
73 // This acts as a timeout to determine if the shard has - for some reason - not started within
74 // a decent amount of time.
75 pub started: Instant,
76 pub token: String,
77 ws_url: Arc<Mutex<String>>,
78 pub intents: GatewayIntents,
79}
80
81impl Shard {
82 /// Instantiates a new instance of a Shard, bypassing the client.
83 ///
84 /// **Note**: You should likely never need to do this yourself.
85 ///
86 /// # Examples
87 ///
88 /// Instantiating a new Shard manually for a bot with no shards, and then listening for events:
89 ///
90 /// ```rust,no_run
91 /// use std::sync::Arc;
92 ///
93 /// use serenity::gateway::Shard;
94 /// use serenity::model::gateway::{GatewayIntents, ShardInfo};
95 /// use serenity::model::id::ShardId;
96 /// use tokio::sync::Mutex;
97 /// #
98 /// # use serenity::http::Http;
99 /// #
100 /// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
101 /// # let http: Arc<Http> = unimplemented!();
102 /// let token = std::env::var("DISCORD_BOT_TOKEN")?;
103 /// let shard_info = ShardInfo {
104 /// id: ShardId(0),
105 /// total: 1,
106 /// };
107 ///
108 /// // retrieve the gateway response, which contains the URL to connect to
109 /// let gateway = Arc::new(Mutex::new(http.get_gateway().await?.url));
110 /// let shard = Shard::new(gateway, &token, shard_info, GatewayIntents::all(), None).await?;
111 ///
112 /// // at this point, you can create a `loop`, and receive events and match
113 /// // their variants
114 /// # Ok(())
115 /// # }
116 /// ```
117 ///
118 /// # Errors
119 ///
120 /// On Error, will return either [`Error::Gateway`], [`Error::Tungstenite`] or a Rustls/native
121 /// TLS error.
122 pub async fn new(
123 ws_url: Arc<Mutex<String>>,
124 token: &str,
125 info: ShardInfo,
126 intents: GatewayIntents,
127 presence: Option<PresenceData>,
128 ) -> Result<Shard> {
129 let url = ws_url.lock().await.clone();
130 let client = connect(&url).await?;
131
132 let presence = presence.unwrap_or_default();
133 let last_heartbeat_sent = None;
134 let last_heartbeat_ack = None;
135 let heartbeat_interval = None;
136 let last_heartbeat_acknowledged = true;
137 let seq = 0;
138 let stage = ConnectionStage::Handshake;
139 let session_id = None;
140
141 Ok(Shard {
142 client,
143 presence,
144 last_heartbeat_sent,
145 last_heartbeat_ack,
146 heartbeat_interval,
147 application_id_callback: None,
148 last_heartbeat_acknowledged,
149 seq,
150 stage,
151 started: Instant::now(),
152 token: token.to_string(),
153 session_id,
154 info,
155 ws_url,
156 intents,
157 })
158 }
159
160 /// Sets a callback to be called when the gateway receives the application's ID from Discord.
161 ///
162 /// Used internally by serenity to set the Http's internal application ID automatically.
163 pub fn set_application_id_callback(
164 &mut self,
165 callback: impl FnOnce(ApplicationId) + Send + Sync + 'static,
166 ) {
167 self.application_id_callback = Some(Box::new(callback));
168 }
169
170 /// Retrieves the current presence of the shard.
171 #[inline]
172 pub fn presence(&self) -> &PresenceData {
173 &self.presence
174 }
175
176 /// Retrieves the value of when the last heartbeat was sent.
177 #[inline]
178 pub fn last_heartbeat_sent(&self) -> Option<Instant> {
179 self.last_heartbeat_sent
180 }
181
182 /// Retrieves the value of when the last heartbeat ack was received.
183 #[inline]
184 pub fn last_heartbeat_ack(&self) -> Option<Instant> {
185 self.last_heartbeat_ack
186 }
187
188 /// Sends a heartbeat to the gateway with the current sequence.
189 ///
190 /// This sets the last heartbeat time to now, and [`Self::last_heartbeat_acknowledged`] to
191 /// `false`.
192 ///
193 /// # Errors
194 ///
195 /// Returns [`GatewayError::HeartbeatFailed`] if there was an error sending a heartbeat.
196 #[instrument(skip(self))]
197 pub async fn heartbeat(&mut self) -> Result<()> {
198 match self.client.send_heartbeat(&self.info, Some(self.seq)).await {
199 Ok(()) => {
200 self.last_heartbeat_sent = Some(Instant::now());
201 self.last_heartbeat_acknowledged = false;
202
203 Ok(())
204 },
205 Err(why) => {
206 match why {
207 Error::Tungstenite(TungsteniteError::Io(err)) => {
208 if err.raw_os_error() != Some(32) {
209 debug!("[{:?}] Err heartbeating: {:?}", self.info, err);
210 }
211 },
212 other => {
213 warn!("[{:?}] Other err w/ keepalive: {:?}", self.info, other);
214 },
215 }
216
217 Err(Error::Gateway(GatewayError::HeartbeatFailed))
218 },
219 }
220 }
221
222 /// Returns the heartbeat interval dictated by Discord, if the Hello packet has been received.
223 #[inline]
224 pub fn heartbeat_interval(&self) -> Option<std::time::Duration> {
225 self.heartbeat_interval
226 }
227
228 #[inline]
229 pub fn last_heartbeat_acknowledged(&self) -> bool {
230 self.last_heartbeat_acknowledged
231 }
232
233 #[inline]
234 pub fn seq(&self) -> u64 {
235 self.seq
236 }
237
238 #[inline]
239 pub fn session_id(&self) -> Option<&String> {
240 self.session_id.as_ref()
241 }
242
243 #[inline]
244 #[instrument(skip(self))]
245 pub fn set_activity(&mut self, activity: Option<ActivityData>) {
246 self.presence.activity = activity;
247 }
248
249 #[inline]
250 #[instrument(skip(self))]
251 pub fn set_presence(&mut self, activity: Option<ActivityData>, status: OnlineStatus) {
252 self.set_activity(activity);
253 self.set_status(status);
254 }
255
256 #[inline]
257 #[instrument(skip(self))]
258 pub fn set_status(&mut self, mut status: OnlineStatus) {
259 if status == OnlineStatus::Offline {
260 status = OnlineStatus::Invisible;
261 }
262
263 self.presence.status = status;
264 }
265
266 /// Retrieves a copy of the current shard information.
267 ///
268 /// For example, if using 3 shards in total, and if this is shard 1, then it can be read as
269 /// "the second of three shards".
270 pub fn shard_info(&self) -> ShardInfo {
271 self.info
272 }
273
274 /// Returns the current connection stage of the shard.
275 pub fn stage(&self) -> ConnectionStage {
276 self.stage
277 }
278
279 #[instrument(skip(self))]
280 fn handle_gateway_dispatch(&mut self, seq: u64, event: &Event) -> Option<ShardAction> {
281 if seq > self.seq + 1 {
282 warn!("[{:?}] Sequence off; them: {}, us: {}", self.info, seq, self.seq);
283 }
284
285 match &event {
286 Event::Ready(ready) => {
287 debug!("[{:?}] Received Ready", self.info);
288
289 self.session_id = Some(ready.ready.session_id.clone());
290 self.stage = ConnectionStage::Connected;
291
292 if let Some(callback) = self.application_id_callback.take() {
293 callback(ready.ready.application.id);
294 }
295 },
296 Event::Resumed(_) => {
297 info!("[{:?}] Resumed", self.info);
298
299 self.stage = ConnectionStage::Connected;
300 self.last_heartbeat_acknowledged = true;
301 self.last_heartbeat_sent = Some(Instant::now());
302 self.last_heartbeat_ack = None;
303 },
304 _ => {},
305 }
306
307 self.seq = seq;
308
309 None
310 }
311
312 #[instrument(skip(self))]
313 fn handle_gateway_closed(
314 &mut self,
315 data: Option<&CloseFrame<'static>>,
316 ) -> Result<Option<ShardAction>> {
317 let num = data.map(|d| d.code.into());
318 let clean = num == Some(1000);
319
320 match num {
321 Some(close_codes::UNKNOWN_OPCODE) => {
322 warn!("[{:?}] Sent invalid opcode.", self.info);
323 },
324 Some(close_codes::DECODE_ERROR) => {
325 warn!("[{:?}] Sent invalid message.", self.info);
326 },
327 Some(close_codes::NOT_AUTHENTICATED) => {
328 warn!("[{:?}] Sent no authentication.", self.info);
329
330 return Err(Error::Gateway(GatewayError::NoAuthentication));
331 },
332 Some(close_codes::AUTHENTICATION_FAILED) => {
333 error!("[{:?}] Sent invalid authentication, please check the token.", self.info);
334
335 return Err(Error::Gateway(GatewayError::InvalidAuthentication));
336 },
337 Some(close_codes::ALREADY_AUTHENTICATED) => {
338 warn!("[{:?}] Already authenticated.", self.info);
339 },
340 Some(close_codes::INVALID_SEQUENCE) => {
341 warn!("[{:?}] Sent invalid seq: {}.", self.info, self.seq);
342
343 self.seq = 0;
344 },
345 Some(close_codes::RATE_LIMITED) => {
346 warn!("[{:?}] Gateway ratelimited.", self.info);
347 },
348 Some(close_codes::INVALID_SHARD) => {
349 warn!("[{:?}] Sent invalid shard data.", self.info);
350
351 return Err(Error::Gateway(GatewayError::InvalidShardData));
352 },
353 Some(close_codes::SHARDING_REQUIRED) => {
354 error!("[{:?}] Shard has too many guilds.", self.info);
355
356 return Err(Error::Gateway(GatewayError::OverloadedShard));
357 },
358 Some(4006 | close_codes::SESSION_TIMEOUT) => {
359 info!("[{:?}] Invalid session.", self.info);
360
361 self.session_id = None;
362 },
363 Some(close_codes::INVALID_GATEWAY_INTENTS) => {
364 error!("[{:?}] Invalid gateway intents have been provided.", self.info);
365
366 return Err(Error::Gateway(GatewayError::InvalidGatewayIntents));
367 },
368 Some(close_codes::DISALLOWED_GATEWAY_INTENTS) => {
369 error!("[{:?}] Disallowed gateway intents have been provided.", self.info);
370
371 return Err(Error::Gateway(GatewayError::DisallowedGatewayIntents));
372 },
373 Some(other) if !clean => {
374 warn!(
375 "[{:?}] Unknown unclean close {}: {:?}",
376 self.info,
377 other,
378 data.map(|d| &d.reason),
379 );
380 },
381 _ => {},
382 }
383
384 let resume = num
385 .map_or(true, |x| x != close_codes::AUTHENTICATION_FAILED && self.session_id.is_some());
386
387 Ok(Some(if resume {
388 ShardAction::Reconnect(ReconnectType::Resume)
389 } else {
390 ShardAction::Reconnect(ReconnectType::Reidentify)
391 }))
392 }
393
394 /// Handles an event from the gateway over the receiver, requiring the receiver to be passed if
395 /// a reconnect needs to occur.
396 ///
397 /// The best case scenario is that one of two values is returned:
398 /// - `Ok(None)`: a heartbeat, late hello, or session invalidation was received;
399 /// - `Ok(Some((event, None)))`: an op0 dispatch was received, and the shard's voice state will
400 /// be updated, _if_ the `voice` feature is enabled.
401 ///
402 /// # Errors
403 ///
404 /// Returns a [`GatewayError::InvalidAuthentication`] if invalid authentication was sent in the
405 /// IDENTIFY.
406 ///
407 /// Returns a [`GatewayError::InvalidShardData`] if invalid shard data was sent in the
408 /// IDENTIFY.
409 ///
410 /// Returns a [`GatewayError::NoAuthentication`] if no authentication was sent in the IDENTIFY.
411 ///
412 /// Returns a [`GatewayError::OverloadedShard`] if the shard would have too many guilds
413 /// assigned to it.
414 #[instrument(skip(self))]
415 pub fn handle_event(&mut self, event: &Result<GatewayEvent>) -> Result<Option<ShardAction>> {
416 match event {
417 Ok(GatewayEvent::Dispatch(seq, event)) => Ok(self.handle_gateway_dispatch(*seq, event)),
418 Ok(GatewayEvent::Heartbeat(..)) => {
419 info!("[{:?}] Received shard heartbeat", self.info);
420
421 Ok(Some(ShardAction::Heartbeat))
422 },
423 Ok(GatewayEvent::HeartbeatAck) => {
424 self.last_heartbeat_ack = Some(Instant::now());
425 self.last_heartbeat_acknowledged = true;
426
427 trace!("[{:?}] Received heartbeat ack", self.info);
428
429 Ok(None)
430 },
431 &Ok(GatewayEvent::Hello(interval)) => {
432 debug!("[{:?}] Received a Hello; interval: {}", self.info, interval);
433
434 if self.stage == ConnectionStage::Resuming {
435 return Ok(None);
436 }
437
438 self.heartbeat_interval = Some(std::time::Duration::from_millis(interval));
439
440 Ok(Some(if self.stage == ConnectionStage::Handshake {
441 ShardAction::Identify
442 } else {
443 debug!("[{:?}] Received late Hello; autoreconnecting", self.info);
444
445 ShardAction::Reconnect(self.reconnection_type())
446 }))
447 },
448 &Ok(GatewayEvent::InvalidateSession(resumable)) => {
449 info!("[{:?}] Received session invalidation", self.info);
450
451 Ok(Some(if resumable {
452 ShardAction::Reconnect(ReconnectType::Resume)
453 } else {
454 ShardAction::Reconnect(ReconnectType::Reidentify)
455 }))
456 },
457 Ok(GatewayEvent::Reconnect) => Ok(Some(ShardAction::Reconnect(ReconnectType::Resume))),
458 Err(Error::Gateway(GatewayError::Closed(data))) => {
459 self.handle_gateway_closed(data.as_ref())
460 },
461 Err(Error::Tungstenite(why)) => {
462 info!("[{:?}] Websocket error: {:?}", self.info, why);
463 info!("[{:?}] Will attempt to auto-reconnect", self.info);
464
465 Ok(Some(ShardAction::Reconnect(self.reconnection_type())))
466 },
467 Err(why) => {
468 warn!("[{:?}] Unhandled error: {:?}", self.info, why);
469
470 Ok(None)
471 },
472 }
473 }
474
475 /// Does a heartbeat if needed. Returns false if something went wrong and the shard should be
476 /// restarted.
477 ///
478 /// `true` is returned under one of the following conditions:
479 /// - the heartbeat interval has not elapsed
480 /// - a heartbeat was successfully sent
481 /// - there is no known heartbeat interval yet
482 ///
483 /// `false` is returned under one of the following conditions:
484 /// - a heartbeat acknowledgement was not received in time
485 /// - an error occurred while heartbeating
486 #[instrument(skip(self))]
487 pub async fn do_heartbeat(&mut self) -> bool {
488 let Some(heartbeat_interval) = self.heartbeat_interval else {
489 // No Hello received yet
490 return self.started.elapsed() < StdDuration::from_secs(15);
491 };
492
493 // If a duration of time less than the heartbeat_interval has passed, then don't perform a
494 // keepalive or attempt to reconnect.
495 if let Some(last_sent) = self.last_heartbeat_sent {
496 if last_sent.elapsed() <= heartbeat_interval {
497 return true;
498 }
499 }
500
501 // If the last heartbeat didn't receive an acknowledgement, then auto-reconnect.
502 if !self.last_heartbeat_acknowledged {
503 debug!("[{:?}] Last heartbeat not acknowledged", self.info,);
504
505 return false;
506 }
507
508 // Otherwise, we're good to heartbeat.
509 if let Err(why) = self.heartbeat().await {
510 warn!("[{:?}] Err heartbeating: {:?}", self.info, why);
511
512 false
513 } else {
514 trace!("[{:?}] Heartbeat", self.info);
515
516 true
517 }
518 }
519
520 /// Calculates the heartbeat latency between the shard and the gateway.
521 // Shamelessly stolen from brayzure's commit in eris:
522 // <https://github.com/abalabahaha/eris/commit/0ce296ae9a542bcec0edf1c999ee2d9986bed5a6>
523 #[instrument(skip(self))]
524 pub fn latency(&self) -> Option<StdDuration> {
525 if let (Some(sent), Some(received)) = (self.last_heartbeat_sent, self.last_heartbeat_ack) {
526 if received > sent {
527 return Some(received - sent);
528 }
529 }
530
531 None
532 }
533
534 /// Performs a deterministic reconnect.
535 ///
536 /// The type of reconnect is deterministic on whether a [`Self::session_id`].
537 ///
538 /// If the `session_id` still exists, then a RESUME is sent. If not, then an IDENTIFY is sent.
539 ///
540 /// Note that, if the shard is already in a stage of [`ConnectionStage::Connecting`], then no
541 /// action will be performed.
542 pub fn should_reconnect(&mut self) -> Option<ReconnectType> {
543 if self.stage == ConnectionStage::Connecting {
544 return None;
545 }
546
547 Some(self.reconnection_type())
548 }
549
550 pub fn reconnection_type(&self) -> ReconnectType {
551 if self.session_id().is_some() {
552 ReconnectType::Resume
553 } else {
554 ReconnectType::Reidentify
555 }
556 }
557
558 /// Requests that one or multiple [`Guild`]s be chunked.
559 ///
560 /// This will ask the gateway to start sending member chunks for large guilds (250 members+).
561 /// If a guild is over 250 members, then a full member list will not be downloaded, and must
562 /// instead be requested to be sent in "chunks" containing members.
563 ///
564 /// Member chunks are sent as the [`Event::GuildMembersChunk`] event. Each chunk only contains
565 /// a partial amount of the total members.
566 ///
567 /// If the `cache` feature is enabled, the cache will automatically be updated with member
568 /// chunks.
569 ///
570 /// # Examples
571 ///
572 /// Chunk a single guild by Id, limiting to 2000 [`Member`]s, and not
573 /// specifying a query parameter:
574 ///
575 /// ```rust,no_run
576 /// # use tokio::sync::Mutex;
577 /// # use serenity::gateway::{ChunkGuildFilter, Shard};
578 /// # use serenity::model::gateway::{GatewayIntents, ShardInfo};
579 /// # use serenity::model::id::ShardId;
580 /// # use std::sync::Arc;
581 /// #
582 /// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
583 /// # let mutex = Arc::new(Mutex::new("".to_string()));
584 /// # let shard_info = ShardInfo {
585 /// # id: ShardId(0),
586 /// # total: 1,
587 /// # };
588 /// #
589 /// # let mut shard = Shard::new(mutex.clone(), "", shard_info, GatewayIntents::all(), None).await?;
590 /// #
591 /// use serenity::model::id::GuildId;
592 ///
593 /// shard.chunk_guild(GuildId::new(81384788765712384), Some(2000), false, ChunkGuildFilter::None, None).await?;
594 /// # Ok(())
595 /// # }
596 /// ```
597 ///
598 /// Chunk a single guild by Id, limiting to 20 members, and specifying a query parameter of
599 /// `"do"` and a nonce of `"request"`:
600 ///
601 /// ```rust,no_run
602 /// # use tokio::sync::Mutex;
603 /// # use serenity::model::gateway::{GatewayIntents, ShardInfo};
604 /// # use serenity::gateway::{ChunkGuildFilter, Shard};
605 /// # use serenity::model::id::ShardId;
606 /// # use std::error::Error;
607 /// # use std::sync::Arc;
608 /// #
609 /// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
610 /// # let mutex = Arc::new(Mutex::new("".to_string()));
611 /// #
612 /// # let shard_info = ShardInfo {
613 /// # id: ShardId(0),
614 /// # total: 1,
615 /// # };
616 /// # let mut shard = Shard::new(mutex.clone(), "", shard_info, GatewayIntents::all(), None).await?;
617 /// #
618 /// use serenity::model::id::GuildId;
619 ///
620 /// shard
621 /// .chunk_guild(
622 /// GuildId::new(81384788765712384),
623 /// Some(20),
624 /// false,
625 /// ChunkGuildFilter::Query("do".to_owned()),
626 /// Some("request"),
627 /// )
628 /// .await?;
629 /// # Ok(())
630 /// # }
631 /// ```
632 ///
633 /// # Errors
634 /// Errors if there is a problem with the WS connection.
635 ///
636 /// [`Event::GuildMembersChunk`]: crate::model::event::Event::GuildMembersChunk
637 /// [`Guild`]: crate::model::guild::Guild
638 /// [`Member`]: crate::model::guild::Member
639 #[instrument(skip(self))]
640 pub async fn chunk_guild(
641 &mut self,
642 guild_id: GuildId,
643 limit: Option<u16>,
644 presences: bool,
645 filter: ChunkGuildFilter,
646 nonce: Option<&str>,
647 ) -> Result<()> {
648 debug!("[{:?}] Requesting member chunks", self.info);
649
650 self.client.send_chunk_guild(guild_id, &self.info, limit, presences, filter, nonce).await
651 }
652
653 /// Requests [Soundboard sounds][soundboard] to be fetched from one or multiple [`Guild`]s.
654 ///
655 /// This will ask the gateway to start sending soundboard sounds.
656 ///
657 /// Soundboard sounds are sent as the [`Event::SoundboardSounds`] event.
658 ///
659 /// # Errors
660 /// Errors if there is a problem with the WS connection.
661 ///
662 /// [`Event::SoundboardSounds`]: crate::model::event::Event::SoundboardSounds
663 /// [`Guild`]: crate::model::guild::Guild
664 /// [soundboard]: crate::model::soundboard::Soundboard
665 #[instrument(skip(self))]
666 pub async fn request_soundboard_sounds(&mut self, guild_ids: &[GuildId]) -> Result<()> {
667 debug!("[{:?}] Requesting soundboard sounds", self.info);
668
669 self.client.request_soundboard_sounds(guild_ids, &self.info).await
670 }
671
672 /// Sets the shard as going into identifying stage, which sets:
673 /// - the time that the last heartbeat sent as being now
674 /// - the `stage` to [`ConnectionStage::Identifying`]
675 ///
676 /// # Errors
677 /// Errors if there is a problem with the WS connection.
678 #[instrument(skip(self))]
679 pub async fn identify(&mut self) -> Result<()> {
680 self.client.send_identify(&self.info, &self.token, self.intents, &self.presence).await?;
681
682 self.last_heartbeat_sent = Some(Instant::now());
683 self.stage = ConnectionStage::Identifying;
684
685 Ok(())
686 }
687
688 /// Initializes a new WebSocket client.
689 ///
690 /// This will set the stage of the shard before and after instantiation of the client.
691 ///
692 /// # Errors
693 ///
694 /// Errors if unable to establish a websocket connection.
695 #[instrument(skip(self))]
696 pub async fn initialize(&mut self) -> Result<WsClient> {
697 debug!("[{:?}] Initializing.", self.info);
698
699 // We need to do two, sort of three things here:
700 // - set the stage of the shard as opening the websocket connection
701 // - open the websocket connection
702 // - if successful, set the current stage as Handshaking
703 //
704 // This is used to accurately assess whether the state of the shard is accurate when a
705 // Hello is received.
706 self.stage = ConnectionStage::Connecting;
707 self.started = Instant::now();
708 let url = &self.ws_url.lock().await.clone();
709 let client = connect(url).await?;
710 self.stage = ConnectionStage::Handshake;
711
712 Ok(client)
713 }
714
715 #[instrument(skip(self))]
716 pub async fn reset(&mut self) {
717 self.last_heartbeat_sent = Some(Instant::now());
718 self.last_heartbeat_ack = None;
719 self.heartbeat_interval = None;
720 self.last_heartbeat_acknowledged = true;
721 self.session_id = None;
722 self.stage = ConnectionStage::Disconnected;
723 self.seq = 0;
724 }
725
726 /// # Errors
727 ///
728 /// Errors if unable to re-establish a websocket connection.
729 #[instrument(skip(self))]
730 pub async fn resume(&mut self) -> Result<()> {
731 debug!("[{:?}] Attempting to resume", self.info);
732
733 self.client = self.initialize().await?;
734 self.stage = ConnectionStage::Resuming;
735
736 match &self.session_id {
737 Some(session_id) => {
738 self.client.send_resume(&self.info, session_id, self.seq, &self.token).await
739 },
740 None => Err(Error::Gateway(GatewayError::NoSessionId)),
741 }
742 }
743
744 /// # Errors
745 ///
746 /// Errors if unable to re-establish a websocket connection.
747 #[instrument(skip(self))]
748 pub async fn reconnect(&mut self) -> Result<()> {
749 info!("[{:?}] Attempting to reconnect", self.shard_info());
750
751 self.reset().await;
752 self.client = self.initialize().await?;
753
754 Ok(())
755 }
756
757 /// # Errors
758 ///
759 /// Errors if there is a problem with the WS connection.
760 #[instrument(skip(self))]
761 pub async fn update_presence(&mut self) -> Result<()> {
762 self.client.send_presence_update(&self.info, &self.presence).await
763 }
764}
765
766async fn connect(base_url: &str) -> Result<WsClient> {
767 let url =
768 Url::parse(&format!("{base_url}?v={}", constants::GATEWAY_VERSION)).map_err(|why| {
769 warn!("Error building gateway URL with base `{}`: {:?}", base_url, why);
770
771 Error::Gateway(GatewayError::BuildingUrl)
772 })?;
773
774 WsClient::connect(url).await
775}