1use std::env::consts;
2#[cfg(feature = "client")]
3use std::io::Read;
4use std::time::SystemTime;
5
6#[cfg(feature = "client")]
7use flate2::read::ZlibDecoder;
8use futures::SinkExt;
9#[cfg(feature = "client")]
10use futures::StreamExt;
11use tokio::net::TcpStream;
12#[cfg(feature = "client")]
13use tokio::time::{timeout, Duration};
14#[cfg(feature = "client")]
15use tokio_tungstenite::tungstenite::protocol::CloseFrame;
16use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
17#[cfg(feature = "client")]
18use tokio_tungstenite::tungstenite::Error as WsError;
19use tokio_tungstenite::tungstenite::Message;
20use tokio_tungstenite::{connect_async_with_config, MaybeTlsStream, WebSocketStream};
21#[cfg(feature = "client")]
22use tracing::warn;
23use tracing::{debug, instrument, trace};
24use url::Url;
25
26use super::{ActivityData, ChunkGuildFilter, PresenceData};
27use crate::constants::{self, Opcode};
28#[cfg(feature = "client")]
29use crate::gateway::GatewayError;
30#[cfg(feature = "client")]
31use crate::json::from_str;
32use crate::json::to_string;
33#[cfg(feature = "client")]
34use crate::model::event::GatewayEvent;
35use crate::model::gateway::{GatewayIntents, ShardInfo};
36use crate::model::id::{GuildId, UserId};
37#[cfg(feature = "client")]
38use crate::Error;
39use crate::Result;
40
41#[derive(Serialize)]
42struct IdentifyProperties {
43 browser: &'static str,
44 device: &'static str,
45 os: &'static str,
46}
47
48#[derive(Serialize)]
49struct ChunkGuildMessage<'a> {
50 guild_id: GuildId,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 query: Option<&'a str>,
53 limit: u16,
54 presences: bool,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 user_ids: Option<Vec<UserId>>,
57 nonce: &'a str,
58}
59
60#[derive(Serialize)]
61struct PresenceUpdateMessage<'a> {
62 afk: bool,
63 status: &'a str,
64 since: SystemTime,
65 activities: &'a [&'a ActivityData],
66}
67
68#[derive(Serialize)]
69#[serde(untagged)]
70enum WebSocketMessageData<'a> {
71 Heartbeat(Option<u64>),
72 ChunkGuild(ChunkGuildMessage<'a>),
73 SoundboardSounds {
74 guild_ids: &'a [GuildId],
75 },
76 Identify {
77 compress: bool,
78 token: &'a str,
79 large_threshold: u8,
80 shard: &'a ShardInfo,
81 intents: GatewayIntents,
82 properties: IdentifyProperties,
83 presence: PresenceUpdateMessage<'a>,
84 },
85 PresenceUpdate(PresenceUpdateMessage<'a>),
86 Resume {
87 session_id: &'a str,
88 token: &'a str,
89 seq: u64,
90 },
91}
92
93#[derive(Serialize)]
94struct WebSocketMessage<'a> {
95 op: Opcode,
96 d: WebSocketMessageData<'a>,
97}
98
99pub struct WsClient(WebSocketStream<MaybeTlsStream<TcpStream>>);
100
101#[cfg(feature = "client")]
102const TIMEOUT: Duration = Duration::from_millis(500);
103#[cfg(feature = "client")]
104const DECOMPRESSION_MULTIPLIER: usize = 3;
105
106impl WsClient {
107 pub(crate) async fn connect(url: Url) -> Result<Self> {
108 let config = WebSocketConfig {
109 max_message_size: None,
110 max_frame_size: None,
111 ..Default::default()
112 };
113 let (stream, _) = connect_async_with_config(url, Some(config), false).await?;
114
115 Ok(Self(stream))
116 }
117
118 #[cfg(feature = "client")]
119 pub(crate) async fn recv_json(&mut self) -> Result<Option<GatewayEvent>> {
120 let message = match timeout(TIMEOUT, self.0.next()).await {
121 Ok(Some(Ok(msg))) => msg,
122 Ok(Some(Err(e))) => return Err(e.into()),
123 Ok(None) | Err(_) => return Ok(None),
124 };
125
126 let value = match message {
127 Message::Binary(bytes) => {
128 let mut decompressed =
129 String::with_capacity(bytes.len() * DECOMPRESSION_MULTIPLIER);
130
131 ZlibDecoder::new(&bytes[..]).read_to_string(&mut decompressed).map_err(|why| {
132 warn!("Err decompressing bytes: {why:?}");
133 debug!("Failing bytes: {bytes:?}");
134
135 why
136 })?;
137
138 from_str(&decompressed).map_err(|why| {
139 warn!("Err deserializing bytes: {why:?}");
140 debug!("Failing bytes: {bytes:?}");
141
142 why
143 })?
144 },
145 Message::Text(payload) => from_str(&payload).map_err(|why| {
146 warn!("Err deserializing text: {why:?}; text: {payload}");
147
148 why
149 })?,
150 Message::Close(Some(frame)) => {
151 return Err(Error::Gateway(GatewayError::Closed(Some(frame))));
152 },
153 _ => return Ok(None),
154 };
155
156 Ok(Some(value))
157 }
158
159 pub(crate) async fn send_json(&mut self, value: &impl serde::Serialize) -> Result<()> {
160 let message = to_string(value).map(Message::Text)?;
161
162 self.0.send(message).await?;
163 Ok(())
164 }
165
166 #[cfg(feature = "client")]
168 pub(crate) async fn next(&mut self) -> Option<std::result::Result<Message, WsError>> {
169 self.0.next().await
170 }
171
172 #[cfg(feature = "client")]
174 pub(crate) async fn send(&mut self, message: Message) -> Result<()> {
175 self.0.send(message).await?;
176 Ok(())
177 }
178
179 #[cfg(feature = "client")]
181 pub(crate) async fn close(&mut self, msg: Option<CloseFrame<'_>>) -> Result<()> {
182 self.0.close(msg).await?;
183 Ok(())
184 }
185
186 pub async fn send_chunk_guild(
190 &mut self,
191 guild_id: GuildId,
192 shard_info: &ShardInfo,
193 limit: Option<u16>,
194 presences: bool,
195 filter: ChunkGuildFilter,
196 nonce: Option<&str>,
197 ) -> Result<()> {
198 debug!("[{:?}] Requesting member chunks", shard_info);
199
200 let (query, user_ids) = match filter {
201 ChunkGuildFilter::None => (Some(String::new()), None),
202 ChunkGuildFilter::Query(query) => (Some(query), None),
203 ChunkGuildFilter::UserIds(user_ids) => (None, Some(user_ids)),
204 };
205
206 self.send_json(&WebSocketMessage {
207 op: Opcode::RequestGuildMembers,
208 d: WebSocketMessageData::ChunkGuild(ChunkGuildMessage {
209 guild_id,
210 query: query.as_deref(),
211 limit: limit.unwrap_or(0),
212 presences,
213 user_ids,
214 nonce: nonce.unwrap_or(""),
215 }),
216 })
217 .await
218 }
219
220 pub async fn request_soundboard_sounds(
224 &mut self,
225 guild_ids: &[GuildId],
226 shard_info: &ShardInfo,
227 ) -> Result<()> {
228 debug!("[{:?}] Requesting soundboard sounds", shard_info);
229
230 self.send_json(&WebSocketMessage {
231 op: Opcode::ReqeustSoundboardSounds,
232 d: WebSocketMessageData::SoundboardSounds {
233 guild_ids,
234 },
235 })
236 .await
237 }
238
239 #[instrument(skip(self))]
243 pub async fn send_heartbeat(&mut self, shard_info: &ShardInfo, seq: Option<u64>) -> Result<()> {
244 trace!("[{:?}] Sending heartbeat d: {:?}", shard_info, seq);
245
246 self.send_json(&WebSocketMessage {
247 op: Opcode::Heartbeat,
248 d: WebSocketMessageData::Heartbeat(seq),
249 })
250 .await
251 }
252
253 #[instrument(skip(self, token))]
257 pub async fn send_identify(
258 &mut self,
259 shard: &ShardInfo,
260 token: &str,
261 intents: GatewayIntents,
262 presence: &PresenceData,
263 ) -> Result<()> {
264 let activities: Vec<_> = presence.activity.iter().collect();
265 let now = SystemTime::now();
266
267 debug!("[{:?}] Identifying", shard);
268
269 let msg = WebSocketMessage {
270 op: Opcode::Identify,
271 d: WebSocketMessageData::Identify {
272 token,
273 shard,
274 intents,
275 compress: true,
276 large_threshold: constants::LARGE_THRESHOLD,
277 properties: IdentifyProperties {
278 browser: "serenity",
279 device: "serenity",
280 os: consts::OS,
281 },
282 presence: PresenceUpdateMessage {
283 afk: false,
284 since: now,
285 status: presence.status.name(),
286 activities: &activities,
287 },
288 },
289 };
290
291 self.send_json(&msg).await
292 }
293
294 #[instrument(skip(self))]
298 pub async fn send_presence_update(
299 &mut self,
300 shard_info: &ShardInfo,
301 presence: &PresenceData,
302 ) -> Result<()> {
303 let activities: Vec<_> = presence.activity.iter().collect();
304 let now = SystemTime::now();
305
306 debug!("[{:?}] Sending presence update", shard_info);
307
308 self.send_json(&WebSocketMessage {
309 op: Opcode::PresenceUpdate,
310 d: WebSocketMessageData::PresenceUpdate(PresenceUpdateMessage {
311 afk: false,
312 since: now,
313 status: presence.status.name(),
314 activities: &activities,
315 }),
316 })
317 .await
318 }
319
320 #[instrument(skip(self, token))]
324 pub async fn send_resume(
325 &mut self,
326 shard_info: &ShardInfo,
327 session_id: &str,
328 seq: u64,
329 token: &str,
330 ) -> Result<()> {
331 debug!("[{:?}] Sending resume; seq: {}", shard_info, seq);
332
333 self.send_json(&WebSocketMessage {
334 op: Opcode::Resume,
335 d: WebSocketMessageData::Resume {
336 session_id,
337 token,
338 seq,
339 },
340 })
341 .await
342 }
343}