tungstenite/
stream.rs

1//! Convenience wrapper for streams to switch between plain TCP and TLS at runtime.
2//!
3//!  There is no dependency on actual TLS implementations. Everything like
4//! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard
5//! `Read + Write` traits.
6
7#[cfg(feature = "__rustls-tls")]
8use std::ops::Deref;
9use std::{
10    fmt::{self, Debug},
11    io::{Read, Result as IoResult, Write},
12};
13
14use std::net::TcpStream;
15
16#[cfg(feature = "native-tls")]
17use native_tls_crate::TlsStream;
18#[cfg(feature = "__rustls-tls")]
19use rustls::StreamOwned;
20
21/// Stream mode, either plain TCP or TLS.
22#[derive(Clone, Copy, Debug)]
23pub enum Mode {
24    /// Plain mode (`ws://` URL).
25    Plain,
26    /// TLS mode (`wss://` URL).
27    Tls,
28}
29
30/// Trait to switch TCP_NODELAY.
31pub trait NoDelay {
32    /// Set the TCP_NODELAY option to the given value.
33    fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()>;
34}
35
36impl NoDelay for TcpStream {
37    fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
38        TcpStream::set_nodelay(self, nodelay)
39    }
40}
41
42#[cfg(feature = "native-tls")]
43impl<S: Read + Write + NoDelay> NoDelay for TlsStream<S> {
44    fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
45        self.get_mut().set_nodelay(nodelay)
46    }
47}
48
49#[cfg(feature = "__rustls-tls")]
50impl<S, SD, T> NoDelay for StreamOwned<S, T>
51where
52    S: Deref<Target = rustls::ConnectionCommon<SD>>,
53    SD: rustls::SideData,
54    T: Read + Write + NoDelay,
55{
56    fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
57        self.sock.set_nodelay(nodelay)
58    }
59}
60
61/// A stream that might be protected with TLS.
62#[non_exhaustive]
63pub enum MaybeTlsStream<S: Read + Write> {
64    /// Unencrypted socket stream.
65    Plain(S),
66    #[cfg(feature = "native-tls")]
67    /// Encrypted socket stream using `native-tls`.
68    NativeTls(native_tls_crate::TlsStream<S>),
69    #[cfg(feature = "__rustls-tls")]
70    /// Encrypted socket stream using `rustls`.
71    Rustls(rustls::StreamOwned<rustls::ClientConnection, S>),
72}
73
74impl<S: Read + Write + Debug> Debug for MaybeTlsStream<S> {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        match self {
77            Self::Plain(s) => f.debug_tuple("MaybeTlsStream::Plain").field(s).finish(),
78            #[cfg(feature = "native-tls")]
79            Self::NativeTls(s) => f.debug_tuple("MaybeTlsStream::NativeTls").field(s).finish(),
80            #[cfg(feature = "__rustls-tls")]
81            Self::Rustls(s) => {
82                struct RustlsStreamDebug<'a, S: Read + Write>(
83                    &'a rustls::StreamOwned<rustls::ClientConnection, S>,
84                );
85
86                impl<'a, S: Read + Write + Debug> Debug for RustlsStreamDebug<'a, S> {
87                    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88                        f.debug_struct("StreamOwned")
89                            .field("conn", &self.0.conn)
90                            .field("sock", &self.0.sock)
91                            .finish()
92                    }
93                }
94
95                f.debug_tuple("MaybeTlsStream::Rustls").field(&RustlsStreamDebug(s)).finish()
96            }
97        }
98    }
99}
100
101impl<S: Read + Write> Read for MaybeTlsStream<S> {
102    fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
103        match *self {
104            MaybeTlsStream::Plain(ref mut s) => s.read(buf),
105            #[cfg(feature = "native-tls")]
106            MaybeTlsStream::NativeTls(ref mut s) => s.read(buf),
107            #[cfg(feature = "__rustls-tls")]
108            MaybeTlsStream::Rustls(ref mut s) => s.read(buf),
109        }
110    }
111}
112
113impl<S: Read + Write> Write for MaybeTlsStream<S> {
114    fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
115        match *self {
116            MaybeTlsStream::Plain(ref mut s) => s.write(buf),
117            #[cfg(feature = "native-tls")]
118            MaybeTlsStream::NativeTls(ref mut s) => s.write(buf),
119            #[cfg(feature = "__rustls-tls")]
120            MaybeTlsStream::Rustls(ref mut s) => s.write(buf),
121        }
122    }
123
124    fn flush(&mut self) -> IoResult<()> {
125        match *self {
126            MaybeTlsStream::Plain(ref mut s) => s.flush(),
127            #[cfg(feature = "native-tls")]
128            MaybeTlsStream::NativeTls(ref mut s) => s.flush(),
129            #[cfg(feature = "__rustls-tls")]
130            MaybeTlsStream::Rustls(ref mut s) => s.flush(),
131        }
132    }
133}
134
135impl<S: Read + Write + NoDelay> NoDelay for MaybeTlsStream<S> {
136    fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
137        match *self {
138            MaybeTlsStream::Plain(ref mut s) => s.set_nodelay(nodelay),
139            #[cfg(feature = "native-tls")]
140            MaybeTlsStream::NativeTls(ref mut s) => s.set_nodelay(nodelay),
141            #[cfg(feature = "__rustls-tls")]
142            MaybeTlsStream::Rustls(ref mut s) => s.set_nodelay(nodelay),
143        }
144    }
145}