Skip to main content

serenity/gateway/
ws.rs

1use std::env::consts;
2#[cfg(feature = "client")]
3use std::io::Read;
4use std::time::SystemTime;
5
6#[cfg(feature = "client")]
7use flate2::read::ZlibDecoder;
8use futures::SinkExt;
9#[cfg(feature = "client")]
10use futures::StreamExt;
11use tokio::net::TcpStream;
12#[cfg(feature = "client")]
13use tokio::time::{timeout, Duration};
14#[cfg(feature = "client")]
15use tokio_tungstenite::tungstenite::protocol::CloseFrame;
16use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
17#[cfg(feature = "client")]
18use tokio_tungstenite::tungstenite::Error as WsError;
19use tokio_tungstenite::tungstenite::Message;
20use tokio_tungstenite::{connect_async_with_config, MaybeTlsStream, WebSocketStream};
21#[cfg(feature = "client")]
22use tracing::warn;
23use tracing::{debug, instrument, trace};
24use url::Url;
25
26use super::{ActivityData, ChunkGuildFilter, PresenceData};
27use crate::constants::{self, Opcode};
28#[cfg(feature = "client")]
29use crate::gateway::GatewayError;
30#[cfg(feature = "client")]
31use crate::json::from_str;
32use crate::json::to_string;
33#[cfg(feature = "client")]
34use crate::model::event::GatewayEvent;
35use crate::model::gateway::{GatewayIntents, ShardInfo};
36use crate::model::id::{GuildId, UserId};
37#[cfg(feature = "client")]
38use crate::Error;
39use crate::Result;
40
41#[derive(Serialize)]
42struct IdentifyProperties {
43    browser: &'static str,
44    device: &'static str,
45    os: &'static str,
46}
47
48#[derive(Serialize)]
49struct ChunkGuildMessage<'a> {
50    guild_id: GuildId,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    query: Option<&'a str>,
53    limit: u16,
54    presences: bool,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    user_ids: Option<Vec<UserId>>,
57    nonce: &'a str,
58}
59
60#[derive(Serialize)]
61struct PresenceUpdateMessage<'a> {
62    afk: bool,
63    status: &'a str,
64    since: SystemTime,
65    activities: &'a [&'a ActivityData],
66}
67
68#[derive(Serialize)]
69#[serde(untagged)]
70enum WebSocketMessageData<'a> {
71    Heartbeat(Option<u64>),
72    ChunkGuild(ChunkGuildMessage<'a>),
73    SoundboardSounds {
74        guild_ids: &'a [GuildId],
75    },
76    Identify {
77        compress: bool,
78        token: &'a str,
79        large_threshold: u8,
80        shard: &'a ShardInfo,
81        intents: GatewayIntents,
82        properties: IdentifyProperties,
83        presence: PresenceUpdateMessage<'a>,
84    },
85    PresenceUpdate(PresenceUpdateMessage<'a>),
86    Resume {
87        session_id: &'a str,
88        token: &'a str,
89        seq: u64,
90    },
91}
92
93#[derive(Serialize)]
94struct WebSocketMessage<'a> {
95    op: Opcode,
96    d: WebSocketMessageData<'a>,
97}
98
99pub struct WsClient(WebSocketStream<MaybeTlsStream<TcpStream>>);
100
101#[cfg(feature = "client")]
102const TIMEOUT: Duration = Duration::from_millis(500);
103#[cfg(feature = "client")]
104const DECOMPRESSION_MULTIPLIER: usize = 3;
105
106impl WsClient {
107    pub(crate) async fn connect(url: Url) -> Result<Self> {
108        let config = WebSocketConfig {
109            max_message_size: None,
110            max_frame_size: None,
111            ..Default::default()
112        };
113        let (stream, _) = connect_async_with_config(url, Some(config), false).await?;
114
115        Ok(Self(stream))
116    }
117
118    #[cfg(feature = "client")]
119    pub(crate) async fn recv_json(&mut self) -> Result<Option<GatewayEvent>> {
120        let message = match timeout(TIMEOUT, self.0.next()).await {
121            Ok(Some(Ok(msg))) => msg,
122            Ok(Some(Err(e))) => return Err(e.into()),
123            Ok(None) | Err(_) => return Ok(None),
124        };
125
126        let value = match message {
127            Message::Binary(bytes) => {
128                let mut decompressed =
129                    String::with_capacity(bytes.len() * DECOMPRESSION_MULTIPLIER);
130
131                ZlibDecoder::new(&bytes[..]).read_to_string(&mut decompressed).map_err(|why| {
132                    warn!("Err decompressing bytes: {why:?}");
133                    debug!("Failing bytes: {bytes:?}");
134
135                    why
136                })?;
137
138                from_str(&decompressed).map_err(|why| {
139                    warn!("Err deserializing bytes: {why:?}");
140                    debug!("Failing bytes: {bytes:?}");
141
142                    why
143                })?
144            },
145            Message::Text(payload) => from_str(&payload).map_err(|why| {
146                warn!("Err deserializing text: {why:?}; text: {payload}");
147
148                why
149            })?,
150            Message::Close(Some(frame)) => {
151                return Err(Error::Gateway(GatewayError::Closed(Some(frame))));
152            },
153            _ => return Ok(None),
154        };
155
156        Ok(Some(value))
157    }
158
159    pub(crate) async fn send_json(&mut self, value: &impl serde::Serialize) -> Result<()> {
160        let message = to_string(value).map(Message::Text)?;
161
162        self.0.send(message).await?;
163        Ok(())
164    }
165
166    /// Delegate to `StreamExt::next`
167    #[cfg(feature = "client")]
168    pub(crate) async fn next(&mut self) -> Option<std::result::Result<Message, WsError>> {
169        self.0.next().await
170    }
171
172    /// Delegate to `SinkExt::send`
173    #[cfg(feature = "client")]
174    pub(crate) async fn send(&mut self, message: Message) -> Result<()> {
175        self.0.send(message).await?;
176        Ok(())
177    }
178
179    /// Delegate to `WebSocketStream::close`
180    #[cfg(feature = "client")]
181    pub(crate) async fn close(&mut self, msg: Option<CloseFrame<'_>>) -> Result<()> {
182        self.0.close(msg).await?;
183        Ok(())
184    }
185
186    /// # Errors
187    ///
188    /// Errors if there is a problem with the WS connection.
189    pub async fn send_chunk_guild(
190        &mut self,
191        guild_id: GuildId,
192        shard_info: &ShardInfo,
193        limit: Option<u16>,
194        presences: bool,
195        filter: ChunkGuildFilter,
196        nonce: Option<&str>,
197    ) -> Result<()> {
198        debug!("[{:?}] Requesting member chunks", shard_info);
199
200        let (query, user_ids) = match filter {
201            ChunkGuildFilter::None => (Some(String::new()), None),
202            ChunkGuildFilter::Query(query) => (Some(query), None),
203            ChunkGuildFilter::UserIds(user_ids) => (None, Some(user_ids)),
204        };
205
206        self.send_json(&WebSocketMessage {
207            op: Opcode::RequestGuildMembers,
208            d: WebSocketMessageData::ChunkGuild(ChunkGuildMessage {
209                guild_id,
210                query: query.as_deref(),
211                limit: limit.unwrap_or(0),
212                presences,
213                user_ids,
214                nonce: nonce.unwrap_or(""),
215            }),
216        })
217        .await
218    }
219
220    /// # Errors
221    ///
222    /// Errors if there is a problem with the WS connection.
223    pub async fn request_soundboard_sounds(
224        &mut self,
225        guild_ids: &[GuildId],
226        shard_info: &ShardInfo,
227    ) -> Result<()> {
228        debug!("[{:?}] Requesting soundboard sounds", shard_info);
229
230        self.send_json(&WebSocketMessage {
231            op: Opcode::ReqeustSoundboardSounds,
232            d: WebSocketMessageData::SoundboardSounds {
233                guild_ids,
234            },
235        })
236        .await
237    }
238
239    /// # Errors
240    ///
241    /// Errors if there is a problem with the WS connection.
242    #[instrument(skip(self))]
243    pub async fn send_heartbeat(&mut self, shard_info: &ShardInfo, seq: Option<u64>) -> Result<()> {
244        trace!("[{:?}] Sending heartbeat d: {:?}", shard_info, seq);
245
246        self.send_json(&WebSocketMessage {
247            op: Opcode::Heartbeat,
248            d: WebSocketMessageData::Heartbeat(seq),
249        })
250        .await
251    }
252
253    /// # Errors
254    ///
255    /// Errors if there is a problem with the WS connection.
256    #[instrument(skip(self, token))]
257    pub async fn send_identify(
258        &mut self,
259        shard: &ShardInfo,
260        token: &str,
261        intents: GatewayIntents,
262        presence: &PresenceData,
263    ) -> Result<()> {
264        let activities: Vec<_> = presence.activity.iter().collect();
265        let now = SystemTime::now();
266
267        debug!("[{:?}] Identifying", shard);
268
269        let msg = WebSocketMessage {
270            op: Opcode::Identify,
271            d: WebSocketMessageData::Identify {
272                token,
273                shard,
274                intents,
275                compress: true,
276                large_threshold: constants::LARGE_THRESHOLD,
277                properties: IdentifyProperties {
278                    browser: "serenity",
279                    device: "serenity",
280                    os: consts::OS,
281                },
282                presence: PresenceUpdateMessage {
283                    afk: false,
284                    since: now,
285                    status: presence.status.name(),
286                    activities: &activities,
287                },
288            },
289        };
290
291        self.send_json(&msg).await
292    }
293
294    /// # Errors
295    ///
296    /// Errors if there is a problem with the WS connection.
297    #[instrument(skip(self))]
298    pub async fn send_presence_update(
299        &mut self,
300        shard_info: &ShardInfo,
301        presence: &PresenceData,
302    ) -> Result<()> {
303        let activities: Vec<_> = presence.activity.iter().collect();
304        let now = SystemTime::now();
305
306        debug!("[{:?}] Sending presence update", shard_info);
307
308        self.send_json(&WebSocketMessage {
309            op: Opcode::PresenceUpdate,
310            d: WebSocketMessageData::PresenceUpdate(PresenceUpdateMessage {
311                afk: false,
312                since: now,
313                status: presence.status.name(),
314                activities: &activities,
315            }),
316        })
317        .await
318    }
319
320    /// # Errors
321    ///
322    /// Errors if there is a problem with the WS connection.
323    #[instrument(skip(self, token))]
324    pub async fn send_resume(
325        &mut self,
326        shard_info: &ShardInfo,
327        session_id: &str,
328        seq: u64,
329        token: &str,
330    ) -> Result<()> {
331        debug!("[{:?}] Sending resume; seq: {}", shard_info, seq);
332
333        self.send_json(&WebSocketMessage {
334            op: Opcode::Resume,
335            d: WebSocketMessageData::Resume {
336                session_id,
337                token,
338                seq,
339            },
340        })
341        .await
342    }
343}