1use std::any::TypeId;
43use std::error::Error as StdError;
44use std::fmt;
45use std::future::Future;
46use std::io;
47use std::marker::Unpin;
48use std::pin::Pin;
49use std::task::{Context, Poll};
50
51use bytes::Bytes;
52use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
53use tokio::sync::oneshot;
54#[cfg(any(feature = "http1", feature = "http2"))]
55use tracing::trace;
56
57use crate::common::io::Rewind;
58
59pub struct Upgraded {
68 io: Rewind<Box<dyn Io + Send>>,
69}
70
71pub struct OnUpgrade {
75 rx: Option<oneshot::Receiver<crate::Result<Upgraded>>>,
76}
77
78#[derive(Debug)]
83pub struct Parts<T> {
84 pub io: T,
86 pub read_buf: Bytes,
95 _inner: (),
96}
97
98pub fn on<T: sealed::CanUpgrade>(msg: T) -> OnUpgrade {
107 msg.on_upgrade()
108}
109
110#[cfg(any(feature = "http1", feature = "http2"))]
111pub(super) struct Pending {
112 tx: oneshot::Sender<crate::Result<Upgraded>>,
113}
114
115#[cfg(any(feature = "http1", feature = "http2"))]
116pub(super) fn pending() -> (Pending, OnUpgrade) {
117 let (tx, rx) = oneshot::channel();
118 (Pending { tx }, OnUpgrade { rx: Some(rx) })
119}
120
121impl Upgraded {
124 #[cfg(any(feature = "http1", feature = "http2", test))]
125 pub(super) fn new<T>(io: T, read_buf: Bytes) -> Self
126 where
127 T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
128 {
129 Upgraded {
130 io: Rewind::new_buffered(Box::new(io), read_buf),
131 }
132 }
133
134 pub fn downcast<T: AsyncRead + AsyncWrite + Unpin + 'static>(self) -> Result<Parts<T>, Self> {
139 let (io, buf) = self.io.into_inner();
140 match io.__hyper_downcast() {
141 Ok(t) => Ok(Parts {
142 io: *t,
143 read_buf: buf,
144 _inner: (),
145 }),
146 Err(io) => Err(Upgraded {
147 io: Rewind::new_buffered(io, buf),
148 }),
149 }
150 }
151}
152
153impl AsyncRead for Upgraded {
154 fn poll_read(
155 mut self: Pin<&mut Self>,
156 cx: &mut Context<'_>,
157 buf: &mut ReadBuf<'_>,
158 ) -> Poll<io::Result<()>> {
159 Pin::new(&mut self.io).poll_read(cx, buf)
160 }
161}
162
163impl AsyncWrite for Upgraded {
164 fn poll_write(
165 mut self: Pin<&mut Self>,
166 cx: &mut Context<'_>,
167 buf: &[u8],
168 ) -> Poll<io::Result<usize>> {
169 Pin::new(&mut self.io).poll_write(cx, buf)
170 }
171
172 fn poll_write_vectored(
173 mut self: Pin<&mut Self>,
174 cx: &mut Context<'_>,
175 bufs: &[io::IoSlice<'_>],
176 ) -> Poll<io::Result<usize>> {
177 Pin::new(&mut self.io).poll_write_vectored(cx, bufs)
178 }
179
180 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
181 Pin::new(&mut self.io).poll_flush(cx)
182 }
183
184 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
185 Pin::new(&mut self.io).poll_shutdown(cx)
186 }
187
188 fn is_write_vectored(&self) -> bool {
189 self.io.is_write_vectored()
190 }
191}
192
193impl fmt::Debug for Upgraded {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 f.debug_struct("Upgraded").finish()
196 }
197}
198
199impl OnUpgrade {
202 pub(super) fn none() -> Self {
203 OnUpgrade { rx: None }
204 }
205
206 #[cfg(feature = "http1")]
207 pub(super) fn is_none(&self) -> bool {
208 self.rx.is_none()
209 }
210}
211
212impl Future for OnUpgrade {
213 type Output = Result<Upgraded, crate::Error>;
214
215 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
216 match self.rx {
217 Some(ref mut rx) => Pin::new(rx).poll(cx).map(|res| match res {
218 Ok(Ok(upgraded)) => Ok(upgraded),
219 Ok(Err(err)) => Err(err),
220 Err(_oneshot_canceled) => Err(crate::Error::new_canceled().with(UpgradeExpected)),
221 }),
222 None => Poll::Ready(Err(crate::Error::new_user_no_upgrade())),
223 }
224 }
225}
226
227impl fmt::Debug for OnUpgrade {
228 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229 f.debug_struct("OnUpgrade").finish()
230 }
231}
232
233#[cfg(any(feature = "http1", feature = "http2"))]
236impl Pending {
237 pub(super) fn fulfill(self, upgraded: Upgraded) {
238 trace!("pending upgrade fulfill");
239 let _ = self.tx.send(Ok(upgraded));
240 }
241
242 #[cfg(feature = "http1")]
243 pub(super) fn manual(self) {
246 trace!("pending upgrade handled manually");
247 let _ = self.tx.send(Err(crate::Error::new_user_manual_upgrade()));
248 }
249}
250
251#[derive(Debug)]
258struct UpgradeExpected;
259
260impl fmt::Display for UpgradeExpected {
261 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
262 f.write_str("upgrade expected but not completed")
263 }
264}
265
266impl StdError for UpgradeExpected {}
267
268pub(super) trait Io: AsyncRead + AsyncWrite + Unpin + 'static {
271 fn __hyper_type_id(&self) -> TypeId {
272 TypeId::of::<Self>()
273 }
274}
275
276impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Io for T {}
277
278impl dyn Io + Send {
279 fn __hyper_is<T: Io>(&self) -> bool {
280 let t = TypeId::of::<T>();
281 self.__hyper_type_id() == t
282 }
283
284 fn __hyper_downcast<T: Io>(self: Box<Self>) -> Result<Box<T>, Box<Self>> {
285 if self.__hyper_is::<T>() {
286 unsafe {
288 let raw: *mut dyn Io = Box::into_raw(self);
289 Ok(Box::from_raw(raw as *mut T))
290 }
291 } else {
292 Err(self)
293 }
294 }
295}
296
297mod sealed {
298 use super::OnUpgrade;
299
300 pub trait CanUpgrade {
301 fn on_upgrade(self) -> OnUpgrade;
302 }
303
304 impl<B> CanUpgrade for http::Request<B> {
305 fn on_upgrade(mut self) -> OnUpgrade {
306 self.extensions_mut()
307 .remove::<OnUpgrade>()
308 .unwrap_or_else(OnUpgrade::none)
309 }
310 }
311
312 impl<B> CanUpgrade for &'_ mut http::Request<B> {
313 fn on_upgrade(self) -> OnUpgrade {
314 self.extensions_mut()
315 .remove::<OnUpgrade>()
316 .unwrap_or_else(OnUpgrade::none)
317 }
318 }
319
320 impl<B> CanUpgrade for http::Response<B> {
321 fn on_upgrade(mut self) -> OnUpgrade {
322 self.extensions_mut()
323 .remove::<OnUpgrade>()
324 .unwrap_or_else(OnUpgrade::none)
325 }
326 }
327
328 impl<B> CanUpgrade for &'_ mut http::Response<B> {
329 fn on_upgrade(self) -> OnUpgrade {
330 self.extensions_mut()
331 .remove::<OnUpgrade>()
332 .unwrap_or_else(OnUpgrade::none)
333 }
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 #[test]
342 fn upgraded_downcast() {
343 let upgraded = Upgraded::new(Mock, Bytes::new());
344
345 let upgraded = upgraded.downcast::<std::io::Cursor<Vec<u8>>>().unwrap_err();
346
347 upgraded.downcast::<Mock>().unwrap();
348 }
349
350 struct Mock;
352
353 impl AsyncRead for Mock {
354 fn poll_read(
355 self: Pin<&mut Self>,
356 _cx: &mut Context<'_>,
357 _buf: &mut ReadBuf<'_>,
358 ) -> Poll<io::Result<()>> {
359 unreachable!("Mock::poll_read")
360 }
361 }
362
363 impl AsyncWrite for Mock {
364 fn poll_write(
365 self: Pin<&mut Self>,
366 _: &mut Context<'_>,
367 buf: &[u8],
368 ) -> Poll<io::Result<usize>> {
369 Poll::Ready(Ok(buf.len()))
371 }
372
373 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
374 unreachable!("Mock::poll_flush")
375 }
376
377 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
378 unreachable!("Mock::poll_shutdown")
379 }
380 }
381}