serenity/http/ratelimiting.rs
1//! Routes are used for ratelimiting. These are to differentiate between the different _types_ of
2//! routes - such as getting the current user's channels - for the most part, with the exception
3//! being major parameters.
4//!
5//! [Taken from] the Discord docs, major parameters are:
6//!
7//! > Additionally, rate limits take into account major parameters in the URL. For example,
8//! > `/channels/:channel_id` and `/channels/:channel_id/messages/:message_id` both take
9//! > `channel_id` into account when generating rate limits since it's the major parameter. The
10//! > only current major parameters are `channel_id`, `guild_id` and `webhook_id`.
11//!
12//! This results in the two URLs of `GET /channels/4/messages/7` and `GET /channels/5/messages/8`
13//! being rate limited _separately_. However, the two URLs of `GET /channels/10/messages/11` and
14//! `GET /channels/10/messages/12` will count towards the "same ratelimit", as the major parameter
15//! - `10` is equivalent in both URLs' format.
16//!
17//! # Examples
18//!
19//! First: taking the first two URLs - `GET /channels/4/messages/7` and `GET
20//! /channels/5/messages/8` - and assuming both buckets have a `limit` of `10`, requesting the
21//! first URL will result in the response containing a `remaining` of `9`. Immediately after -
22//! prior to buckets resetting - performing a request to the _second_ URL will also contain a
23//! `remaining` of `9` in the response, as the major parameter - `channel_id` - is different in the
24//! two requests (`4` and `5`).
25//!
26//! Second: take for example the last two URLs. Assuming the bucket's `limit` is `10`, requesting
27//! the first URL will return a `remaining` of `9` in the response. Immediately after - prior to
28//! buckets resetting - performing a request to the _second_ URL will return a `remaining` of `8`
29//! in the response, as the major parameter - `channel_id` - is equivalent for the two requests
30//! (`10`).
31//!
32//! Major parameters are why some variants (i.e. all of the channel/guild variants) have an
33//! associated u64 as data. This is the Id of the parameter, differentiating between different
34//! ratelimits.
35//!
36//! [Taken from]: https://discord.com/developers/docs/topics/rate-limits#rate-limits
37
38use std::collections::HashMap;
39use std::fmt;
40use std::str::{self, FromStr};
41use std::sync::Arc;
42use std::time::SystemTime;
43
44use reqwest::header::HeaderMap;
45use reqwest::{Client, Response, StatusCode};
46use secrecy::{ExposeSecret, SecretString};
47use tokio::sync::{Mutex, RwLock};
48use tokio::time::{sleep, Duration};
49use tracing::{debug, instrument};
50
51pub use super::routing::RatelimitingBucket;
52use super::{HttpError, LightMethod, Request};
53use crate::internal::prelude::*;
54
55/// Passed to the [`Ratelimiter::set_ratelimit_callback`] callback. If using Client, that callback
56/// is initialized to call the `EventHandler::ratelimit()` method.
57#[derive(Clone, Debug)]
58#[non_exhaustive]
59pub struct RatelimitInfo {
60 pub timeout: std::time::Duration,
61 pub limit: i64,
62 pub method: LightMethod,
63 pub path: String,
64 pub global: bool,
65}
66
67/// Ratelimiter for requests to the Discord API.
68///
69/// This keeps track of ratelimit data for known routes through the [`Ratelimit`] implementation
70/// for each route: how many tickets are [`remaining`] until the user needs to wait for the known
71/// [`reset`] time, and the [`limit`] of requests that can be made within that time.
72///
73/// When no tickets are available for some time, then the thread sleeps until that time passes. The
74/// mechanism is known as "pre-emptive ratelimiting".
75///
76/// Occasionally for very high traffic bots, a global ratelimit may be reached which blocks all
77/// future requests until the global ratelimit is over, regardless of route. The value of this
78/// global ratelimit is never given through the API, so it can't be pre-emptively ratelimited. This
79/// only affects the largest of bots.
80///
81/// [`limit`]: Ratelimit::limit
82/// [`remaining`]: Ratelimit::remaining
83/// [`reset`]: Ratelimit::reset
84pub struct Ratelimiter {
85 client: Client,
86 global: Arc<Mutex<()>>,
87 // When futures is implemented, make tasks clear out their respective entry when the 'reset'
88 // passes.
89 routes: Arc<RwLock<HashMap<RatelimitingBucket, Arc<Mutex<Ratelimit>>>>>,
90 token: SecretString,
91 absolute_ratelimits: bool,
92 ratelimit_callback: Box<dyn Fn(RatelimitInfo) + Send + Sync>,
93}
94
95impl fmt::Debug for Ratelimiter {
96 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97 f.debug_struct("Ratelimiter")
98 .field("client", &self.client)
99 .field("global", &self.global)
100 .field("routes", &self.routes)
101 .field("token", &self.token)
102 .field("absolute_ratelimits", &self.absolute_ratelimits)
103 .field("ratelimit_callback", &"Fn(RatelimitInfo)")
104 .finish()
105 }
106}
107
108impl Ratelimiter {
109 /// Creates a new ratelimiter, with a shared [`reqwest`] client and the bot's token.
110 ///
111 /// The bot token must be prefixed with `"Bot "`. The ratelimiter does not prefix it.
112 #[must_use]
113 pub fn new(client: Client, token: impl Into<String>) -> Self {
114 Self::new_(client, token.into())
115 }
116
117 fn new_(client: Client, token: String) -> Self {
118 Self {
119 client,
120 global: Arc::default(),
121 routes: Arc::default(),
122 token: SecretString::new(token),
123 ratelimit_callback: Box::new(|_| {}),
124 absolute_ratelimits: false,
125 }
126 }
127
128 /// Sets a callback to be called when a route is rate limited.
129 pub fn set_ratelimit_callback(
130 &mut self,
131 ratelimit_callback: Box<dyn Fn(RatelimitInfo) + Send + Sync>,
132 ) {
133 self.ratelimit_callback = ratelimit_callback;
134 }
135
136 // Sets whether absolute ratelimits should be used.
137 pub fn set_absolute_ratelimits(&mut self, absolute_ratelimits: bool) {
138 self.absolute_ratelimits = absolute_ratelimits;
139 }
140
141 /// The routes mutex is a HashMap of each [`RatelimitingBucket`] and their respective ratelimit
142 /// information.
143 ///
144 /// See the documentation for [`Ratelimit`] for more information on how the library handles
145 /// ratelimiting.
146 ///
147 /// # Examples
148 ///
149 /// View the `reset` time of the route for `ChannelsId(7)`:
150 ///
151 /// ```rust,no_run
152 /// use serenity::http::Route;
153 /// # use serenity::http::Http;
154 /// # use serenity::model::prelude::*;
155 ///
156 /// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
157 /// # let http: Http = unimplemented!();
158 /// let routes = http.ratelimiter.unwrap().routes();
159 /// let reader = routes.read().await;
160 ///
161 /// let channel_id = ChannelId::new(7);
162 /// let route = Route::Channel {
163 /// channel_id,
164 /// };
165 /// if let Some(route) = reader.get(&route.ratelimiting_bucket()) {
166 /// if let Some(reset) = route.lock().await.reset() {
167 /// println!("Reset time at: {:?}", reset);
168 /// }
169 /// }
170 /// # Ok(())
171 /// # }
172 /// ```
173 #[must_use]
174 pub fn routes(&self) -> Arc<RwLock<HashMap<RatelimitingBucket, Arc<Mutex<Ratelimit>>>>> {
175 Arc::clone(&self.routes)
176 }
177
178 /// # Errors
179 ///
180 /// Only error kind that may be returned is [`Error::Http`].
181 #[instrument]
182 pub async fn perform(&self, req: Request<'_>) -> Result<Response> {
183 loop {
184 // This will block if another thread hit the global ratelimit.
185 drop(self.global.lock().await);
186
187 // Perform pre-checking here:
188 // - get the route's relevant rate
189 // - sleep if that route's already rate-limited until the end of the 'reset' time;
190 // - get the global rate;
191 // - sleep if there is 0 remaining
192 // - then, perform the request
193 let ratelimiting_bucket = req.route.ratelimiting_bucket();
194 let bucket =
195 Arc::clone(self.routes.write().await.entry(ratelimiting_bucket).or_default());
196
197 bucket.lock().await.pre_hook(&req, &self.ratelimit_callback).await;
198
199 let request = req.clone().build(&self.client, self.token.expose_secret(), None)?;
200 let response = self.client.execute(request.build()?).await?;
201
202 // Check if the request got ratelimited by checking for status 429, and if so, sleep
203 // for the value of the header 'retry-after' - which is in milliseconds - and then
204 // `continue` to try again
205 //
206 // If it didn't ratelimit, subtract one from the Ratelimit's 'remaining'.
207 //
208 // Update `reset` with the value of 'x-ratelimit-reset' header. Similarly, update
209 // `reset-after` with the 'x-ratelimit-reset-after' header.
210 //
211 // It _may_ be possible for the limit to be raised at any time, so check if it did from
212 // the value of the 'x-ratelimit-limit' header. If the limit was 5 and is now 7, add 2
213 // to the 'remaining'
214 if ratelimiting_bucket.is_none() {
215 return Ok(response);
216 }
217
218 let redo = if response.headers().get("x-ratelimit-global").is_some() {
219 drop(self.global.lock().await);
220
221 Ok(
222 if let Some(retry_after) =
223 parse_header::<f64>(response.headers(), "retry-after")?
224 {
225 debug!(
226 "Ratelimited on route {:?} for {:?}s",
227 ratelimiting_bucket, retry_after
228 );
229 (self.ratelimit_callback)(RatelimitInfo {
230 timeout: Duration::from_secs_f64(retry_after),
231 limit: 50,
232 method: req.method,
233 path: req.route.path().to_string(),
234 global: true,
235 });
236 sleep(Duration::from_secs_f64(retry_after)).await;
237
238 true
239 } else {
240 false
241 },
242 )
243 } else {
244 bucket
245 .lock()
246 .await
247 .post_hook(&response, &req, &self.ratelimit_callback, self.absolute_ratelimits)
248 .await
249 };
250
251 if !redo.unwrap_or(true) {
252 return Ok(response);
253 }
254 }
255 }
256}
257
258/// A set of data containing information about the ratelimits for a particular
259/// [`RatelimitingBucket`], which is stored in [`Http`].
260///
261/// See the [Discord docs] on ratelimits for more information.
262///
263/// **Note**: You should _not_ mutate any of the fields, as this can help cause 429s.
264///
265/// [`Http`]: super::Http
266/// [Discord docs]: https://discord.com/developers/docs/topics/rate-limits
267#[derive(Debug)]
268pub struct Ratelimit {
269 /// The total number of requests that can be made in a period of time.
270 limit: i64,
271 /// The number of requests remaining in the period of time.
272 remaining: i64,
273 /// The absolute time when the interval resets.
274 reset: Option<SystemTime>,
275 /// The total time when the interval resets.
276 reset_after: Option<Duration>,
277}
278
279impl Ratelimit {
280 #[instrument(skip(ratelimit_callback))]
281 pub async fn pre_hook(
282 &mut self,
283 req: &Request<'_>,
284 ratelimit_callback: &(dyn Fn(RatelimitInfo) + Send + Sync),
285 ) {
286 if self.limit() == 0 {
287 return;
288 }
289
290 let Some(reset) = self.reset else {
291 // We're probably in the past.
292 self.remaining = self.limit;
293 return;
294 };
295
296 let Ok(delay) = reset.duration_since(SystemTime::now()) else {
297 // if duration is negative (i.e. adequate time has passed since last call to this api)
298 if self.remaining() != 0 {
299 self.remaining -= 1;
300 }
301 return;
302 };
303
304 if self.remaining() == 0 {
305 debug!(
306 "Pre-emptive ratelimit on route {:?} for {}ms",
307 req.route.ratelimiting_bucket(),
308 delay.as_millis(),
309 );
310 ratelimit_callback(RatelimitInfo {
311 timeout: delay,
312 limit: self.limit,
313 method: req.method,
314 path: req.route.path().to_string(),
315 global: false,
316 });
317
318 sleep(delay).await;
319
320 return;
321 }
322
323 self.remaining -= 1;
324 }
325
326 /// # Errors
327 ///
328 /// Errors if parsing headers from the response fails.
329 #[instrument(skip(ratelimit_callback))]
330 pub async fn post_hook(
331 &mut self,
332 response: &Response,
333 req: &Request<'_>,
334 ratelimit_callback: &(dyn Fn(RatelimitInfo) + Send + Sync),
335 absolute_ratelimits: bool,
336 ) -> Result<bool> {
337 if let Some(limit) = parse_header(response.headers(), "x-ratelimit-limit")? {
338 self.limit = limit;
339 }
340
341 if let Some(remaining) = parse_header(response.headers(), "x-ratelimit-remaining")? {
342 self.remaining = remaining;
343 }
344
345 if absolute_ratelimits {
346 if let Some(reset) = parse_header::<f64>(response.headers(), "x-ratelimit-reset")? {
347 self.reset = Some(std::time::UNIX_EPOCH + Duration::from_secs_f64(reset));
348 }
349 }
350
351 if let Some(reset_after) =
352 parse_header::<f64>(response.headers(), "x-ratelimit-reset-after")?
353 {
354 if !absolute_ratelimits {
355 self.reset = Some(SystemTime::now() + Duration::from_secs_f64(reset_after));
356 }
357
358 self.reset_after = Some(Duration::from_secs_f64(reset_after));
359 }
360
361 Ok(if response.status() != StatusCode::TOO_MANY_REQUESTS {
362 false
363 } else if let Some(retry_after) = parse_header::<f64>(response.headers(), "retry-after")? {
364 debug!(
365 "Ratelimited on route {:?} for {:?}s",
366 req.route.ratelimiting_bucket(),
367 retry_after
368 );
369 ratelimit_callback(RatelimitInfo {
370 timeout: Duration::from_secs_f64(retry_after),
371 limit: self.limit,
372 method: req.method,
373 path: req.route.path().to_string(),
374 global: false,
375 });
376
377 sleep(Duration::from_secs_f64(retry_after)).await;
378
379 true
380 } else {
381 false
382 })
383 }
384
385 /// The total number of requests that can be made in a period of time.
386 #[inline]
387 #[must_use]
388 pub const fn limit(&self) -> i64 {
389 self.limit
390 }
391
392 /// The number of requests remaining in the period of time.
393 #[inline]
394 #[must_use]
395 pub const fn remaining(&self) -> i64 {
396 self.remaining
397 }
398
399 /// The absolute time in milliseconds when the interval resets.
400 #[inline]
401 #[must_use]
402 pub const fn reset(&self) -> Option<SystemTime> {
403 self.reset
404 }
405
406 /// The total time in milliseconds when the interval resets.
407 #[inline]
408 #[must_use]
409 pub const fn reset_after(&self) -> Option<Duration> {
410 self.reset_after
411 }
412}
413
414impl Default for Ratelimit {
415 fn default() -> Self {
416 Self {
417 limit: i64::MAX,
418 remaining: i64::MAX,
419 reset: None,
420 reset_after: None,
421 }
422 }
423}
424
425fn parse_header<T: FromStr>(headers: &HeaderMap, header: &str) -> Result<Option<T>> {
426 let Some(header) = headers.get(header) else { return Ok(None) };
427
428 let unicode =
429 str::from_utf8(header.as_bytes()).map_err(|_| Error::from(HttpError::RateLimitUtf8))?;
430
431 let num = unicode.parse().map_err(|_| Error::from(HttpError::RateLimitI64F64))?;
432
433 Ok(Some(num))
434}
435
436#[cfg(test)]
437mod tests {
438 use std::error::Error as StdError;
439 use std::result::Result as StdResult;
440
441 use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
442
443 use super::parse_header;
444 use crate::error::Error;
445 use crate::http::HttpError;
446
447 type Result<T> = StdResult<T, Box<dyn StdError>>;
448
449 fn headers() -> HeaderMap {
450 let pairs = &[
451 (HeaderName::from_static("x-ratelimit-limit"), HeaderValue::from_static("5")),
452 (HeaderName::from_static("x-ratelimit-remaining"), HeaderValue::from_static("4")),
453 (
454 HeaderName::from_static("x-ratelimit-reset"),
455 HeaderValue::from_static("1560704880.423"),
456 ),
457 (HeaderName::from_static("x-bad-num"), HeaderValue::from_static("abc")),
458 (
459 HeaderName::from_static("x-bad-unicode"),
460 HeaderValue::from_bytes(&[255, 255, 255, 255]).unwrap(),
461 ),
462 ];
463
464 let mut map = HeaderMap::with_capacity(pairs.len());
465
466 for (name, val) in pairs {
467 map.insert(name, val.clone());
468 }
469
470 map
471 }
472
473 #[test]
474 #[allow(clippy::float_cmp)]
475 fn test_parse_header_good() -> Result<()> {
476 let headers = headers();
477
478 assert_eq!(parse_header::<i64>(&headers, "x-ratelimit-limit")?.unwrap(), 5);
479 assert_eq!(parse_header::<i64>(&headers, "x-ratelimit-remaining")?.unwrap(), 4,);
480 assert_eq!(parse_header::<f64>(&headers, "x-ratelimit-reset")?.unwrap(), 1_560_704_880.423);
481
482 Ok(())
483 }
484
485 #[test]
486 fn test_parse_header_errors() {
487 let headers = headers();
488
489 assert!(matches!(
490 parse_header::<i64>(&headers, "x-bad-num").unwrap_err(),
491 Error::Http(HttpError::RateLimitI64F64)
492 ));
493 assert!(matches!(
494 parse_header::<i64>(&headers, "x-bad-unicode").unwrap_err(),
495 Error::Http(HttpError::RateLimitUtf8)
496 ));
497 }
498}