1use std::io;
2#[cfg(unix)]
3use std::os::unix::io::{AsRawFd, RawFd};
4#[cfg(windows)]
5use std::os::windows::io::{AsRawSocket, RawSocket};
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use rustls::ClientConnection;
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11
12use crate::common::{IoSession, Stream, TlsState};
13
14#[derive(Debug)]
17pub struct TlsStream<IO> {
18 pub(crate) io: IO,
19 pub(crate) session: ClientConnection,
20 pub(crate) state: TlsState,
21
22 #[cfg(feature = "early-data")]
23 pub(crate) early_waker: Option<std::task::Waker>,
24}
25
26impl<IO> TlsStream<IO> {
27 #[inline]
28 pub fn get_ref(&self) -> (&IO, &ClientConnection) {
29 (&self.io, &self.session)
30 }
31
32 #[inline]
33 pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) {
34 (&mut self.io, &mut self.session)
35 }
36
37 #[inline]
38 pub fn into_inner(self) -> (IO, ClientConnection) {
39 (self.io, self.session)
40 }
41}
42
43#[cfg(unix)]
44impl<S> AsRawFd for TlsStream<S>
45where
46 S: AsRawFd,
47{
48 fn as_raw_fd(&self) -> RawFd {
49 self.get_ref().0.as_raw_fd()
50 }
51}
52
53#[cfg(windows)]
54impl<S> AsRawSocket for TlsStream<S>
55where
56 S: AsRawSocket,
57{
58 fn as_raw_socket(&self) -> RawSocket {
59 self.get_ref().0.as_raw_socket()
60 }
61}
62
63impl<IO> IoSession for TlsStream<IO> {
64 type Io = IO;
65 type Session = ClientConnection;
66
67 #[inline]
68 fn skip_handshake(&self) -> bool {
69 self.state.is_early_data()
70 }
71
72 #[inline]
73 fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
74 (&mut self.state, &mut self.io, &mut self.session)
75 }
76
77 #[inline]
78 fn into_io(self) -> Self::Io {
79 self.io
80 }
81}
82
83impl<IO> AsyncRead for TlsStream<IO>
84where
85 IO: AsyncRead + AsyncWrite + Unpin,
86{
87 fn poll_read(
88 self: Pin<&mut Self>,
89 cx: &mut Context<'_>,
90 buf: &mut ReadBuf<'_>,
91 ) -> Poll<io::Result<()>> {
92 match self.state {
93 #[cfg(feature = "early-data")]
94 TlsState::EarlyData(..) => {
95 let this = self.get_mut();
96
97 if this
104 .early_waker
105 .as_ref()
106 .filter(|waker| cx.waker().will_wake(waker))
107 .is_none()
108 {
109 this.early_waker = Some(cx.waker().clone());
110 }
111
112 Poll::Pending
113 }
114 TlsState::Stream | TlsState::WriteShutdown => {
115 let this = self.get_mut();
116 let mut stream =
117 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
118 let prev = buf.remaining();
119
120 match stream.as_mut_pin().poll_read(cx, buf) {
121 Poll::Ready(Ok(())) => {
122 if prev == buf.remaining() || stream.eof {
123 this.state.shutdown_read();
124 }
125
126 Poll::Ready(Ok(()))
127 }
128 Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
129 this.state.shutdown_read();
130 Poll::Ready(Err(err))
131 }
132 output => output,
133 }
134 }
135 TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())),
136 }
137 }
138}
139
140impl<IO> AsyncWrite for TlsStream<IO>
141where
142 IO: AsyncRead + AsyncWrite + Unpin,
143{
144 fn poll_write(
147 self: Pin<&mut Self>,
148 cx: &mut Context<'_>,
149 buf: &[u8],
150 ) -> Poll<io::Result<usize>> {
151 let this = self.get_mut();
152 let mut stream =
153 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
154
155 #[allow(clippy::match_single_binding)]
156 match this.state {
157 #[cfg(feature = "early-data")]
158 TlsState::EarlyData(ref mut pos, ref mut data) => {
159 use std::io::Write;
160
161 if let Some(mut early_data) = stream.session.early_data() {
163 let len = match early_data.write(buf) {
164 Ok(n) => n,
165 Err(err) => return Poll::Ready(Err(err)),
166 };
167 if len != 0 {
168 data.extend_from_slice(&buf[..len]);
169 return Poll::Ready(Ok(len));
170 }
171 }
172
173 while stream.session.is_handshaking() {
175 ready!(stream.handshake(cx))?;
176 }
177
178 if !stream.session.is_early_data_accepted() {
180 while *pos < data.len() {
181 let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
182 *pos += len;
183 }
184 }
185
186 this.state = TlsState::Stream;
188
189 if let Some(waker) = this.early_waker.take() {
190 waker.wake();
191 }
192
193 stream.as_mut_pin().poll_write(cx, buf)
194 }
195 _ => stream.as_mut_pin().poll_write(cx, buf),
196 }
197 }
198
199 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
200 let this = self.get_mut();
201 let mut stream =
202 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
203
204 #[cfg(feature = "early-data")]
205 {
206 if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
207 while stream.session.is_handshaking() {
209 ready!(stream.handshake(cx))?;
210 }
211
212 if !stream.session.is_early_data_accepted() {
214 while *pos < data.len() {
215 let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
216 *pos += len;
217 }
218 }
219
220 this.state = TlsState::Stream;
221
222 if let Some(waker) = this.early_waker.take() {
223 waker.wake();
224 }
225 }
226 }
227
228 stream.as_mut_pin().poll_flush(cx)
229 }
230
231 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
232 #[cfg(feature = "early-data")]
233 {
234 if matches!(self.state, TlsState::EarlyData(..)) {
236 ready!(self.as_mut().poll_flush(cx))?;
237 }
238 }
239
240 if self.state.writeable() {
241 self.session.send_close_notify();
242 self.state.shutdown_write();
243 }
244
245 let this = self.get_mut();
246 let mut stream =
247 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
248 stream.as_mut_pin().poll_shutdown(cx)
249 }
250}