tokio_rustls/
client.rs

1use std::io;
2#[cfg(unix)]
3use std::os::unix::io::{AsRawFd, RawFd};
4#[cfg(windows)]
5use std::os::windows::io::{AsRawSocket, RawSocket};
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use rustls::ClientConnection;
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11
12use crate::common::{IoSession, Stream, TlsState};
13
14/// A wrapper around an underlying raw stream which implements the TLS or SSL
15/// protocol.
16#[derive(Debug)]
17pub struct TlsStream<IO> {
18    pub(crate) io: IO,
19    pub(crate) session: ClientConnection,
20    pub(crate) state: TlsState,
21
22    #[cfg(feature = "early-data")]
23    pub(crate) early_waker: Option<std::task::Waker>,
24}
25
26impl<IO> TlsStream<IO> {
27    #[inline]
28    pub fn get_ref(&self) -> (&IO, &ClientConnection) {
29        (&self.io, &self.session)
30    }
31
32    #[inline]
33    pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) {
34        (&mut self.io, &mut self.session)
35    }
36
37    #[inline]
38    pub fn into_inner(self) -> (IO, ClientConnection) {
39        (self.io, self.session)
40    }
41}
42
43#[cfg(unix)]
44impl<S> AsRawFd for TlsStream<S>
45where
46    S: AsRawFd,
47{
48    fn as_raw_fd(&self) -> RawFd {
49        self.get_ref().0.as_raw_fd()
50    }
51}
52
53#[cfg(windows)]
54impl<S> AsRawSocket for TlsStream<S>
55where
56    S: AsRawSocket,
57{
58    fn as_raw_socket(&self) -> RawSocket {
59        self.get_ref().0.as_raw_socket()
60    }
61}
62
63impl<IO> IoSession for TlsStream<IO> {
64    type Io = IO;
65    type Session = ClientConnection;
66
67    #[inline]
68    fn skip_handshake(&self) -> bool {
69        self.state.is_early_data()
70    }
71
72    #[inline]
73    fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
74        (&mut self.state, &mut self.io, &mut self.session)
75    }
76
77    #[inline]
78    fn into_io(self) -> Self::Io {
79        self.io
80    }
81}
82
83impl<IO> AsyncRead for TlsStream<IO>
84where
85    IO: AsyncRead + AsyncWrite + Unpin,
86{
87    fn poll_read(
88        self: Pin<&mut Self>,
89        cx: &mut Context<'_>,
90        buf: &mut ReadBuf<'_>,
91    ) -> Poll<io::Result<()>> {
92        match self.state {
93            #[cfg(feature = "early-data")]
94            TlsState::EarlyData(..) => {
95                let this = self.get_mut();
96
97                // In the EarlyData state, we have not really established a Tls connection.
98                // Before writing data through `AsyncWrite` and completing the tls handshake,
99                // we ignore read readiness and return to pending.
100                //
101                // In order to avoid event loss,
102                // we need to register a waker and wake it up after tls is connected.
103                if this
104                    .early_waker
105                    .as_ref()
106                    .filter(|waker| cx.waker().will_wake(waker))
107                    .is_none()
108                {
109                    this.early_waker = Some(cx.waker().clone());
110                }
111
112                Poll::Pending
113            }
114            TlsState::Stream | TlsState::WriteShutdown => {
115                let this = self.get_mut();
116                let mut stream =
117                    Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
118                let prev = buf.remaining();
119
120                match stream.as_mut_pin().poll_read(cx, buf) {
121                    Poll::Ready(Ok(())) => {
122                        if prev == buf.remaining() || stream.eof {
123                            this.state.shutdown_read();
124                        }
125
126                        Poll::Ready(Ok(()))
127                    }
128                    Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
129                        this.state.shutdown_read();
130                        Poll::Ready(Err(err))
131                    }
132                    output => output,
133                }
134            }
135            TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())),
136        }
137    }
138}
139
140impl<IO> AsyncWrite for TlsStream<IO>
141where
142    IO: AsyncRead + AsyncWrite + Unpin,
143{
144    /// Note: that it does not guarantee the final data to be sent.
145    /// To be cautious, you must manually call `flush`.
146    fn poll_write(
147        self: Pin<&mut Self>,
148        cx: &mut Context<'_>,
149        buf: &[u8],
150    ) -> Poll<io::Result<usize>> {
151        let this = self.get_mut();
152        let mut stream =
153            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
154
155        #[allow(clippy::match_single_binding)]
156        match this.state {
157            #[cfg(feature = "early-data")]
158            TlsState::EarlyData(ref mut pos, ref mut data) => {
159                use std::io::Write;
160
161                // write early data
162                if let Some(mut early_data) = stream.session.early_data() {
163                    let len = match early_data.write(buf) {
164                        Ok(n) => n,
165                        Err(err) => return Poll::Ready(Err(err)),
166                    };
167                    if len != 0 {
168                        data.extend_from_slice(&buf[..len]);
169                        return Poll::Ready(Ok(len));
170                    }
171                }
172
173                // complete handshake
174                while stream.session.is_handshaking() {
175                    ready!(stream.handshake(cx))?;
176                }
177
178                // write early data (fallback)
179                if !stream.session.is_early_data_accepted() {
180                    while *pos < data.len() {
181                        let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
182                        *pos += len;
183                    }
184                }
185
186                // end
187                this.state = TlsState::Stream;
188
189                if let Some(waker) = this.early_waker.take() {
190                    waker.wake();
191                }
192
193                stream.as_mut_pin().poll_write(cx, buf)
194            }
195            _ => stream.as_mut_pin().poll_write(cx, buf),
196        }
197    }
198
199    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
200        let this = self.get_mut();
201        let mut stream =
202            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
203
204        #[cfg(feature = "early-data")]
205        {
206            if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
207                // complete handshake
208                while stream.session.is_handshaking() {
209                    ready!(stream.handshake(cx))?;
210                }
211
212                // write early data (fallback)
213                if !stream.session.is_early_data_accepted() {
214                    while *pos < data.len() {
215                        let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
216                        *pos += len;
217                    }
218                }
219
220                this.state = TlsState::Stream;
221
222                if let Some(waker) = this.early_waker.take() {
223                    waker.wake();
224                }
225            }
226        }
227
228        stream.as_mut_pin().poll_flush(cx)
229    }
230
231    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
232        #[cfg(feature = "early-data")]
233        {
234            // complete handshake
235            if matches!(self.state, TlsState::EarlyData(..)) {
236                ready!(self.as_mut().poll_flush(cx))?;
237            }
238        }
239
240        if self.state.writeable() {
241            self.session.send_close_notify();
242            self.state.shutdown_write();
243        }
244
245        let this = self.get_mut();
246        let mut stream =
247            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
248        stream.as_mut_pin().poll_shutdown(cx)
249    }
250}