tokio_tungstenite/
lib.rs

1//! Async WebSocket usage.
2//!
3//! This library is an implementation of WebSocket handshakes and streams. It
4//! is based on the crate which implements all required WebSocket protocol
5//! logic. So this crate basically just brings tokio support / tokio integration
6//! to it.
7//!
8//! Each WebSocket stream implements the required `Stream` and `Sink` traits,
9//! so the socket is just a stream of messages coming in and going out.
10
11#![deny(missing_docs, unused_must_use, unused_mut, unused_imports, unused_import_braces)]
12
13pub use tungstenite;
14
15mod compat;
16#[cfg(feature = "connect")]
17mod connect;
18mod handshake;
19#[cfg(feature = "stream")]
20mod stream;
21#[cfg(any(feature = "native-tls", feature = "__rustls-tls", feature = "connect"))]
22mod tls;
23
24use std::io::{Read, Write};
25
26use compat::{cvt, AllowStd, ContextWaker};
27use futures_util::{
28    sink::{Sink, SinkExt},
29    stream::{FusedStream, Stream},
30};
31use log::*;
32use std::{
33    pin::Pin,
34    task::{Context, Poll},
35};
36use tokio::io::{AsyncRead, AsyncWrite};
37
38#[cfg(feature = "handshake")]
39use tungstenite::{
40    client::IntoClientRequest,
41    handshake::{
42        client::{ClientHandshake, Response},
43        server::{Callback, NoCallback},
44        HandshakeError,
45    },
46};
47use tungstenite::{
48    error::Error as WsError,
49    protocol::{Message, Role, WebSocket, WebSocketConfig},
50};
51
52#[cfg(any(feature = "native-tls", feature = "__rustls-tls", feature = "connect"))]
53pub use tls::Connector;
54#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
55pub use tls::{client_async_tls, client_async_tls_with_config};
56
57#[cfg(feature = "connect")]
58pub use connect::{connect_async, connect_async_with_config};
59
60#[cfg(all(any(feature = "native-tls", feature = "__rustls-tls"), feature = "connect"))]
61pub use connect::connect_async_tls_with_config;
62
63#[cfg(feature = "stream")]
64pub use stream::MaybeTlsStream;
65
66use tungstenite::protocol::CloseFrame;
67
68/// Creates a WebSocket handshake from a request and a stream.
69/// For convenience, the user may call this with a url string, a URL,
70/// or a `Request`. Calling with `Request` allows the user to add
71/// a WebSocket protocol or other custom headers.
72///
73/// Internally, this custom creates a handshake representation and returns
74/// a future representing the resolution of the WebSocket handshake. The
75/// returned future will resolve to either `WebSocketStream<S>` or `Error`
76/// depending on whether the handshake is successful.
77///
78/// This is typically used for clients who have already established, for
79/// example, a TCP connection to the remote server.
80#[cfg(feature = "handshake")]
81pub async fn client_async<'a, R, S>(
82    request: R,
83    stream: S,
84) -> Result<(WebSocketStream<S>, Response), WsError>
85where
86    R: IntoClientRequest + Unpin,
87    S: AsyncRead + AsyncWrite + Unpin,
88{
89    client_async_with_config(request, stream, None).await
90}
91
92/// The same as `client_async()` but the one can specify a websocket configuration.
93/// Please refer to `client_async()` for more details.
94#[cfg(feature = "handshake")]
95pub async fn client_async_with_config<'a, R, S>(
96    request: R,
97    stream: S,
98    config: Option<WebSocketConfig>,
99) -> Result<(WebSocketStream<S>, Response), WsError>
100where
101    R: IntoClientRequest + Unpin,
102    S: AsyncRead + AsyncWrite + Unpin,
103{
104    let f = handshake::client_handshake(stream, move |allow_std| {
105        let request = request.into_client_request()?;
106        let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
107        cli_handshake.handshake()
108    });
109    f.await.map_err(|e| match e {
110        HandshakeError::Failure(e) => e,
111        e => WsError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())),
112    })
113}
114
115/// Accepts a new WebSocket connection with the provided stream.
116///
117/// This function will internally call `server::accept` to create a
118/// handshake representation and returns a future representing the
119/// resolution of the WebSocket handshake. The returned future will resolve
120/// to either `WebSocketStream<S>` or `Error` depending if it's successful
121/// or not.
122///
123/// This is typically used after a socket has been accepted from a
124/// `TcpListener`. That socket is then passed to this function to perform
125/// the server half of the accepting a client's websocket connection.
126#[cfg(feature = "handshake")]
127pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
128where
129    S: AsyncRead + AsyncWrite + Unpin,
130{
131    accept_hdr_async(stream, NoCallback).await
132}
133
134/// The same as `accept_async()` but the one can specify a websocket configuration.
135/// Please refer to `accept_async()` for more details.
136#[cfg(feature = "handshake")]
137pub async fn accept_async_with_config<S>(
138    stream: S,
139    config: Option<WebSocketConfig>,
140) -> Result<WebSocketStream<S>, WsError>
141where
142    S: AsyncRead + AsyncWrite + Unpin,
143{
144    accept_hdr_async_with_config(stream, NoCallback, config).await
145}
146
147/// Accepts a new WebSocket connection with the provided stream.
148///
149/// This function does the same as `accept_async()` but accepts an extra callback
150/// for header processing. The callback receives headers of the incoming
151/// requests and is able to add extra headers to the reply.
152#[cfg(feature = "handshake")]
153pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
154where
155    S: AsyncRead + AsyncWrite + Unpin,
156    C: Callback + Unpin,
157{
158    accept_hdr_async_with_config(stream, callback, None).await
159}
160
161/// The same as `accept_hdr_async()` but the one can specify a websocket configuration.
162/// Please refer to `accept_hdr_async()` for more details.
163#[cfg(feature = "handshake")]
164pub async fn accept_hdr_async_with_config<S, C>(
165    stream: S,
166    callback: C,
167    config: Option<WebSocketConfig>,
168) -> Result<WebSocketStream<S>, WsError>
169where
170    S: AsyncRead + AsyncWrite + Unpin,
171    C: Callback + Unpin,
172{
173    let f = handshake::server_handshake(stream, move |allow_std| {
174        tungstenite::accept_hdr_with_config(allow_std, callback, config)
175    });
176    f.await.map_err(|e| match e {
177        HandshakeError::Failure(e) => e,
178        e => WsError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())),
179    })
180}
181
182/// A wrapper around an underlying raw stream which implements the WebSocket
183/// protocol.
184///
185/// A `WebSocketStream<S>` represents a handshake that has been completed
186/// successfully and both the server and the client are ready for receiving
187/// and sending data. Message from a `WebSocketStream<S>` are accessible
188/// through the respective `Stream` and `Sink`. Check more information about
189/// them in `futures-rs` crate documentation or have a look on the examples
190/// and unit tests for this crate.
191#[derive(Debug)]
192pub struct WebSocketStream<S> {
193    inner: WebSocket<AllowStd<S>>,
194    closing: bool,
195    ended: bool,
196    /// Tungstenite is probably ready to receive more data.
197    ///
198    /// `false` once start_send hits `WouldBlock` errors.
199    /// `true` initially and after `flush`ing.
200    ready: bool,
201}
202
203impl<S> WebSocketStream<S> {
204    /// Convert a raw socket into a WebSocketStream without performing a
205    /// handshake.
206    pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
207    where
208        S: AsyncRead + AsyncWrite + Unpin,
209    {
210        handshake::without_handshake(stream, move |allow_std| {
211            WebSocket::from_raw_socket(allow_std, role, config)
212        })
213        .await
214    }
215
216    /// Convert a raw socket into a WebSocketStream without performing a
217    /// handshake.
218    pub async fn from_partially_read(
219        stream: S,
220        part: Vec<u8>,
221        role: Role,
222        config: Option<WebSocketConfig>,
223    ) -> Self
224    where
225        S: AsyncRead + AsyncWrite + Unpin,
226    {
227        handshake::without_handshake(stream, move |allow_std| {
228            WebSocket::from_partially_read(allow_std, part, role, config)
229        })
230        .await
231    }
232
233    pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
234        Self { inner: ws, closing: false, ended: false, ready: true }
235    }
236
237    fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
238    where
239        S: Unpin,
240        F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
241        AllowStd<S>: Read + Write,
242    {
243        trace!("{}:{} WebSocketStream.with_context", file!(), line!());
244        if let Some((kind, ctx)) = ctx {
245            self.inner.get_mut().set_waker(kind, ctx.waker());
246        }
247        f(&mut self.inner)
248    }
249
250    /// Returns a shared reference to the inner stream.
251    pub fn get_ref(&self) -> &S
252    where
253        S: AsyncRead + AsyncWrite + Unpin,
254    {
255        self.inner.get_ref().get_ref()
256    }
257
258    /// Returns a mutable reference to the inner stream.
259    pub fn get_mut(&mut self) -> &mut S
260    where
261        S: AsyncRead + AsyncWrite + Unpin,
262    {
263        self.inner.get_mut().get_mut()
264    }
265
266    /// Returns a reference to the configuration of the tungstenite stream.
267    pub fn get_config(&self) -> &WebSocketConfig {
268        self.inner.get_config()
269    }
270
271    /// Close the underlying web socket
272    pub async fn close(&mut self, msg: Option<CloseFrame<'_>>) -> Result<(), WsError>
273    where
274        S: AsyncRead + AsyncWrite + Unpin,
275    {
276        let msg = msg.map(|msg| msg.into_owned());
277        self.send(Message::Close(msg)).await
278    }
279}
280
281impl<T> Stream for WebSocketStream<T>
282where
283    T: AsyncRead + AsyncWrite + Unpin,
284{
285    type Item = Result<Message, WsError>;
286
287    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
288        trace!("{}:{} Stream.poll_next", file!(), line!());
289
290        // The connection has been closed or a critical error has occurred.
291        // We have already returned the error to the user, the `Stream` is unusable,
292        // so we assume that the stream has been "fused".
293        if self.ended {
294            return Poll::Ready(None);
295        }
296
297        match futures_util::ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
298            trace!("{}:{} Stream.with_context poll_next -> read()", file!(), line!());
299            cvt(s.read())
300        })) {
301            Ok(v) => Poll::Ready(Some(Ok(v))),
302            Err(e) => {
303                self.ended = true;
304                if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
305                    Poll::Ready(None)
306                } else {
307                    Poll::Ready(Some(Err(e)))
308                }
309            }
310        }
311    }
312}
313
314impl<T> FusedStream for WebSocketStream<T>
315where
316    T: AsyncRead + AsyncWrite + Unpin,
317{
318    fn is_terminated(&self) -> bool {
319        self.ended
320    }
321}
322
323impl<T> Sink<Message> for WebSocketStream<T>
324where
325    T: AsyncRead + AsyncWrite + Unpin,
326{
327    type Error = WsError;
328
329    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
330        if self.ready {
331            Poll::Ready(Ok(()))
332        } else {
333            // Currently blocked so try to flush the blockage away
334            (*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush())).map(|r| {
335                self.ready = true;
336                r
337            })
338        }
339    }
340
341    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
342        match (*self).with_context(None, |s| s.write(item)) {
343            Ok(()) => {
344                self.ready = true;
345                Ok(())
346            }
347            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
348                // the message was accepted and queued so not an error
349                // but `poll_ready` will now start trying to flush the block
350                self.ready = false;
351                Ok(())
352            }
353            Err(e) => {
354                self.ready = true;
355                debug!("websocket start_send error: {}", e);
356                Err(e)
357            }
358        }
359    }
360
361    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
362        (*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush())).map(|r| {
363            self.ready = true;
364            match r {
365                // WebSocket connection has just been closed. Flushing completed, not an error.
366                Err(WsError::ConnectionClosed) => Ok(()),
367                other => other,
368            }
369        })
370    }
371
372    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
373        self.ready = true;
374        let res = if self.closing {
375            // After queueing it, we call `flush` to drive the close handshake to completion.
376            (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
377        } else {
378            (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
379        };
380
381        match res {
382            Ok(()) => Poll::Ready(Ok(())),
383            Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
384            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
385                trace!("WouldBlock");
386                self.closing = true;
387                Poll::Pending
388            }
389            Err(err) => {
390                debug!("websocket close error: {}", err);
391                Poll::Ready(Err(err))
392            }
393        }
394    }
395}
396
397/// Get a domain from an URL.
398#[cfg(any(feature = "connect", feature = "native-tls", feature = "__rustls-tls"))]
399#[inline]
400fn domain(request: &tungstenite::handshake::client::Request) -> Result<String, WsError> {
401    match request.uri().host() {
402        // rustls expects IPv6 addresses without the surrounding [] brackets
403        #[cfg(feature = "__rustls-tls")]
404        Some(d) if d.starts_with('[') && d.ends_with(']') => Ok(d[1..d.len() - 1].to_string()),
405        Some(d) => Ok(d.to_string()),
406        None => Err(WsError::Url(tungstenite::error::UrlError::NoHostName)),
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    #[cfg(feature = "connect")]
413    use crate::stream::MaybeTlsStream;
414    use crate::{compat::AllowStd, WebSocketStream};
415    use std::io::{Read, Write};
416    #[cfg(feature = "connect")]
417    use tokio::io::{AsyncReadExt, AsyncWriteExt};
418
419    fn is_read<T: Read>() {}
420    fn is_write<T: Write>() {}
421    #[cfg(feature = "connect")]
422    fn is_async_read<T: AsyncReadExt>() {}
423    #[cfg(feature = "connect")]
424    fn is_async_write<T: AsyncWriteExt>() {}
425    fn is_unpin<T: Unpin>() {}
426
427    #[test]
428    fn web_socket_stream_has_traits() {
429        is_read::<AllowStd<tokio::net::TcpStream>>();
430        is_write::<AllowStd<tokio::net::TcpStream>>();
431
432        #[cfg(feature = "connect")]
433        is_async_read::<MaybeTlsStream<tokio::net::TcpStream>>();
434        #[cfg(feature = "connect")]
435        is_async_write::<MaybeTlsStream<tokio::net::TcpStream>>();
436
437        is_unpin::<WebSocketStream<tokio::net::TcpStream>>();
438        #[cfg(feature = "connect")]
439        is_unpin::<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>>();
440    }
441}