serenity/gateway/
ws.rs

1use std::env::consts;
2#[cfg(feature = "client")]
3use std::io::Read;
4use std::time::SystemTime;
5
6#[cfg(feature = "client")]
7use flate2::read::ZlibDecoder;
8use futures::SinkExt;
9#[cfg(feature = "client")]
10use futures::StreamExt;
11use tokio::net::TcpStream;
12#[cfg(feature = "client")]
13use tokio::time::{timeout, Duration};
14#[cfg(feature = "client")]
15use tokio_tungstenite::tungstenite::protocol::CloseFrame;
16use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
17#[cfg(feature = "client")]
18use tokio_tungstenite::tungstenite::Error as WsError;
19use tokio_tungstenite::tungstenite::Message;
20use tokio_tungstenite::{connect_async_with_config, MaybeTlsStream, WebSocketStream};
21#[cfg(feature = "client")]
22use tracing::warn;
23use tracing::{debug, instrument, trace};
24use url::Url;
25
26use super::{ActivityData, ChunkGuildFilter, PresenceData};
27use crate::constants::{self, Opcode};
28#[cfg(feature = "client")]
29use crate::gateway::GatewayError;
30#[cfg(feature = "client")]
31use crate::json::from_str;
32use crate::json::to_string;
33#[cfg(feature = "client")]
34use crate::model::event::GatewayEvent;
35use crate::model::gateway::{GatewayIntents, ShardInfo};
36use crate::model::id::{GuildId, UserId};
37#[cfg(feature = "client")]
38use crate::Error;
39use crate::Result;
40
41#[derive(Serialize)]
42struct IdentifyProperties {
43    browser: &'static str,
44    device: &'static str,
45    os: &'static str,
46}
47
48#[derive(Serialize)]
49struct ChunkGuildMessage<'a> {
50    guild_id: GuildId,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    query: Option<&'a str>,
53    limit: u16,
54    presences: bool,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    user_ids: Option<Vec<UserId>>,
57    nonce: &'a str,
58}
59
60#[derive(Serialize)]
61struct PresenceUpdateMessage<'a> {
62    afk: bool,
63    status: &'a str,
64    since: SystemTime,
65    activities: &'a [&'a ActivityData],
66}
67
68#[derive(Serialize)]
69#[serde(untagged)]
70enum WebSocketMessageData<'a> {
71    Heartbeat(Option<u64>),
72    ChunkGuild(ChunkGuildMessage<'a>),
73    Identify {
74        compress: bool,
75        token: &'a str,
76        large_threshold: u8,
77        shard: &'a ShardInfo,
78        intents: GatewayIntents,
79        properties: IdentifyProperties,
80        presence: PresenceUpdateMessage<'a>,
81    },
82    PresenceUpdate(PresenceUpdateMessage<'a>),
83    Resume {
84        session_id: &'a str,
85        token: &'a str,
86        seq: u64,
87    },
88}
89
90#[derive(Serialize)]
91struct WebSocketMessage<'a> {
92    op: Opcode,
93    d: WebSocketMessageData<'a>,
94}
95
96pub struct WsClient(WebSocketStream<MaybeTlsStream<TcpStream>>);
97
98#[cfg(feature = "client")]
99const TIMEOUT: Duration = Duration::from_millis(500);
100#[cfg(feature = "client")]
101const DECOMPRESSION_MULTIPLIER: usize = 3;
102
103impl WsClient {
104    pub(crate) async fn connect(url: Url) -> Result<Self> {
105        let config = WebSocketConfig {
106            max_message_size: None,
107            max_frame_size: None,
108            ..Default::default()
109        };
110        let (stream, _) = connect_async_with_config(url, Some(config), false).await?;
111
112        Ok(Self(stream))
113    }
114
115    #[cfg(feature = "client")]
116    pub(crate) async fn recv_json(&mut self) -> Result<Option<GatewayEvent>> {
117        let message = match timeout(TIMEOUT, self.0.next()).await {
118            Ok(Some(Ok(msg))) => msg,
119            Ok(Some(Err(e))) => return Err(e.into()),
120            Ok(None) | Err(_) => return Ok(None),
121        };
122
123        let value = match message {
124            Message::Binary(bytes) => {
125                let mut decompressed =
126                    String::with_capacity(bytes.len() * DECOMPRESSION_MULTIPLIER);
127
128                ZlibDecoder::new(&bytes[..]).read_to_string(&mut decompressed).map_err(|why| {
129                    warn!("Err decompressing bytes: {why:?}");
130                    debug!("Failing bytes: {bytes:?}");
131
132                    why
133                })?;
134
135                from_str(&decompressed).map_err(|why| {
136                    warn!("Err deserializing bytes: {why:?}");
137                    debug!("Failing bytes: {bytes:?}");
138
139                    why
140                })?
141            },
142            Message::Text(payload) => from_str(&payload).map_err(|why| {
143                warn!("Err deserializing text: {why:?}; text: {payload}");
144
145                why
146            })?,
147            Message::Close(Some(frame)) => {
148                return Err(Error::Gateway(GatewayError::Closed(Some(frame))));
149            },
150            _ => return Ok(None),
151        };
152
153        Ok(Some(value))
154    }
155
156    pub(crate) async fn send_json(&mut self, value: &impl serde::Serialize) -> Result<()> {
157        let message = to_string(value).map(Message::Text)?;
158
159        self.0.send(message).await?;
160        Ok(())
161    }
162
163    /// Delegate to `StreamExt::next`
164    #[cfg(feature = "client")]
165    pub(crate) async fn next(&mut self) -> Option<std::result::Result<Message, WsError>> {
166        self.0.next().await
167    }
168
169    /// Delegate to `SinkExt::send`
170    #[cfg(feature = "client")]
171    pub(crate) async fn send(&mut self, message: Message) -> Result<()> {
172        self.0.send(message).await?;
173        Ok(())
174    }
175
176    /// Delegate to `WebSocketStream::close`
177    #[cfg(feature = "client")]
178    pub(crate) async fn close(&mut self, msg: Option<CloseFrame<'_>>) -> Result<()> {
179        self.0.close(msg).await?;
180        Ok(())
181    }
182
183    #[allow(clippy::missing_errors_doc)]
184    pub async fn send_chunk_guild(
185        &mut self,
186        guild_id: GuildId,
187        shard_info: &ShardInfo,
188        limit: Option<u16>,
189        presences: bool,
190        filter: ChunkGuildFilter,
191        nonce: Option<&str>,
192    ) -> Result<()> {
193        debug!("[{:?}] Requesting member chunks", shard_info);
194
195        let (query, user_ids) = match filter {
196            ChunkGuildFilter::None => (Some(String::new()), None),
197            ChunkGuildFilter::Query(query) => (Some(query), None),
198            ChunkGuildFilter::UserIds(user_ids) => (None, Some(user_ids)),
199        };
200
201        self.send_json(&WebSocketMessage {
202            op: Opcode::RequestGuildMembers,
203            d: WebSocketMessageData::ChunkGuild(ChunkGuildMessage {
204                guild_id,
205                query: query.as_deref(),
206                limit: limit.unwrap_or(0),
207                presences,
208                user_ids,
209                nonce: nonce.unwrap_or(""),
210            }),
211        })
212        .await
213    }
214
215    #[instrument(skip(self))]
216    pub async fn send_heartbeat(&mut self, shard_info: &ShardInfo, seq: Option<u64>) -> Result<()> {
217        trace!("[{:?}] Sending heartbeat d: {:?}", shard_info, seq);
218
219        self.send_json(&WebSocketMessage {
220            op: Opcode::Heartbeat,
221            d: WebSocketMessageData::Heartbeat(seq),
222        })
223        .await
224    }
225
226    #[instrument(skip(self, token))]
227    pub async fn send_identify(
228        &mut self,
229        shard: &ShardInfo,
230        token: &str,
231        intents: GatewayIntents,
232        presence: &PresenceData,
233    ) -> Result<()> {
234        let activities: Vec<_> = presence.activity.iter().collect();
235        let now = SystemTime::now();
236
237        debug!("[{:?}] Identifying", shard);
238
239        let msg = WebSocketMessage {
240            op: Opcode::Identify,
241            d: WebSocketMessageData::Identify {
242                token,
243                shard,
244                intents,
245                compress: true,
246                large_threshold: constants::LARGE_THRESHOLD,
247                properties: IdentifyProperties {
248                    browser: "serenity",
249                    device: "serenity",
250                    os: consts::OS,
251                },
252                presence: PresenceUpdateMessage {
253                    afk: false,
254                    since: now,
255                    status: presence.status.name(),
256                    activities: &activities,
257                },
258            },
259        };
260
261        self.send_json(&msg).await
262    }
263
264    #[instrument(skip(self))]
265    pub async fn send_presence_update(
266        &mut self,
267        shard_info: &ShardInfo,
268        presence: &PresenceData,
269    ) -> Result<()> {
270        let activities: Vec<_> = presence.activity.iter().collect();
271        let now = SystemTime::now();
272
273        debug!("[{:?}] Sending presence update", shard_info);
274
275        self.send_json(&WebSocketMessage {
276            op: Opcode::PresenceUpdate,
277            d: WebSocketMessageData::PresenceUpdate(PresenceUpdateMessage {
278                afk: false,
279                since: now,
280                status: presence.status.name(),
281                activities: &activities,
282            }),
283        })
284        .await
285    }
286
287    #[instrument(skip(self, token))]
288    pub async fn send_resume(
289        &mut self,
290        shard_info: &ShardInfo,
291        session_id: &str,
292        seq: u64,
293        token: &str,
294    ) -> Result<()> {
295        debug!("[{:?}] Sending resume; seq: {}", shard_info, seq);
296
297        self.send_json(&WebSocketMessage {
298            op: Opcode::Resume,
299            d: WebSocketMessageData::Resume {
300                session_id,
301                token,
302                seq,
303            },
304        })
305        .await
306    }
307}