tokio_rustls/
lib.rs

1//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/rustls/rustls).
2//!
3//! # Why do I need to call `poll_flush`?
4//!
5//! Most TLS implementations will have an internal buffer to improve throughput,
6//! and rustls is no exception.
7//!
8//! When we write data to `TlsStream`, we always write rustls buffer first,
9//! then take out rustls encrypted data packet, and write it to data channel (like TcpStream).
10//! When data channel is pending, some data may remain in rustls buffer.
11//!
12//! `tokio-rustls` To keep it simple and correct, [TlsStream] will behave like `BufWriter`.
13//! For `TlsStream<TcpStream>`, this means that data written by `poll_write` is not guaranteed to be written to `TcpStream`.
14//! You must call `poll_flush` to ensure that it is written to `TcpStream`.
15//!
16//! You should call `poll_flush` at the appropriate time,
17//! such as when a period of `poll_write` write is complete and there is no more data to write.
18//!
19//! ## Why don't we write during `poll_read`?
20//!
21//! We did this in the early days of `tokio-rustls`, but it caused some bugs.
22//! We can solve these bugs through some solutions, but this will cause performance degradation (reverse false wakeup).
23//!
24//! And reverse write will also prevent us implement full duplex in the future.
25//!
26//! see <https://github.com/tokio-rs/tls/issues/40>
27//!
28//! ## Why can't we handle it like `native-tls`?
29//!
30//! When data channel returns to pending, `native-tls` will falsely report the number of bytes it consumes.
31//! This means that if data written by `poll_write` is not actually written to data channel, it will not return `Ready`.
32//! Thus avoiding the call of `poll_flush`.
33//!
34//! but which does not conform to convention of `AsyncWrite` trait.
35//! This means that if you give inconsistent data in two `poll_write`, it may cause unexpected behavior.
36//!
37//! see <https://github.com/tokio-rs/tls/issues/41>
38
39use std::future::Future;
40use std::io;
41#[cfg(unix)]
42use std::os::unix::io::{AsRawFd, RawFd};
43#[cfg(windows)]
44use std::os::windows::io::{AsRawSocket, RawSocket};
45use std::pin::Pin;
46use std::sync::Arc;
47use std::task::{Context, Poll};
48
49pub use rustls;
50use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
51use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
52
53macro_rules! ready {
54    ( $e:expr ) => {
55        match $e {
56            std::task::Poll::Ready(t) => t,
57            std::task::Poll::Pending => return std::task::Poll::Pending,
58        }
59    };
60}
61
62pub mod client;
63mod common;
64use common::{MidHandshake, TlsState};
65pub mod server;
66
67/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
68#[derive(Clone)]
69pub struct TlsConnector {
70    inner: Arc<ClientConfig>,
71    #[cfg(feature = "early-data")]
72    early_data: bool,
73}
74
75/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method.
76#[derive(Clone)]
77pub struct TlsAcceptor {
78    inner: Arc<ServerConfig>,
79}
80
81impl From<Arc<ClientConfig>> for TlsConnector {
82    fn from(inner: Arc<ClientConfig>) -> TlsConnector {
83        TlsConnector {
84            inner,
85            #[cfg(feature = "early-data")]
86            early_data: false,
87        }
88    }
89}
90
91impl From<Arc<ServerConfig>> for TlsAcceptor {
92    fn from(inner: Arc<ServerConfig>) -> TlsAcceptor {
93        TlsAcceptor { inner }
94    }
95}
96
97impl TlsConnector {
98    /// Enable 0-RTT.
99    ///
100    /// If you want to use 0-RTT,
101    /// You must also set `ClientConfig.enable_early_data` to `true`.
102    #[cfg(feature = "early-data")]
103    pub fn early_data(mut self, flag: bool) -> TlsConnector {
104        self.early_data = flag;
105        self
106    }
107
108    #[inline]
109    pub fn connect<IO>(&self, domain: pki_types::ServerName<'static>, stream: IO) -> Connect<IO>
110    where
111        IO: AsyncRead + AsyncWrite + Unpin,
112    {
113        self.connect_with(domain, stream, |_| ())
114    }
115
116    pub fn connect_with<IO, F>(
117        &self,
118        domain: pki_types::ServerName<'static>,
119        stream: IO,
120        f: F,
121    ) -> Connect<IO>
122    where
123        IO: AsyncRead + AsyncWrite + Unpin,
124        F: FnOnce(&mut ClientConnection),
125    {
126        let mut session = match ClientConnection::new(self.inner.clone(), domain) {
127            Ok(session) => session,
128            Err(error) => {
129                return Connect(MidHandshake::Error {
130                    io: stream,
131                    // TODO(eliza): should this really return an `io::Error`?
132                    // Probably not...
133                    error: io::Error::new(io::ErrorKind::Other, error),
134                });
135            }
136        };
137        f(&mut session);
138
139        Connect(MidHandshake::Handshaking(client::TlsStream {
140            io: stream,
141
142            #[cfg(not(feature = "early-data"))]
143            state: TlsState::Stream,
144
145            #[cfg(feature = "early-data")]
146            state: if self.early_data && session.early_data().is_some() {
147                TlsState::EarlyData(0, Vec::new())
148            } else {
149                TlsState::Stream
150            },
151
152            #[cfg(feature = "early-data")]
153            early_waker: None,
154
155            session,
156        }))
157    }
158}
159
160impl TlsAcceptor {
161    #[inline]
162    pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
163    where
164        IO: AsyncRead + AsyncWrite + Unpin,
165    {
166        self.accept_with(stream, |_| ())
167    }
168
169    pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
170    where
171        IO: AsyncRead + AsyncWrite + Unpin,
172        F: FnOnce(&mut ServerConnection),
173    {
174        let mut session = match ServerConnection::new(self.inner.clone()) {
175            Ok(session) => session,
176            Err(error) => {
177                return Accept(MidHandshake::Error {
178                    io: stream,
179                    // TODO(eliza): should this really return an `io::Error`?
180                    // Probably not...
181                    error: io::Error::new(io::ErrorKind::Other, error),
182                });
183            }
184        };
185        f(&mut session);
186
187        Accept(MidHandshake::Handshaking(server::TlsStream {
188            session,
189            io: stream,
190            state: TlsState::Stream,
191        }))
192    }
193}
194
195pub struct LazyConfigAcceptor<IO> {
196    acceptor: rustls::server::Acceptor,
197    io: Option<IO>,
198}
199
200impl<IO> LazyConfigAcceptor<IO>
201where
202    IO: AsyncRead + AsyncWrite + Unpin,
203{
204    #[inline]
205    pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
206        Self {
207            acceptor,
208            io: Some(io),
209        }
210    }
211
212    /// Takes back the client connection. Will return `None` if called more than once or if the
213    /// connection has been accepted.
214    ///
215    /// # Example
216    ///
217    /// ```no_run
218    /// # fn choose_server_config(
219    /// #     _: rustls::server::ClientHello,
220    /// # ) -> std::sync::Arc<rustls::ServerConfig> {
221    /// #     unimplemented!();
222    /// # }
223    /// # #[allow(unused_variables)]
224    /// # async fn listen() {
225    /// use tokio::io::AsyncWriteExt;
226    /// let listener = tokio::net::TcpListener::bind("127.0.0.1:4443").await.unwrap();
227    /// let (stream, _) = listener.accept().await.unwrap();
228    ///
229    /// let acceptor = tokio_rustls::LazyConfigAcceptor::new(rustls::server::Acceptor::default(), stream);
230    /// tokio::pin!(acceptor);
231    ///
232    /// match acceptor.as_mut().await {
233    ///     Ok(start) => {
234    ///         let clientHello = start.client_hello();
235    ///         let config = choose_server_config(clientHello);
236    ///         let stream = start.into_stream(config).await.unwrap();
237    ///         // Proceed with handling the ServerConnection...
238    ///     }
239    ///     Err(err) => {
240    ///         if let Some(mut stream) = acceptor.take_io() {
241    ///             stream
242    ///                 .write_all(
243    ///                     format!("HTTP/1.1 400 Invalid Input\r\n\r\n\r\n{:?}\n", err)
244    ///                         .as_bytes()
245    ///                 )
246    ///                 .await
247    ///                 .unwrap();
248    ///         }
249    ///     }
250    /// }
251    /// # }
252    /// ```
253    pub fn take_io(&mut self) -> Option<IO> {
254        self.io.take()
255    }
256}
257
258impl<IO> Future for LazyConfigAcceptor<IO>
259where
260    IO: AsyncRead + AsyncWrite + Unpin,
261{
262    type Output = Result<StartHandshake<IO>, io::Error>;
263
264    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
265        let this = self.get_mut();
266        loop {
267            let io = match this.io.as_mut() {
268                Some(io) => io,
269                None => {
270                    return Poll::Ready(Err(io::Error::new(
271                        io::ErrorKind::Other,
272                        "acceptor cannot be polled after acceptance",
273                    )))
274                }
275            };
276
277            let mut reader = common::SyncReadAdapter { io, cx };
278            match this.acceptor.read_tls(&mut reader) {
279                Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(),
280                Ok(_) => {}
281                Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
282                Err(e) => return Err(e).into(),
283            }
284
285            match this.acceptor.accept() {
286                Ok(Some(accepted)) => {
287                    let io = this.io.take().unwrap();
288                    return Poll::Ready(Ok(StartHandshake { accepted, io }));
289                }
290                Ok(None) => continue,
291                Err(err) => {
292                    return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err)))
293                }
294            }
295        }
296    }
297}
298
299pub struct StartHandshake<IO> {
300    accepted: rustls::server::Accepted,
301    io: IO,
302}
303
304impl<IO> StartHandshake<IO>
305where
306    IO: AsyncRead + AsyncWrite + Unpin,
307{
308    pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
309        self.accepted.client_hello()
310    }
311
312    pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
313        self.into_stream_with(config, |_| ())
314    }
315
316    pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
317    where
318        F: FnOnce(&mut ServerConnection),
319    {
320        let mut conn = match self.accepted.into_connection(config) {
321            Ok(conn) => conn,
322            Err(error) => {
323                return Accept(MidHandshake::Error {
324                    io: self.io,
325                    // TODO(eliza): should this really return an `io::Error`?
326                    // Probably not...
327                    error: io::Error::new(io::ErrorKind::Other, error),
328                });
329            }
330        };
331        f(&mut conn);
332
333        Accept(MidHandshake::Handshaking(server::TlsStream {
334            session: conn,
335            io: self.io,
336            state: TlsState::Stream,
337        }))
338    }
339}
340
341/// Future returned from `TlsConnector::connect` which will resolve
342/// once the connection handshake has finished.
343pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
344
345/// Future returned from `TlsAcceptor::accept` which will resolve
346/// once the accept handshake has finished.
347pub struct Accept<IO>(MidHandshake<server::TlsStream<IO>>);
348
349/// Like [Connect], but returns `IO` on failure.
350pub struct FallibleConnect<IO>(MidHandshake<client::TlsStream<IO>>);
351
352/// Like [Accept], but returns `IO` on failure.
353pub struct FallibleAccept<IO>(MidHandshake<server::TlsStream<IO>>);
354
355impl<IO> Connect<IO> {
356    #[inline]
357    pub fn into_fallible(self) -> FallibleConnect<IO> {
358        FallibleConnect(self.0)
359    }
360
361    pub fn get_ref(&self) -> Option<&IO> {
362        match &self.0 {
363            MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
364            MidHandshake::Error { io, .. } => Some(io),
365            MidHandshake::End => None,
366        }
367    }
368
369    pub fn get_mut(&mut self) -> Option<&mut IO> {
370        match &mut self.0 {
371            MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
372            MidHandshake::Error { io, .. } => Some(io),
373            MidHandshake::End => None,
374        }
375    }
376}
377
378impl<IO> Accept<IO> {
379    #[inline]
380    pub fn into_fallible(self) -> FallibleAccept<IO> {
381        FallibleAccept(self.0)
382    }
383
384    pub fn get_ref(&self) -> Option<&IO> {
385        match &self.0 {
386            MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
387            MidHandshake::Error { io, .. } => Some(io),
388            MidHandshake::End => None,
389        }
390    }
391
392    pub fn get_mut(&mut self) -> Option<&mut IO> {
393        match &mut self.0 {
394            MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
395            MidHandshake::Error { io, .. } => Some(io),
396            MidHandshake::End => None,
397        }
398    }
399}
400
401impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
402    type Output = io::Result<client::TlsStream<IO>>;
403
404    #[inline]
405    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
406        Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
407    }
408}
409
410impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
411    type Output = io::Result<server::TlsStream<IO>>;
412
413    #[inline]
414    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
415        Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
416    }
417}
418
419impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleConnect<IO> {
420    type Output = Result<client::TlsStream<IO>, (io::Error, IO)>;
421
422    #[inline]
423    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
424        Pin::new(&mut self.0).poll(cx)
425    }
426}
427
428impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleAccept<IO> {
429    type Output = Result<server::TlsStream<IO>, (io::Error, IO)>;
430
431    #[inline]
432    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
433        Pin::new(&mut self.0).poll(cx)
434    }
435}
436
437/// Unified TLS stream type
438///
439/// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use
440/// a single type to keep both client- and server-initiated TLS-encrypted connections.
441#[allow(clippy::large_enum_variant)] // https://github.com/rust-lang/rust-clippy/issues/9798
442#[derive(Debug)]
443pub enum TlsStream<T> {
444    Client(client::TlsStream<T>),
445    Server(server::TlsStream<T>),
446}
447
448impl<T> TlsStream<T> {
449    pub fn get_ref(&self) -> (&T, &CommonState) {
450        use TlsStream::*;
451        match self {
452            Client(io) => {
453                let (io, session) = io.get_ref();
454                (io, session)
455            }
456            Server(io) => {
457                let (io, session) = io.get_ref();
458                (io, session)
459            }
460        }
461    }
462
463    pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) {
464        use TlsStream::*;
465        match self {
466            Client(io) => {
467                let (io, session) = io.get_mut();
468                (io, &mut *session)
469            }
470            Server(io) => {
471                let (io, session) = io.get_mut();
472                (io, &mut *session)
473            }
474        }
475    }
476}
477
478impl<T> From<client::TlsStream<T>> for TlsStream<T> {
479    fn from(s: client::TlsStream<T>) -> Self {
480        Self::Client(s)
481    }
482}
483
484impl<T> From<server::TlsStream<T>> for TlsStream<T> {
485    fn from(s: server::TlsStream<T>) -> Self {
486        Self::Server(s)
487    }
488}
489
490#[cfg(unix)]
491impl<S> AsRawFd for TlsStream<S>
492where
493    S: AsRawFd,
494{
495    fn as_raw_fd(&self) -> RawFd {
496        self.get_ref().0.as_raw_fd()
497    }
498}
499
500#[cfg(windows)]
501impl<S> AsRawSocket for TlsStream<S>
502where
503    S: AsRawSocket,
504{
505    fn as_raw_socket(&self) -> RawSocket {
506        self.get_ref().0.as_raw_socket()
507    }
508}
509
510impl<T> AsyncRead for TlsStream<T>
511where
512    T: AsyncRead + AsyncWrite + Unpin,
513{
514    #[inline]
515    fn poll_read(
516        self: Pin<&mut Self>,
517        cx: &mut Context<'_>,
518        buf: &mut ReadBuf<'_>,
519    ) -> Poll<io::Result<()>> {
520        match self.get_mut() {
521            TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf),
522            TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf),
523        }
524    }
525}
526
527impl<T> AsyncWrite for TlsStream<T>
528where
529    T: AsyncRead + AsyncWrite + Unpin,
530{
531    #[inline]
532    fn poll_write(
533        self: Pin<&mut Self>,
534        cx: &mut Context<'_>,
535        buf: &[u8],
536    ) -> Poll<io::Result<usize>> {
537        match self.get_mut() {
538            TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf),
539            TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf),
540        }
541    }
542
543    #[inline]
544    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
545        match self.get_mut() {
546            TlsStream::Client(x) => Pin::new(x).poll_flush(cx),
547            TlsStream::Server(x) => Pin::new(x).poll_flush(cx),
548        }
549    }
550
551    #[inline]
552    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
553        match self.get_mut() {
554            TlsStream::Client(x) => Pin::new(x).poll_shutdown(cx),
555            TlsStream::Server(x) => Pin::new(x).poll_shutdown(cx),
556        }
557    }
558}