1use std::env::consts;
2#[cfg(feature = "client")]
3use std::io::Read;
4use std::time::SystemTime;
5
6#[cfg(feature = "client")]
7use flate2::read::ZlibDecoder;
8use futures::SinkExt;
9#[cfg(feature = "client")]
10use futures::StreamExt;
11use tokio::net::TcpStream;
12#[cfg(feature = "client")]
13use tokio::time::{timeout, Duration};
14#[cfg(feature = "client")]
15use tokio_tungstenite::tungstenite::protocol::CloseFrame;
16use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
17#[cfg(feature = "client")]
18use tokio_tungstenite::tungstenite::Error as WsError;
19use tokio_tungstenite::tungstenite::Message;
20use tokio_tungstenite::{connect_async_with_config, MaybeTlsStream, WebSocketStream};
21#[cfg(feature = "client")]
22use tracing::warn;
23use tracing::{debug, instrument, trace};
24use url::Url;
25
26use super::{ActivityData, ChunkGuildFilter, PresenceData};
27use crate::constants::{self, Opcode};
28#[cfg(feature = "client")]
29use crate::gateway::GatewayError;
30#[cfg(feature = "client")]
31use crate::json::from_str;
32use crate::json::to_string;
33#[cfg(feature = "client")]
34use crate::model::event::GatewayEvent;
35use crate::model::gateway::{GatewayIntents, ShardInfo};
36use crate::model::id::{GuildId, UserId};
37#[cfg(feature = "client")]
38use crate::Error;
39use crate::Result;
40
41#[derive(Serialize)]
42struct IdentifyProperties {
43 browser: &'static str,
44 device: &'static str,
45 os: &'static str,
46}
47
48#[derive(Serialize)]
49struct ChunkGuildMessage<'a> {
50 guild_id: GuildId,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 query: Option<&'a str>,
53 limit: u16,
54 presences: bool,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 user_ids: Option<Vec<UserId>>,
57 nonce: &'a str,
58}
59
60#[derive(Serialize)]
61struct PresenceUpdateMessage<'a> {
62 afk: bool,
63 status: &'a str,
64 since: SystemTime,
65 activities: &'a [&'a ActivityData],
66}
67
68#[derive(Serialize)]
69#[serde(untagged)]
70enum WebSocketMessageData<'a> {
71 Heartbeat(Option<u64>),
72 ChunkGuild(ChunkGuildMessage<'a>),
73 Identify {
74 compress: bool,
75 token: &'a str,
76 large_threshold: u8,
77 shard: &'a ShardInfo,
78 intents: GatewayIntents,
79 properties: IdentifyProperties,
80 presence: PresenceUpdateMessage<'a>,
81 },
82 PresenceUpdate(PresenceUpdateMessage<'a>),
83 Resume {
84 session_id: &'a str,
85 token: &'a str,
86 seq: u64,
87 },
88}
89
90#[derive(Serialize)]
91struct WebSocketMessage<'a> {
92 op: Opcode,
93 d: WebSocketMessageData<'a>,
94}
95
96pub struct WsClient(WebSocketStream<MaybeTlsStream<TcpStream>>);
97
98#[cfg(feature = "client")]
99const TIMEOUT: Duration = Duration::from_millis(500);
100#[cfg(feature = "client")]
101const DECOMPRESSION_MULTIPLIER: usize = 3;
102
103impl WsClient {
104 pub(crate) async fn connect(url: Url) -> Result<Self> {
105 let config = WebSocketConfig {
106 max_message_size: None,
107 max_frame_size: None,
108 ..Default::default()
109 };
110 let (stream, _) = connect_async_with_config(url, Some(config), false).await?;
111
112 Ok(Self(stream))
113 }
114
115 #[cfg(feature = "client")]
116 pub(crate) async fn recv_json(&mut self) -> Result<Option<GatewayEvent>> {
117 let message = match timeout(TIMEOUT, self.0.next()).await {
118 Ok(Some(Ok(msg))) => msg,
119 Ok(Some(Err(e))) => return Err(e.into()),
120 Ok(None) | Err(_) => return Ok(None),
121 };
122
123 let value = match message {
124 Message::Binary(bytes) => {
125 let mut decompressed =
126 String::with_capacity(bytes.len() * DECOMPRESSION_MULTIPLIER);
127
128 ZlibDecoder::new(&bytes[..]).read_to_string(&mut decompressed).map_err(|why| {
129 warn!("Err decompressing bytes: {why:?}");
130 debug!("Failing bytes: {bytes:?}");
131
132 why
133 })?;
134
135 from_str(&decompressed).map_err(|why| {
136 warn!("Err deserializing bytes: {why:?}");
137 debug!("Failing bytes: {bytes:?}");
138
139 why
140 })?
141 },
142 Message::Text(payload) => from_str(&payload).map_err(|why| {
143 warn!("Err deserializing text: {why:?}; text: {payload}");
144
145 why
146 })?,
147 Message::Close(Some(frame)) => {
148 return Err(Error::Gateway(GatewayError::Closed(Some(frame))));
149 },
150 _ => return Ok(None),
151 };
152
153 Ok(Some(value))
154 }
155
156 pub(crate) async fn send_json(&mut self, value: &impl serde::Serialize) -> Result<()> {
157 let message = to_string(value).map(Message::Text)?;
158
159 self.0.send(message).await?;
160 Ok(())
161 }
162
163 #[cfg(feature = "client")]
165 pub(crate) async fn next(&mut self) -> Option<std::result::Result<Message, WsError>> {
166 self.0.next().await
167 }
168
169 #[cfg(feature = "client")]
171 pub(crate) async fn send(&mut self, message: Message) -> Result<()> {
172 self.0.send(message).await?;
173 Ok(())
174 }
175
176 #[cfg(feature = "client")]
178 pub(crate) async fn close(&mut self, msg: Option<CloseFrame<'_>>) -> Result<()> {
179 self.0.close(msg).await?;
180 Ok(())
181 }
182
183 #[allow(clippy::missing_errors_doc)]
184 pub async fn send_chunk_guild(
185 &mut self,
186 guild_id: GuildId,
187 shard_info: &ShardInfo,
188 limit: Option<u16>,
189 presences: bool,
190 filter: ChunkGuildFilter,
191 nonce: Option<&str>,
192 ) -> Result<()> {
193 debug!("[{:?}] Requesting member chunks", shard_info);
194
195 let (query, user_ids) = match filter {
196 ChunkGuildFilter::None => (Some(String::new()), None),
197 ChunkGuildFilter::Query(query) => (Some(query), None),
198 ChunkGuildFilter::UserIds(user_ids) => (None, Some(user_ids)),
199 };
200
201 self.send_json(&WebSocketMessage {
202 op: Opcode::RequestGuildMembers,
203 d: WebSocketMessageData::ChunkGuild(ChunkGuildMessage {
204 guild_id,
205 query: query.as_deref(),
206 limit: limit.unwrap_or(0),
207 presences,
208 user_ids,
209 nonce: nonce.unwrap_or(""),
210 }),
211 })
212 .await
213 }
214
215 #[instrument(skip(self))]
216 pub async fn send_heartbeat(&mut self, shard_info: &ShardInfo, seq: Option<u64>) -> Result<()> {
217 trace!("[{:?}] Sending heartbeat d: {:?}", shard_info, seq);
218
219 self.send_json(&WebSocketMessage {
220 op: Opcode::Heartbeat,
221 d: WebSocketMessageData::Heartbeat(seq),
222 })
223 .await
224 }
225
226 #[instrument(skip(self, token))]
227 pub async fn send_identify(
228 &mut self,
229 shard: &ShardInfo,
230 token: &str,
231 intents: GatewayIntents,
232 presence: &PresenceData,
233 ) -> Result<()> {
234 let activities: Vec<_> = presence.activity.iter().collect();
235 let now = SystemTime::now();
236
237 debug!("[{:?}] Identifying", shard);
238
239 let msg = WebSocketMessage {
240 op: Opcode::Identify,
241 d: WebSocketMessageData::Identify {
242 token,
243 shard,
244 intents,
245 compress: true,
246 large_threshold: constants::LARGE_THRESHOLD,
247 properties: IdentifyProperties {
248 browser: "serenity",
249 device: "serenity",
250 os: consts::OS,
251 },
252 presence: PresenceUpdateMessage {
253 afk: false,
254 since: now,
255 status: presence.status.name(),
256 activities: &activities,
257 },
258 },
259 };
260
261 self.send_json(&msg).await
262 }
263
264 #[instrument(skip(self))]
265 pub async fn send_presence_update(
266 &mut self,
267 shard_info: &ShardInfo,
268 presence: &PresenceData,
269 ) -> Result<()> {
270 let activities: Vec<_> = presence.activity.iter().collect();
271 let now = SystemTime::now();
272
273 debug!("[{:?}] Sending presence update", shard_info);
274
275 self.send_json(&WebSocketMessage {
276 op: Opcode::PresenceUpdate,
277 d: WebSocketMessageData::PresenceUpdate(PresenceUpdateMessage {
278 afk: false,
279 since: now,
280 status: presence.status.name(),
281 activities: &activities,
282 }),
283 })
284 .await
285 }
286
287 #[instrument(skip(self, token))]
288 pub async fn send_resume(
289 &mut self,
290 shard_info: &ShardInfo,
291 session_id: &str,
292 seq: u64,
293 token: &str,
294 ) -> Result<()> {
295 debug!("[{:?}] Sending resume; seq: {}", shard_info, seq);
296
297 self.send_json(&WebSocketMessage {
298 op: Opcode::Resume,
299 d: WebSocketMessageData::Resume {
300 session_id,
301 token,
302 seq,
303 },
304 })
305 .await
306 }
307}