serenity/http/ratelimiting.rs
1//! Routes are used for ratelimiting. These are to differentiate between the different _types_ of
2//! routes - such as getting the current user's channels - for the most part, with the exception
3//! being major parameters.
4//!
5//! [Taken from] the Discord docs, major parameters are:
6//!
7//! > Additionally, rate limits take into account major parameters in the URL. For example,
8//! > `/channels/:channel_id` and `/channels/:channel_id/messages/:message_id` both take
9//! > `channel_id` into account when generating rate limits since it's the major parameter. The
10//! > only current major parameters are `channel_id`, `guild_id` and `webhook_id`.
11//!
12//! This results in the two URLs of `GET /channels/4/messages/7` and `GET /channels/5/messages/8`
13//! being rate limited _separately_. However, the two URLs of `GET /channels/10/messages/11` and
14//! `GET /channels/10/messages/12` will count towards the "same ratelimit", as the major parameter
15//! - `10` is equivalent in both URLs' format.
16//!
17//! # Examples
18//!
19//! First: taking the first two URLs - `GET /channels/4/messages/7` and `GET
20//! /channels/5/messages/8` - and assuming both buckets have a `limit` of `10`, requesting the
21//! first URL will result in the response containing a `remaining` of `9`. Immediately after -
22//! prior to buckets resetting - performing a request to the _second_ URL will also contain a
23//! `remaining` of `9` in the response, as the major parameter - `channel_id` - is different in the
24//! two requests (`4` and `5`).
25//!
26//! Second: take for example the last two URLs. Assuming the bucket's `limit` is `10`, requesting
27//! the first URL will return a `remaining` of `9` in the response. Immediately after - prior to
28//! buckets resetting - performing a request to the _second_ URL will return a `remaining` of `8`
29//! in the response, as the major parameter - `channel_id` - is equivalent for the two requests
30//! (`10`).
31//!
32//! Major parameters are why some variants (i.e. all of the channel/guild variants) have an
33//! associated u64 as data. This is the Id of the parameter, differentiating between different
34//! ratelimits.
35//!
36//! [Taken from]: https://discord.com/developers/docs/topics/rate-limits#rate-limits
37
38use std::collections::HashMap;
39use std::fmt;
40use std::str::{self, FromStr};
41use std::sync::Arc;
42use std::time::SystemTime;
43
44use reqwest::header::HeaderMap;
45use reqwest::{Client, Response, StatusCode};
46use secrecy::{ExposeSecret, SecretString};
47use tokio::sync::{Mutex, RwLock};
48use tokio::time::{sleep, Duration};
49use tracing::{debug, instrument};
50
51pub use super::routing::RatelimitingBucket;
52use super::{HttpError, LightMethod, Request};
53use crate::internal::prelude::*;
54
55/// Passed to the [`Ratelimiter::set_ratelimit_callback`] callback. If using Client, that callback
56/// is initialized to call the `EventHandler::ratelimit()` method.
57#[derive(Clone, Debug)]
58#[non_exhaustive]
59pub struct RatelimitInfo {
60 pub timeout: std::time::Duration,
61 pub limit: i64,
62 pub method: LightMethod,
63 pub path: String,
64 pub global: bool,
65}
66
67/// Ratelimiter for requests to the Discord API.
68///
69/// This keeps track of ratelimit data for known routes through the [`Ratelimit`] implementation
70/// for each route: how many tickets are [`remaining`] until the user needs to wait for the known
71/// [`reset`] time, and the [`limit`] of requests that can be made within that time.
72///
73/// When no tickets are available for some time, then the thread sleeps until that time passes. The
74/// mechanism is known as "pre-emptive ratelimiting".
75///
76/// Occasionally for very high traffic bots, a global ratelimit may be reached which blocks all
77/// future requests until the global ratelimit is over, regardless of route. The value of this
78/// global ratelimit is never given through the API, so it can't be pre-emptively ratelimited. This
79/// only affects the largest of bots.
80///
81/// [`limit`]: Ratelimit::limit
82/// [`remaining`]: Ratelimit::remaining
83/// [`reset`]: Ratelimit::reset
84pub struct Ratelimiter {
85 client: Client,
86 global: Arc<Mutex<()>>,
87 // When futures is implemented, make tasks clear out their respective entry when the 'reset'
88 // passes.
89 routes: Arc<RwLock<HashMap<RatelimitingBucket, Arc<Mutex<Ratelimit>>>>>,
90 token: SecretString,
91 absolute_ratelimits: bool,
92 ratelimit_callback: Box<dyn Fn(RatelimitInfo) + Send + Sync>,
93}
94
95impl fmt::Debug for Ratelimiter {
96 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97 f.debug_struct("Ratelimiter")
98 .field("client", &self.client)
99 .field("global", &self.global)
100 .field("routes", &self.routes)
101 .field("token", &self.token)
102 .field("absolute_ratelimits", &self.absolute_ratelimits)
103 .field("ratelimit_callback", &"Fn(RatelimitInfo)")
104 .finish()
105 }
106}
107
108impl Ratelimiter {
109 /// Creates a new ratelimiter, with a shared [`reqwest`] client and the bot's token.
110 ///
111 /// The bot token must be prefixed with `"Bot "`. The ratelimiter does not prefix it.
112 #[must_use]
113 pub fn new(client: Client, token: impl Into<String>) -> Self {
114 Self::new_(client, token.into())
115 }
116
117 fn new_(client: Client, token: String) -> Self {
118 Self {
119 client,
120 global: Arc::default(),
121 routes: Arc::default(),
122 token: SecretString::new(token),
123 ratelimit_callback: Box::new(|_| {}),
124 absolute_ratelimits: false,
125 }
126 }
127
128 /// Sets a callback to be called when a route is rate limited.
129 pub fn set_ratelimit_callback(
130 &mut self,
131 ratelimit_callback: Box<dyn Fn(RatelimitInfo) + Send + Sync>,
132 ) {
133 self.ratelimit_callback = ratelimit_callback;
134 }
135
136 // Sets whether absolute ratelimits should be used.
137 pub fn set_absolute_ratelimits(&mut self, absolute_ratelimits: bool) {
138 self.absolute_ratelimits = absolute_ratelimits;
139 }
140
141 /// The routes mutex is a HashMap of each [`RatelimitingBucket`] and their respective ratelimit
142 /// information.
143 ///
144 /// See the documentation for [`Ratelimit`] for more information on how the library handles
145 /// ratelimiting.
146 ///
147 /// # Examples
148 ///
149 /// View the `reset` time of the route for `ChannelsId(7)`:
150 ///
151 /// ```rust,no_run
152 /// use serenity::http::Route;
153 /// # use serenity::http::Http;
154 /// # use serenity::model::prelude::*;
155 ///
156 /// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
157 /// # let http: Http = unimplemented!();
158 /// let routes = http.ratelimiter.unwrap().routes();
159 /// let reader = routes.read().await;
160 ///
161 /// let channel_id = ChannelId::new(7);
162 /// let route = Route::Channel {
163 /// channel_id,
164 /// };
165 /// if let Some(route) = reader.get(&route.ratelimiting_bucket()) {
166 /// if let Some(reset) = route.lock().await.reset() {
167 /// println!("Reset time at: {:?}", reset);
168 /// }
169 /// }
170 /// # Ok(())
171 /// # }
172 /// ```
173 #[must_use]
174 pub fn routes(&self) -> Arc<RwLock<HashMap<RatelimitingBucket, Arc<Mutex<Ratelimit>>>>> {
175 Arc::clone(&self.routes)
176 }
177
178 /// # Errors
179 ///
180 /// Only error kind that may be returned is [`Error::Http`].
181 #[instrument]
182 pub async fn perform(&self, req: Request<'_>) -> Result<Response> {
183 loop {
184 // This will block if another thread hit the global ratelimit.
185 drop(self.global.lock().await);
186
187 // Perform pre-checking here:
188 // - get the route's relevant rate
189 // - sleep if that route's already rate-limited until the end of the 'reset' time;
190 // - get the global rate;
191 // - sleep if there is 0 remaining
192 // - then, perform the request
193 let ratelimiting_bucket = req.route.ratelimiting_bucket();
194 let bucket =
195 Arc::clone(self.routes.write().await.entry(ratelimiting_bucket).or_default());
196
197 bucket.lock().await.pre_hook(&req, &self.ratelimit_callback).await;
198
199 let request = req.clone().build(&self.client, self.token.expose_secret(), None)?;
200 let response = self.client.execute(request.build()?).await?;
201
202 // Check if the request got ratelimited by checking for status 429, and if so, sleep
203 // for the value of the header 'retry-after' - which is in milliseconds - and then
204 // `continue` to try again
205 //
206 // If it didn't ratelimit, subtract one from the Ratelimit's 'remaining'.
207 //
208 // Update `reset` with the value of 'x-ratelimit-reset' header. Similarly, update
209 // `reset-after` with the 'x-ratelimit-reset-after' header.
210 //
211 // It _may_ be possible for the limit to be raised at any time, so check if it did from
212 // the value of the 'x-ratelimit-limit' header. If the limit was 5 and is now 7, add 2
213 // to the 'remaining'
214 if ratelimiting_bucket.is_none() {
215 return Ok(response);
216 }
217
218 let redo = if response.headers().get("x-ratelimit-global").is_some() {
219 drop(self.global.lock().await);
220
221 Ok(
222 if let Some(retry_after) =
223 parse_header::<f64>(response.headers(), "retry-after")?
224 {
225 debug!(
226 "Ratelimited on route {:?} for {:?}s",
227 ratelimiting_bucket, retry_after
228 );
229 (self.ratelimit_callback)(RatelimitInfo {
230 timeout: Duration::from_secs_f64(retry_after),
231 limit: 50,
232 method: req.method,
233 path: req.route.path().to_string(),
234 global: true,
235 });
236 sleep(Duration::from_secs_f64(retry_after)).await;
237
238 true
239 } else {
240 false
241 },
242 )
243 } else {
244 bucket
245 .lock()
246 .await
247 .post_hook(&response, &req, &self.ratelimit_callback, self.absolute_ratelimits)
248 .await
249 };
250
251 if !redo.unwrap_or(true) {
252 return Ok(response);
253 }
254 }
255 }
256}
257
258/// A set of data containing information about the ratelimits for a particular
259/// [`RatelimitingBucket`], which is stored in [`Http`].
260///
261/// See the [Discord docs] on ratelimits for more information.
262///
263/// **Note**: You should _not_ mutate any of the fields, as this can help cause 429s.
264///
265/// [`Http`]: super::Http
266/// [Discord docs]: https://discord.com/developers/docs/topics/rate-limits
267#[derive(Debug)]
268pub struct Ratelimit {
269 /// The total number of requests that can be made in a period of time.
270 limit: i64,
271 /// The number of requests remaining in the period of time.
272 remaining: i64,
273 /// The absolute time when the interval resets.
274 reset: Option<SystemTime>,
275 /// The total time when the interval resets.
276 reset_after: Option<Duration>,
277}
278
279impl Ratelimit {
280 #[instrument(skip(ratelimit_callback))]
281 pub async fn pre_hook(
282 &mut self,
283 req: &Request<'_>,
284 ratelimit_callback: &(dyn Fn(RatelimitInfo) + Send + Sync),
285 ) {
286 if self.limit() == 0 {
287 return;
288 }
289
290 let Some(reset) = self.reset else {
291 // We're probably in the past.
292 self.remaining = self.limit;
293 return;
294 };
295
296 let Ok(delay) = reset.duration_since(SystemTime::now()) else {
297 // if duration is negative (i.e. adequate time has passed since last call to this api)
298 if self.remaining() != 0 {
299 self.remaining -= 1;
300 }
301 return;
302 };
303
304 if self.remaining() == 0 {
305 debug!(
306 "Pre-emptive ratelimit on route {:?} for {}ms",
307 req.route.ratelimiting_bucket(),
308 delay.as_millis(),
309 );
310 ratelimit_callback(RatelimitInfo {
311 timeout: delay,
312 limit: self.limit,
313 method: req.method,
314 path: req.route.path().to_string(),
315 global: false,
316 });
317
318 sleep(delay).await;
319
320 return;
321 }
322
323 self.remaining -= 1;
324 }
325
326 #[instrument(skip(ratelimit_callback))]
327 pub async fn post_hook(
328 &mut self,
329 response: &Response,
330 req: &Request<'_>,
331 ratelimit_callback: &(dyn Fn(RatelimitInfo) + Send + Sync),
332 absolute_ratelimits: bool,
333 ) -> Result<bool> {
334 if let Some(limit) = parse_header(response.headers(), "x-ratelimit-limit")? {
335 self.limit = limit;
336 }
337
338 if let Some(remaining) = parse_header(response.headers(), "x-ratelimit-remaining")? {
339 self.remaining = remaining;
340 }
341
342 if absolute_ratelimits {
343 if let Some(reset) = parse_header::<f64>(response.headers(), "x-ratelimit-reset")? {
344 self.reset = Some(std::time::UNIX_EPOCH + Duration::from_secs_f64(reset));
345 }
346 }
347
348 if let Some(reset_after) =
349 parse_header::<f64>(response.headers(), "x-ratelimit-reset-after")?
350 {
351 if !absolute_ratelimits {
352 self.reset = Some(SystemTime::now() + Duration::from_secs_f64(reset_after));
353 }
354
355 self.reset_after = Some(Duration::from_secs_f64(reset_after));
356 }
357
358 Ok(if response.status() != StatusCode::TOO_MANY_REQUESTS {
359 false
360 } else if let Some(retry_after) = parse_header::<f64>(response.headers(), "retry-after")? {
361 debug!(
362 "Ratelimited on route {:?} for {:?}s",
363 req.route.ratelimiting_bucket(),
364 retry_after
365 );
366 ratelimit_callback(RatelimitInfo {
367 timeout: Duration::from_secs_f64(retry_after),
368 limit: self.limit,
369 method: req.method,
370 path: req.route.path().to_string(),
371 global: false,
372 });
373
374 sleep(Duration::from_secs_f64(retry_after)).await;
375
376 true
377 } else {
378 false
379 })
380 }
381
382 /// The total number of requests that can be made in a period of time.
383 #[inline]
384 #[must_use]
385 pub const fn limit(&self) -> i64 {
386 self.limit
387 }
388
389 /// The number of requests remaining in the period of time.
390 #[inline]
391 #[must_use]
392 pub const fn remaining(&self) -> i64 {
393 self.remaining
394 }
395
396 /// The absolute time in milliseconds when the interval resets.
397 #[inline]
398 #[must_use]
399 pub const fn reset(&self) -> Option<SystemTime> {
400 self.reset
401 }
402
403 /// The total time in milliseconds when the interval resets.
404 #[inline]
405 #[must_use]
406 pub const fn reset_after(&self) -> Option<Duration> {
407 self.reset_after
408 }
409}
410
411impl Default for Ratelimit {
412 fn default() -> Self {
413 Self {
414 limit: i64::MAX,
415 remaining: i64::MAX,
416 reset: None,
417 reset_after: None,
418 }
419 }
420}
421
422fn parse_header<T: FromStr>(headers: &HeaderMap, header: &str) -> Result<Option<T>> {
423 let Some(header) = headers.get(header) else { return Ok(None) };
424
425 let unicode =
426 str::from_utf8(header.as_bytes()).map_err(|_| Error::from(HttpError::RateLimitUtf8))?;
427
428 let num = unicode.parse().map_err(|_| Error::from(HttpError::RateLimitI64F64))?;
429
430 Ok(Some(num))
431}
432
433#[cfg(test)]
434mod tests {
435 use std::error::Error as StdError;
436 use std::result::Result as StdResult;
437
438 use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
439
440 use super::parse_header;
441 use crate::error::Error;
442 use crate::http::HttpError;
443
444 type Result<T> = StdResult<T, Box<dyn StdError>>;
445
446 fn headers() -> HeaderMap {
447 let pairs = &[
448 (HeaderName::from_static("x-ratelimit-limit"), HeaderValue::from_static("5")),
449 (HeaderName::from_static("x-ratelimit-remaining"), HeaderValue::from_static("4")),
450 (
451 HeaderName::from_static("x-ratelimit-reset"),
452 HeaderValue::from_static("1560704880.423"),
453 ),
454 (HeaderName::from_static("x-bad-num"), HeaderValue::from_static("abc")),
455 (
456 HeaderName::from_static("x-bad-unicode"),
457 HeaderValue::from_bytes(&[255, 255, 255, 255]).unwrap(),
458 ),
459 ];
460
461 let mut map = HeaderMap::with_capacity(pairs.len());
462
463 for (name, val) in pairs {
464 map.insert(name, val.clone());
465 }
466
467 map
468 }
469
470 #[test]
471 #[allow(clippy::float_cmp)]
472 fn test_parse_header_good() -> Result<()> {
473 let headers = headers();
474
475 assert_eq!(parse_header::<i64>(&headers, "x-ratelimit-limit")?.unwrap(), 5);
476 assert_eq!(parse_header::<i64>(&headers, "x-ratelimit-remaining")?.unwrap(), 4,);
477 assert_eq!(parse_header::<f64>(&headers, "x-ratelimit-reset")?.unwrap(), 1_560_704_880.423);
478
479 Ok(())
480 }
481
482 #[test]
483 fn test_parse_header_errors() {
484 let headers = headers();
485
486 assert!(matches!(
487 parse_header::<i64>(&headers, "x-bad-num").unwrap_err(),
488 Error::Http(HttpError::RateLimitI64F64)
489 ));
490 assert!(matches!(
491 parse_header::<i64>(&headers, "x-bad-unicode").unwrap_err(),
492 Error::Http(HttpError::RateLimitUtf8)
493 ));
494 }
495}