http_body/
limited.rs

1use crate::{Body, SizeHint};
2use bytes::Buf;
3use http::HeaderMap;
4use pin_project_lite::pin_project;
5use std::error::Error;
6use std::fmt;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9
10pin_project! {
11    /// A length limited body.
12    ///
13    /// This body will return an error if more than the configured number
14    /// of bytes are returned on polling the wrapped body.
15    #[derive(Clone, Copy, Debug)]
16    pub struct Limited<B> {
17        remaining: usize,
18        #[pin]
19        inner: B,
20    }
21}
22
23impl<B> Limited<B> {
24    /// Create a new `Limited`.
25    pub fn new(inner: B, limit: usize) -> Self {
26        Self {
27            remaining: limit,
28            inner,
29        }
30    }
31}
32
33impl<B> Body for Limited<B>
34where
35    B: Body,
36    B::Error: Into<Box<dyn Error + Send + Sync>>,
37{
38    type Data = B::Data;
39    type Error = Box<dyn Error + Send + Sync>;
40
41    fn poll_data(
42        self: Pin<&mut Self>,
43        cx: &mut Context<'_>,
44    ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
45        let this = self.project();
46        let res = match this.inner.poll_data(cx) {
47            Poll::Pending => return Poll::Pending,
48            Poll::Ready(None) => None,
49            Poll::Ready(Some(Ok(data))) => {
50                if data.remaining() > *this.remaining {
51                    *this.remaining = 0;
52                    Some(Err(LengthLimitError.into()))
53                } else {
54                    *this.remaining -= data.remaining();
55                    Some(Ok(data))
56                }
57            }
58            Poll::Ready(Some(Err(err))) => Some(Err(err.into())),
59        };
60
61        Poll::Ready(res)
62    }
63
64    fn poll_trailers(
65        self: Pin<&mut Self>,
66        cx: &mut Context<'_>,
67    ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
68        let this = self.project();
69        let res = match this.inner.poll_trailers(cx) {
70            Poll::Pending => return Poll::Pending,
71            Poll::Ready(Ok(data)) => Ok(data),
72            Poll::Ready(Err(err)) => Err(err.into()),
73        };
74
75        Poll::Ready(res)
76    }
77
78    fn is_end_stream(&self) -> bool {
79        self.inner.is_end_stream()
80    }
81
82    fn size_hint(&self) -> SizeHint {
83        use std::convert::TryFrom;
84        match u64::try_from(self.remaining) {
85            Ok(n) => {
86                let mut hint = self.inner.size_hint();
87                if hint.lower() >= n {
88                    hint.set_exact(n)
89                } else if let Some(max) = hint.upper() {
90                    hint.set_upper(n.min(max))
91                } else {
92                    hint.set_upper(n)
93                }
94                hint
95            }
96            Err(_) => self.inner.size_hint(),
97        }
98    }
99}
100
101/// An error returned when body length exceeds the configured limit.
102#[derive(Debug)]
103#[non_exhaustive]
104pub struct LengthLimitError;
105
106impl fmt::Display for LengthLimitError {
107    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
108        f.write_str("length limit exceeded")
109    }
110}
111
112impl Error for LengthLimitError {}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use crate::Full;
118    use bytes::Bytes;
119    use std::convert::Infallible;
120
121    #[tokio::test]
122    async fn read_for_body_under_limit_returns_data() {
123        const DATA: &[u8] = b"testing";
124        let inner = Full::new(Bytes::from(DATA));
125        let body = &mut Limited::new(inner, 8);
126
127        let mut hint = SizeHint::new();
128        hint.set_upper(7);
129        assert_eq!(body.size_hint().upper(), hint.upper());
130
131        let data = body.data().await.unwrap().unwrap();
132        assert_eq!(data, DATA);
133        hint.set_upper(0);
134        assert_eq!(body.size_hint().upper(), hint.upper());
135
136        assert!(matches!(body.data().await, None));
137    }
138
139    #[tokio::test]
140    async fn read_for_body_over_limit_returns_error() {
141        const DATA: &[u8] = b"testing a string that is too long";
142        let inner = Full::new(Bytes::from(DATA));
143        let body = &mut Limited::new(inner, 8);
144
145        let mut hint = SizeHint::new();
146        hint.set_upper(8);
147        assert_eq!(body.size_hint().upper(), hint.upper());
148
149        let error = body.data().await.unwrap().unwrap_err();
150        assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
151    }
152
153    struct Chunky(&'static [&'static [u8]]);
154
155    impl Body for Chunky {
156        type Data = &'static [u8];
157        type Error = Infallible;
158
159        fn poll_data(
160            self: Pin<&mut Self>,
161            _cx: &mut Context<'_>,
162        ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
163            let mut this = self;
164            match this.0.split_first().map(|(&head, tail)| (Ok(head), tail)) {
165                Some((data, new_tail)) => {
166                    this.0 = new_tail;
167
168                    Poll::Ready(Some(data))
169                }
170                None => Poll::Ready(None),
171            }
172        }
173
174        fn poll_trailers(
175            self: Pin<&mut Self>,
176            _cx: &mut Context<'_>,
177        ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
178            Poll::Ready(Ok(Some(HeaderMap::new())))
179        }
180    }
181
182    #[tokio::test]
183    async fn read_for_chunked_body_around_limit_returns_first_chunk_but_returns_error_on_over_limit_chunk(
184    ) {
185        const DATA: &[&[u8]] = &[b"testing ", b"a string that is too long"];
186        let inner = Chunky(DATA);
187        let body = &mut Limited::new(inner, 8);
188
189        let mut hint = SizeHint::new();
190        hint.set_upper(8);
191        assert_eq!(body.size_hint().upper(), hint.upper());
192
193        let data = body.data().await.unwrap().unwrap();
194        assert_eq!(data, DATA[0]);
195        hint.set_upper(0);
196        assert_eq!(body.size_hint().upper(), hint.upper());
197
198        let error = body.data().await.unwrap().unwrap_err();
199        assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
200    }
201
202    #[tokio::test]
203    async fn read_for_chunked_body_over_limit_on_first_chunk_returns_error() {
204        const DATA: &[&[u8]] = &[b"testing a string", b" that is too long"];
205        let inner = Chunky(DATA);
206        let body = &mut Limited::new(inner, 8);
207
208        let mut hint = SizeHint::new();
209        hint.set_upper(8);
210        assert_eq!(body.size_hint().upper(), hint.upper());
211
212        let error = body.data().await.unwrap().unwrap_err();
213        assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
214    }
215
216    #[tokio::test]
217    async fn read_for_chunked_body_under_limit_is_okay() {
218        const DATA: &[&[u8]] = &[b"test", b"ing!"];
219        let inner = Chunky(DATA);
220        let body = &mut Limited::new(inner, 8);
221
222        let mut hint = SizeHint::new();
223        hint.set_upper(8);
224        assert_eq!(body.size_hint().upper(), hint.upper());
225
226        let data = body.data().await.unwrap().unwrap();
227        assert_eq!(data, DATA[0]);
228        hint.set_upper(4);
229        assert_eq!(body.size_hint().upper(), hint.upper());
230
231        let data = body.data().await.unwrap().unwrap();
232        assert_eq!(data, DATA[1]);
233        hint.set_upper(0);
234        assert_eq!(body.size_hint().upper(), hint.upper());
235
236        assert!(matches!(body.data().await, None));
237    }
238
239    #[tokio::test]
240    async fn read_for_trailers_propagates_inner_trailers() {
241        const DATA: &[&[u8]] = &[b"test", b"ing!"];
242        let inner = Chunky(DATA);
243        let body = &mut Limited::new(inner, 8);
244        let trailers = body.trailers().await.unwrap();
245        assert_eq!(trailers, Some(HeaderMap::new()))
246    }
247
248    #[derive(Debug)]
249    enum ErrorBodyError {
250        Data,
251        Trailers,
252    }
253
254    impl fmt::Display for ErrorBodyError {
255        fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result {
256            Ok(())
257        }
258    }
259
260    impl Error for ErrorBodyError {}
261
262    struct ErrorBody;
263
264    impl Body for ErrorBody {
265        type Data = &'static [u8];
266        type Error = ErrorBodyError;
267
268        fn poll_data(
269            self: Pin<&mut Self>,
270            _cx: &mut Context<'_>,
271        ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
272            Poll::Ready(Some(Err(ErrorBodyError::Data)))
273        }
274
275        fn poll_trailers(
276            self: Pin<&mut Self>,
277            _cx: &mut Context<'_>,
278        ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
279            Poll::Ready(Err(ErrorBodyError::Trailers))
280        }
281    }
282
283    #[tokio::test]
284    async fn read_for_body_returning_error_propagates_error() {
285        let body = &mut Limited::new(ErrorBody, 8);
286        let error = body.data().await.unwrap().unwrap_err();
287        assert!(matches!(error.downcast_ref(), Some(ErrorBodyError::Data)));
288    }
289
290    #[tokio::test]
291    async fn trailers_for_body_returning_error_propagates_error() {
292        let body = &mut Limited::new(ErrorBody, 8);
293        let error = body.trailers().await.unwrap_err();
294        assert!(matches!(
295            error.downcast_ref(),
296            Some(ErrorBodyError::Trailers)
297        ));
298    }
299}