hyper/
upgrade.rs

1//! HTTP Upgrades
2//!
3//! This module deals with managing [HTTP Upgrades][mdn] in hyper. Since
4//! several concepts in HTTP allow for first talking HTTP, and then converting
5//! to a different protocol, this module conflates them into a single API.
6//! Those include:
7//!
8//! - HTTP/1.1 Upgrades
9//! - HTTP `CONNECT`
10//!
11//! You are responsible for any other pre-requisites to establish an upgrade,
12//! such as sending the appropriate headers, methods, and status codes. You can
13//! then use [`on`][] to grab a `Future` which will resolve to the upgraded
14//! connection object, or an error if the upgrade fails.
15//!
16//! [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Protocol_upgrade_mechanism
17//!
18//! # Client
19//!
20//! Sending an HTTP upgrade from the [`client`](super::client) involves setting
21//! either the appropriate method, if wanting to `CONNECT`, or headers such as
22//! `Upgrade` and `Connection`, on the `http::Request`. Once receiving the
23//! `http::Response` back, you must check for the specific information that the
24//! upgrade is agreed upon by the server (such as a `101` status code), and then
25//! get the `Future` from the `Response`.
26//!
27//! # Server
28//!
29//! Receiving upgrade requests in a server requires you to check the relevant
30//! headers in a `Request`, and if an upgrade should be done, you then send the
31//! corresponding headers in a response. To then wait for hyper to finish the
32//! upgrade, you call `on()` with the `Request`, and then can spawn a task
33//! awaiting it.
34//!
35//! # Example
36//!
37//! See [this example][example] showing how upgrades work with both
38//! Clients and Servers.
39//!
40//! [example]: https://github.com/hyperium/hyper/blob/master/examples/upgrades.rs
41
42use std::any::TypeId;
43use std::error::Error as StdError;
44use std::fmt;
45use std::future::Future;
46use std::io;
47use std::marker::Unpin;
48use std::pin::Pin;
49use std::task::{Context, Poll};
50
51use bytes::Bytes;
52use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
53use tokio::sync::oneshot;
54#[cfg(any(feature = "http1", feature = "http2"))]
55use tracing::trace;
56
57use crate::common::io::Rewind;
58
59/// An upgraded HTTP connection.
60///
61/// This type holds a trait object internally of the original IO that
62/// was used to speak HTTP before the upgrade. It can be used directly
63/// as a `Read` or `Write` for convenience.
64///
65/// Alternatively, if the exact type is known, this can be deconstructed
66/// into its parts.
67pub struct Upgraded {
68    io: Rewind<Box<dyn Io + Send>>,
69}
70
71/// A future for a possible HTTP upgrade.
72///
73/// If no upgrade was available, or it doesn't succeed, yields an `Error`.
74pub struct OnUpgrade {
75    rx: Option<oneshot::Receiver<crate::Result<Upgraded>>>,
76}
77
78/// The deconstructed parts of an [`Upgraded`](Upgraded) type.
79///
80/// Includes the original IO type, and a read buffer of bytes that the
81/// HTTP state machine may have already read before completing an upgrade.
82#[derive(Debug)]
83pub struct Parts<T> {
84    /// The original IO object used before the upgrade.
85    pub io: T,
86    /// A buffer of bytes that have been read but not processed as HTTP.
87    ///
88    /// For instance, if the `Connection` is used for an HTTP upgrade request,
89    /// it is possible the server sent back the first bytes of the new protocol
90    /// along with the response upgrade.
91    ///
92    /// You will want to check for any existing bytes if you plan to continue
93    /// communicating on the IO object.
94    pub read_buf: Bytes,
95    _inner: (),
96}
97
98/// Gets a pending HTTP upgrade from this message.
99///
100/// This can be called on the following types:
101///
102/// - `http::Request<B>`
103/// - `http::Response<B>`
104/// - `&mut http::Request<B>`
105/// - `&mut http::Response<B>`
106pub fn on<T: sealed::CanUpgrade>(msg: T) -> OnUpgrade {
107    msg.on_upgrade()
108}
109
110#[cfg(any(feature = "http1", feature = "http2"))]
111pub(super) struct Pending {
112    tx: oneshot::Sender<crate::Result<Upgraded>>,
113}
114
115#[cfg(any(feature = "http1", feature = "http2"))]
116pub(super) fn pending() -> (Pending, OnUpgrade) {
117    let (tx, rx) = oneshot::channel();
118    (Pending { tx }, OnUpgrade { rx: Some(rx) })
119}
120
121// ===== impl Upgraded =====
122
123impl Upgraded {
124    #[cfg(any(feature = "http1", feature = "http2", test))]
125    pub(super) fn new<T>(io: T, read_buf: Bytes) -> Self
126    where
127        T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
128    {
129        Upgraded {
130            io: Rewind::new_buffered(Box::new(io), read_buf),
131        }
132    }
133
134    /// Tries to downcast the internal trait object to the type passed.
135    ///
136    /// On success, returns the downcasted parts. On error, returns the
137    /// `Upgraded` back.
138    pub fn downcast<T: AsyncRead + AsyncWrite + Unpin + 'static>(self) -> Result<Parts<T>, Self> {
139        let (io, buf) = self.io.into_inner();
140        match io.__hyper_downcast() {
141            Ok(t) => Ok(Parts {
142                io: *t,
143                read_buf: buf,
144                _inner: (),
145            }),
146            Err(io) => Err(Upgraded {
147                io: Rewind::new_buffered(io, buf),
148            }),
149        }
150    }
151}
152
153impl AsyncRead for Upgraded {
154    fn poll_read(
155        mut self: Pin<&mut Self>,
156        cx: &mut Context<'_>,
157        buf: &mut ReadBuf<'_>,
158    ) -> Poll<io::Result<()>> {
159        Pin::new(&mut self.io).poll_read(cx, buf)
160    }
161}
162
163impl AsyncWrite for Upgraded {
164    fn poll_write(
165        mut self: Pin<&mut Self>,
166        cx: &mut Context<'_>,
167        buf: &[u8],
168    ) -> Poll<io::Result<usize>> {
169        Pin::new(&mut self.io).poll_write(cx, buf)
170    }
171
172    fn poll_write_vectored(
173        mut self: Pin<&mut Self>,
174        cx: &mut Context<'_>,
175        bufs: &[io::IoSlice<'_>],
176    ) -> Poll<io::Result<usize>> {
177        Pin::new(&mut self.io).poll_write_vectored(cx, bufs)
178    }
179
180    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
181        Pin::new(&mut self.io).poll_flush(cx)
182    }
183
184    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
185        Pin::new(&mut self.io).poll_shutdown(cx)
186    }
187
188    fn is_write_vectored(&self) -> bool {
189        self.io.is_write_vectored()
190    }
191}
192
193impl fmt::Debug for Upgraded {
194    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195        f.debug_struct("Upgraded").finish()
196    }
197}
198
199// ===== impl OnUpgrade =====
200
201impl OnUpgrade {
202    pub(super) fn none() -> Self {
203        OnUpgrade { rx: None }
204    }
205
206    #[cfg(feature = "http1")]
207    pub(super) fn is_none(&self) -> bool {
208        self.rx.is_none()
209    }
210}
211
212impl Future for OnUpgrade {
213    type Output = Result<Upgraded, crate::Error>;
214
215    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
216        match self.rx {
217            Some(ref mut rx) => Pin::new(rx).poll(cx).map(|res| match res {
218                Ok(Ok(upgraded)) => Ok(upgraded),
219                Ok(Err(err)) => Err(err),
220                Err(_oneshot_canceled) => Err(crate::Error::new_canceled().with(UpgradeExpected)),
221            }),
222            None => Poll::Ready(Err(crate::Error::new_user_no_upgrade())),
223        }
224    }
225}
226
227impl fmt::Debug for OnUpgrade {
228    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229        f.debug_struct("OnUpgrade").finish()
230    }
231}
232
233// ===== impl Pending =====
234
235#[cfg(any(feature = "http1", feature = "http2"))]
236impl Pending {
237    pub(super) fn fulfill(self, upgraded: Upgraded) {
238        trace!("pending upgrade fulfill");
239        let _ = self.tx.send(Ok(upgraded));
240    }
241
242    #[cfg(feature = "http1")]
243    /// Don't fulfill the pending Upgrade, but instead signal that
244    /// upgrades are handled manually.
245    pub(super) fn manual(self) {
246        trace!("pending upgrade handled manually");
247        let _ = self.tx.send(Err(crate::Error::new_user_manual_upgrade()));
248    }
249}
250
251// ===== impl UpgradeExpected =====
252
253/// Error cause returned when an upgrade was expected but canceled
254/// for whatever reason.
255///
256/// This likely means the actual `Conn` future wasn't polled and upgraded.
257#[derive(Debug)]
258struct UpgradeExpected;
259
260impl fmt::Display for UpgradeExpected {
261    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
262        f.write_str("upgrade expected but not completed")
263    }
264}
265
266impl StdError for UpgradeExpected {}
267
268// ===== impl Io =====
269
270pub(super) trait Io: AsyncRead + AsyncWrite + Unpin + 'static {
271    fn __hyper_type_id(&self) -> TypeId {
272        TypeId::of::<Self>()
273    }
274}
275
276impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Io for T {}
277
278impl dyn Io + Send {
279    fn __hyper_is<T: Io>(&self) -> bool {
280        let t = TypeId::of::<T>();
281        self.__hyper_type_id() == t
282    }
283
284    fn __hyper_downcast<T: Io>(self: Box<Self>) -> Result<Box<T>, Box<Self>> {
285        if self.__hyper_is::<T>() {
286            // Taken from `std::error::Error::downcast()`.
287            unsafe {
288                let raw: *mut dyn Io = Box::into_raw(self);
289                Ok(Box::from_raw(raw as *mut T))
290            }
291        } else {
292            Err(self)
293        }
294    }
295}
296
297mod sealed {
298    use super::OnUpgrade;
299
300    pub trait CanUpgrade {
301        fn on_upgrade(self) -> OnUpgrade;
302    }
303
304    impl<B> CanUpgrade for http::Request<B> {
305        fn on_upgrade(mut self) -> OnUpgrade {
306            self.extensions_mut()
307                .remove::<OnUpgrade>()
308                .unwrap_or_else(OnUpgrade::none)
309        }
310    }
311
312    impl<B> CanUpgrade for &'_ mut http::Request<B> {
313        fn on_upgrade(self) -> OnUpgrade {
314            self.extensions_mut()
315                .remove::<OnUpgrade>()
316                .unwrap_or_else(OnUpgrade::none)
317        }
318    }
319
320    impl<B> CanUpgrade for http::Response<B> {
321        fn on_upgrade(mut self) -> OnUpgrade {
322            self.extensions_mut()
323                .remove::<OnUpgrade>()
324                .unwrap_or_else(OnUpgrade::none)
325        }
326    }
327
328    impl<B> CanUpgrade for &'_ mut http::Response<B> {
329        fn on_upgrade(self) -> OnUpgrade {
330            self.extensions_mut()
331                .remove::<OnUpgrade>()
332                .unwrap_or_else(OnUpgrade::none)
333        }
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    #[test]
342    fn upgraded_downcast() {
343        let upgraded = Upgraded::new(Mock, Bytes::new());
344
345        let upgraded = upgraded.downcast::<std::io::Cursor<Vec<u8>>>().unwrap_err();
346
347        upgraded.downcast::<Mock>().unwrap();
348    }
349
350    // TODO: replace with tokio_test::io when it can test write_buf
351    struct Mock;
352
353    impl AsyncRead for Mock {
354        fn poll_read(
355            self: Pin<&mut Self>,
356            _cx: &mut Context<'_>,
357            _buf: &mut ReadBuf<'_>,
358        ) -> Poll<io::Result<()>> {
359            unreachable!("Mock::poll_read")
360        }
361    }
362
363    impl AsyncWrite for Mock {
364        fn poll_write(
365            self: Pin<&mut Self>,
366            _: &mut Context<'_>,
367            buf: &[u8],
368        ) -> Poll<io::Result<usize>> {
369            // panic!("poll_write shouldn't be called");
370            Poll::Ready(Ok(buf.len()))
371        }
372
373        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
374            unreachable!("Mock::poll_flush")
375        }
376
377        fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
378            unreachable!("Mock::poll_shutdown")
379        }
380    }
381}