serenity/http/
ratelimiting.rs

1//! Routes are used for ratelimiting. These are to differentiate between the different _types_ of
2//! routes - such as getting the current user's channels - for the most part, with the exception
3//! being major parameters.
4//!
5//! [Taken from] the Discord docs, major parameters are:
6//!
7//! > Additionally, rate limits take into account major parameters in the URL. For example,
8//! > `/channels/:channel_id` and `/channels/:channel_id/messages/:message_id` both take
9//! > `channel_id` into account when generating rate limits since it's the major parameter. The
10//! > only current major parameters are `channel_id`, `guild_id` and `webhook_id`.
11//!
12//! This results in the two URLs of `GET /channels/4/messages/7` and `GET /channels/5/messages/8`
13//! being rate limited _separately_. However, the two URLs of `GET /channels/10/messages/11` and
14//! `GET /channels/10/messages/12` will count towards the "same ratelimit", as the major parameter
15//! - `10` is equivalent in both URLs' format.
16//!
17//! # Examples
18//!
19//! First: taking the first two URLs - `GET /channels/4/messages/7` and `GET
20//! /channels/5/messages/8` - and assuming both buckets have a `limit` of `10`, requesting the
21//! first URL will result in the response containing a `remaining` of `9`. Immediately after -
22//! prior to buckets resetting - performing a request to the _second_ URL will also contain a
23//! `remaining` of `9` in the response, as the major parameter - `channel_id` - is different in the
24//! two requests (`4` and `5`).
25//!
26//! Second: take for example the last two URLs. Assuming the bucket's `limit` is `10`, requesting
27//! the first URL will return a `remaining` of `9` in the response. Immediately after - prior to
28//! buckets resetting - performing a request to the _second_ URL will return a `remaining` of `8`
29//! in the response, as the major parameter - `channel_id` - is equivalent for the two requests
30//! (`10`).
31//!
32//! Major parameters are why some variants (i.e. all of the channel/guild variants) have an
33//! associated u64 as data. This is the Id of the parameter, differentiating between different
34//! ratelimits.
35//!
36//! [Taken from]: https://discord.com/developers/docs/topics/rate-limits#rate-limits
37
38use std::collections::HashMap;
39use std::fmt;
40use std::str::{self, FromStr};
41use std::sync::Arc;
42use std::time::SystemTime;
43
44use reqwest::header::HeaderMap;
45use reqwest::{Client, Response, StatusCode};
46use secrecy::{ExposeSecret, SecretString};
47use tokio::sync::{Mutex, RwLock};
48use tokio::time::{sleep, Duration};
49use tracing::{debug, instrument};
50
51pub use super::routing::RatelimitingBucket;
52use super::{HttpError, LightMethod, Request};
53use crate::internal::prelude::*;
54
55/// Passed to the [`Ratelimiter::set_ratelimit_callback`] callback. If using Client, that callback
56/// is initialized to call the `EventHandler::ratelimit()` method.
57#[derive(Clone, Debug)]
58#[non_exhaustive]
59pub struct RatelimitInfo {
60    pub timeout: std::time::Duration,
61    pub limit: i64,
62    pub method: LightMethod,
63    pub path: String,
64    pub global: bool,
65}
66
67/// Ratelimiter for requests to the Discord API.
68///
69/// This keeps track of ratelimit data for known routes through the [`Ratelimit`] implementation
70/// for each route: how many tickets are [`remaining`] until the user needs to wait for the known
71/// [`reset`] time, and the [`limit`] of requests that can be made within that time.
72///
73/// When no tickets are available for some time, then the thread sleeps until that time passes. The
74/// mechanism is known as "pre-emptive ratelimiting".
75///
76/// Occasionally for very high traffic bots, a global ratelimit may be reached which blocks all
77/// future requests until the global ratelimit is over, regardless of route. The value of this
78/// global ratelimit is never given through the API, so it can't be pre-emptively ratelimited. This
79/// only affects the largest of bots.
80///
81/// [`limit`]: Ratelimit::limit
82/// [`remaining`]: Ratelimit::remaining
83/// [`reset`]: Ratelimit::reset
84pub struct Ratelimiter {
85    client: Client,
86    global: Arc<Mutex<()>>,
87    // When futures is implemented, make tasks clear out their respective entry when the 'reset'
88    // passes.
89    routes: Arc<RwLock<HashMap<RatelimitingBucket, Arc<Mutex<Ratelimit>>>>>,
90    token: SecretString,
91    absolute_ratelimits: bool,
92    ratelimit_callback: Box<dyn Fn(RatelimitInfo) + Send + Sync>,
93}
94
95impl fmt::Debug for Ratelimiter {
96    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97        f.debug_struct("Ratelimiter")
98            .field("client", &self.client)
99            .field("global", &self.global)
100            .field("routes", &self.routes)
101            .field("token", &self.token)
102            .field("absolute_ratelimits", &self.absolute_ratelimits)
103            .field("ratelimit_callback", &"Fn(RatelimitInfo)")
104            .finish()
105    }
106}
107
108impl Ratelimiter {
109    /// Creates a new ratelimiter, with a shared [`reqwest`] client and the bot's token.
110    ///
111    /// The bot token must be prefixed with `"Bot "`. The ratelimiter does not prefix it.
112    #[must_use]
113    pub fn new(client: Client, token: impl Into<String>) -> Self {
114        Self::new_(client, token.into())
115    }
116
117    fn new_(client: Client, token: String) -> Self {
118        Self {
119            client,
120            global: Arc::default(),
121            routes: Arc::default(),
122            token: SecretString::new(token),
123            ratelimit_callback: Box::new(|_| {}),
124            absolute_ratelimits: false,
125        }
126    }
127
128    /// Sets a callback to be called when a route is rate limited.
129    pub fn set_ratelimit_callback(
130        &mut self,
131        ratelimit_callback: Box<dyn Fn(RatelimitInfo) + Send + Sync>,
132    ) {
133        self.ratelimit_callback = ratelimit_callback;
134    }
135
136    // Sets whether absolute ratelimits should be used.
137    pub fn set_absolute_ratelimits(&mut self, absolute_ratelimits: bool) {
138        self.absolute_ratelimits = absolute_ratelimits;
139    }
140
141    /// The routes mutex is a HashMap of each [`RatelimitingBucket`] and their respective ratelimit
142    /// information.
143    ///
144    /// See the documentation for [`Ratelimit`] for more information on how the library handles
145    /// ratelimiting.
146    ///
147    /// # Examples
148    ///
149    /// View the `reset` time of the route for `ChannelsId(7)`:
150    ///
151    /// ```rust,no_run
152    /// use serenity::http::Route;
153    /// # use serenity::http::Http;
154    /// # use serenity::model::prelude::*;
155    ///
156    /// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
157    /// # let http: Http = unimplemented!();
158    /// let routes = http.ratelimiter.unwrap().routes();
159    /// let reader = routes.read().await;
160    ///
161    /// let channel_id = ChannelId::new(7);
162    /// let route = Route::Channel {
163    ///     channel_id,
164    /// };
165    /// if let Some(route) = reader.get(&route.ratelimiting_bucket()) {
166    ///     if let Some(reset) = route.lock().await.reset() {
167    ///         println!("Reset time at: {:?}", reset);
168    ///     }
169    /// }
170    /// # Ok(())
171    /// # }
172    /// ```
173    #[must_use]
174    pub fn routes(&self) -> Arc<RwLock<HashMap<RatelimitingBucket, Arc<Mutex<Ratelimit>>>>> {
175        Arc::clone(&self.routes)
176    }
177
178    /// # Errors
179    ///
180    /// Only error kind that may be returned is [`Error::Http`].
181    #[instrument]
182    pub async fn perform(&self, req: Request<'_>) -> Result<Response> {
183        loop {
184            // This will block if another thread hit the global ratelimit.
185            drop(self.global.lock().await);
186
187            // Perform pre-checking here:
188            // - get the route's relevant rate
189            // - sleep if that route's already rate-limited until the end of the 'reset' time;
190            // - get the global rate;
191            // - sleep if there is 0 remaining
192            // - then, perform the request
193            let ratelimiting_bucket = req.route.ratelimiting_bucket();
194            let bucket =
195                Arc::clone(self.routes.write().await.entry(ratelimiting_bucket).or_default());
196
197            bucket.lock().await.pre_hook(&req, &self.ratelimit_callback).await;
198
199            let request = req.clone().build(&self.client, self.token.expose_secret(), None)?;
200            let response = self.client.execute(request.build()?).await?;
201
202            // Check if the request got ratelimited by checking for status 429, and if so, sleep
203            // for the value of the header 'retry-after' - which is in milliseconds - and then
204            // `continue` to try again
205            //
206            // If it didn't ratelimit, subtract one from the Ratelimit's 'remaining'.
207            //
208            // Update `reset` with the value of 'x-ratelimit-reset' header. Similarly, update
209            // `reset-after` with the 'x-ratelimit-reset-after' header.
210            //
211            // It _may_ be possible for the limit to be raised at any time, so check if it did from
212            // the value of the 'x-ratelimit-limit' header. If the limit was 5 and is now 7, add 2
213            // to the 'remaining'
214            if ratelimiting_bucket.is_none() {
215                return Ok(response);
216            }
217
218            let redo = if response.headers().get("x-ratelimit-global").is_some() {
219                drop(self.global.lock().await);
220
221                Ok(
222                    if let Some(retry_after) =
223                        parse_header::<f64>(response.headers(), "retry-after")?
224                    {
225                        debug!(
226                            "Ratelimited on route {:?} for {:?}s",
227                            ratelimiting_bucket, retry_after
228                        );
229                        (self.ratelimit_callback)(RatelimitInfo {
230                            timeout: Duration::from_secs_f64(retry_after),
231                            limit: 50,
232                            method: req.method,
233                            path: req.route.path().to_string(),
234                            global: true,
235                        });
236                        sleep(Duration::from_secs_f64(retry_after)).await;
237
238                        true
239                    } else {
240                        false
241                    },
242                )
243            } else {
244                bucket
245                    .lock()
246                    .await
247                    .post_hook(&response, &req, &self.ratelimit_callback, self.absolute_ratelimits)
248                    .await
249            };
250
251            if !redo.unwrap_or(true) {
252                return Ok(response);
253            }
254        }
255    }
256}
257
258/// A set of data containing information about the ratelimits for a particular
259/// [`RatelimitingBucket`], which is stored in [`Http`].
260///
261/// See the [Discord docs] on ratelimits for more information.
262///
263/// **Note**: You should _not_ mutate any of the fields, as this can help cause 429s.
264///
265/// [`Http`]: super::Http
266/// [Discord docs]: https://discord.com/developers/docs/topics/rate-limits
267#[derive(Debug)]
268pub struct Ratelimit {
269    /// The total number of requests that can be made in a period of time.
270    limit: i64,
271    /// The number of requests remaining in the period of time.
272    remaining: i64,
273    /// The absolute time when the interval resets.
274    reset: Option<SystemTime>,
275    /// The total time when the interval resets.
276    reset_after: Option<Duration>,
277}
278
279impl Ratelimit {
280    #[instrument(skip(ratelimit_callback))]
281    pub async fn pre_hook(
282        &mut self,
283        req: &Request<'_>,
284        ratelimit_callback: &(dyn Fn(RatelimitInfo) + Send + Sync),
285    ) {
286        if self.limit() == 0 {
287            return;
288        }
289
290        let Some(reset) = self.reset else {
291            // We're probably in the past.
292            self.remaining = self.limit;
293            return;
294        };
295
296        let Ok(delay) = reset.duration_since(SystemTime::now()) else {
297            // if duration is negative (i.e. adequate time has passed since last call to this api)
298            if self.remaining() != 0 {
299                self.remaining -= 1;
300            }
301            return;
302        };
303
304        if self.remaining() == 0 {
305            debug!(
306                "Pre-emptive ratelimit on route {:?} for {}ms",
307                req.route.ratelimiting_bucket(),
308                delay.as_millis(),
309            );
310            ratelimit_callback(RatelimitInfo {
311                timeout: delay,
312                limit: self.limit,
313                method: req.method,
314                path: req.route.path().to_string(),
315                global: false,
316            });
317
318            sleep(delay).await;
319
320            return;
321        }
322
323        self.remaining -= 1;
324    }
325
326    #[instrument(skip(ratelimit_callback))]
327    pub async fn post_hook(
328        &mut self,
329        response: &Response,
330        req: &Request<'_>,
331        ratelimit_callback: &(dyn Fn(RatelimitInfo) + Send + Sync),
332        absolute_ratelimits: bool,
333    ) -> Result<bool> {
334        if let Some(limit) = parse_header(response.headers(), "x-ratelimit-limit")? {
335            self.limit = limit;
336        }
337
338        if let Some(remaining) = parse_header(response.headers(), "x-ratelimit-remaining")? {
339            self.remaining = remaining;
340        }
341
342        if absolute_ratelimits {
343            if let Some(reset) = parse_header::<f64>(response.headers(), "x-ratelimit-reset")? {
344                self.reset = Some(std::time::UNIX_EPOCH + Duration::from_secs_f64(reset));
345            }
346        }
347
348        if let Some(reset_after) =
349            parse_header::<f64>(response.headers(), "x-ratelimit-reset-after")?
350        {
351            if !absolute_ratelimits {
352                self.reset = Some(SystemTime::now() + Duration::from_secs_f64(reset_after));
353            }
354
355            self.reset_after = Some(Duration::from_secs_f64(reset_after));
356        }
357
358        Ok(if response.status() != StatusCode::TOO_MANY_REQUESTS {
359            false
360        } else if let Some(retry_after) = parse_header::<f64>(response.headers(), "retry-after")? {
361            debug!(
362                "Ratelimited on route {:?} for {:?}s",
363                req.route.ratelimiting_bucket(),
364                retry_after
365            );
366            ratelimit_callback(RatelimitInfo {
367                timeout: Duration::from_secs_f64(retry_after),
368                limit: self.limit,
369                method: req.method,
370                path: req.route.path().to_string(),
371                global: false,
372            });
373
374            sleep(Duration::from_secs_f64(retry_after)).await;
375
376            true
377        } else {
378            false
379        })
380    }
381
382    /// The total number of requests that can be made in a period of time.
383    #[inline]
384    #[must_use]
385    pub const fn limit(&self) -> i64 {
386        self.limit
387    }
388
389    /// The number of requests remaining in the period of time.
390    #[inline]
391    #[must_use]
392    pub const fn remaining(&self) -> i64 {
393        self.remaining
394    }
395
396    /// The absolute time in milliseconds when the interval resets.
397    #[inline]
398    #[must_use]
399    pub const fn reset(&self) -> Option<SystemTime> {
400        self.reset
401    }
402
403    /// The total time in milliseconds when the interval resets.
404    #[inline]
405    #[must_use]
406    pub const fn reset_after(&self) -> Option<Duration> {
407        self.reset_after
408    }
409}
410
411impl Default for Ratelimit {
412    fn default() -> Self {
413        Self {
414            limit: i64::MAX,
415            remaining: i64::MAX,
416            reset: None,
417            reset_after: None,
418        }
419    }
420}
421
422fn parse_header<T: FromStr>(headers: &HeaderMap, header: &str) -> Result<Option<T>> {
423    let Some(header) = headers.get(header) else { return Ok(None) };
424
425    let unicode =
426        str::from_utf8(header.as_bytes()).map_err(|_| Error::from(HttpError::RateLimitUtf8))?;
427
428    let num = unicode.parse().map_err(|_| Error::from(HttpError::RateLimitI64F64))?;
429
430    Ok(Some(num))
431}
432
433#[cfg(test)]
434mod tests {
435    use std::error::Error as StdError;
436    use std::result::Result as StdResult;
437
438    use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
439
440    use super::parse_header;
441    use crate::error::Error;
442    use crate::http::HttpError;
443
444    type Result<T> = StdResult<T, Box<dyn StdError>>;
445
446    fn headers() -> HeaderMap {
447        let pairs = &[
448            (HeaderName::from_static("x-ratelimit-limit"), HeaderValue::from_static("5")),
449            (HeaderName::from_static("x-ratelimit-remaining"), HeaderValue::from_static("4")),
450            (
451                HeaderName::from_static("x-ratelimit-reset"),
452                HeaderValue::from_static("1560704880.423"),
453            ),
454            (HeaderName::from_static("x-bad-num"), HeaderValue::from_static("abc")),
455            (
456                HeaderName::from_static("x-bad-unicode"),
457                HeaderValue::from_bytes(&[255, 255, 255, 255]).unwrap(),
458            ),
459        ];
460
461        let mut map = HeaderMap::with_capacity(pairs.len());
462
463        for (name, val) in pairs {
464            map.insert(name, val.clone());
465        }
466
467        map
468    }
469
470    #[test]
471    #[allow(clippy::float_cmp)]
472    fn test_parse_header_good() -> Result<()> {
473        let headers = headers();
474
475        assert_eq!(parse_header::<i64>(&headers, "x-ratelimit-limit")?.unwrap(), 5);
476        assert_eq!(parse_header::<i64>(&headers, "x-ratelimit-remaining")?.unwrap(), 4,);
477        assert_eq!(parse_header::<f64>(&headers, "x-ratelimit-reset")?.unwrap(), 1_560_704_880.423);
478
479        Ok(())
480    }
481
482    #[test]
483    fn test_parse_header_errors() {
484        let headers = headers();
485
486        assert!(matches!(
487            parse_header::<i64>(&headers, "x-bad-num").unwrap_err(),
488            Error::Http(HttpError::RateLimitI64F64)
489        ));
490        assert!(matches!(
491            parse_header::<i64>(&headers, "x-bad-unicode").unwrap_err(),
492            Error::Http(HttpError::RateLimitUtf8)
493        ));
494    }
495}