1use std::future::Future;
40use std::io;
41#[cfg(unix)]
42use std::os::unix::io::{AsRawFd, RawFd};
43#[cfg(windows)]
44use std::os::windows::io::{AsRawSocket, RawSocket};
45use std::pin::Pin;
46use std::sync::Arc;
47use std::task::{Context, Poll};
48
49pub use rustls;
50use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
51use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
52
53macro_rules! ready {
54 ( $e:expr ) => {
55 match $e {
56 std::task::Poll::Ready(t) => t,
57 std::task::Poll::Pending => return std::task::Poll::Pending,
58 }
59 };
60}
61
62pub mod client;
63mod common;
64use common::{MidHandshake, TlsState};
65pub mod server;
66
67#[derive(Clone)]
69pub struct TlsConnector {
70 inner: Arc<ClientConfig>,
71 #[cfg(feature = "early-data")]
72 early_data: bool,
73}
74
75#[derive(Clone)]
77pub struct TlsAcceptor {
78 inner: Arc<ServerConfig>,
79}
80
81impl From<Arc<ClientConfig>> for TlsConnector {
82 fn from(inner: Arc<ClientConfig>) -> TlsConnector {
83 TlsConnector {
84 inner,
85 #[cfg(feature = "early-data")]
86 early_data: false,
87 }
88 }
89}
90
91impl From<Arc<ServerConfig>> for TlsAcceptor {
92 fn from(inner: Arc<ServerConfig>) -> TlsAcceptor {
93 TlsAcceptor { inner }
94 }
95}
96
97impl TlsConnector {
98 #[cfg(feature = "early-data")]
103 pub fn early_data(mut self, flag: bool) -> TlsConnector {
104 self.early_data = flag;
105 self
106 }
107
108 #[inline]
109 pub fn connect<IO>(&self, domain: pki_types::ServerName<'static>, stream: IO) -> Connect<IO>
110 where
111 IO: AsyncRead + AsyncWrite + Unpin,
112 {
113 self.connect_with(domain, stream, |_| ())
114 }
115
116 pub fn connect_with<IO, F>(
117 &self,
118 domain: pki_types::ServerName<'static>,
119 stream: IO,
120 f: F,
121 ) -> Connect<IO>
122 where
123 IO: AsyncRead + AsyncWrite + Unpin,
124 F: FnOnce(&mut ClientConnection),
125 {
126 let mut session = match ClientConnection::new(self.inner.clone(), domain) {
127 Ok(session) => session,
128 Err(error) => {
129 return Connect(MidHandshake::Error {
130 io: stream,
131 error: io::Error::new(io::ErrorKind::Other, error),
134 });
135 }
136 };
137 f(&mut session);
138
139 Connect(MidHandshake::Handshaking(client::TlsStream {
140 io: stream,
141
142 #[cfg(not(feature = "early-data"))]
143 state: TlsState::Stream,
144
145 #[cfg(feature = "early-data")]
146 state: if self.early_data && session.early_data().is_some() {
147 TlsState::EarlyData(0, Vec::new())
148 } else {
149 TlsState::Stream
150 },
151
152 #[cfg(feature = "early-data")]
153 early_waker: None,
154
155 session,
156 }))
157 }
158}
159
160impl TlsAcceptor {
161 #[inline]
162 pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
163 where
164 IO: AsyncRead + AsyncWrite + Unpin,
165 {
166 self.accept_with(stream, |_| ())
167 }
168
169 pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
170 where
171 IO: AsyncRead + AsyncWrite + Unpin,
172 F: FnOnce(&mut ServerConnection),
173 {
174 let mut session = match ServerConnection::new(self.inner.clone()) {
175 Ok(session) => session,
176 Err(error) => {
177 return Accept(MidHandshake::Error {
178 io: stream,
179 error: io::Error::new(io::ErrorKind::Other, error),
182 });
183 }
184 };
185 f(&mut session);
186
187 Accept(MidHandshake::Handshaking(server::TlsStream {
188 session,
189 io: stream,
190 state: TlsState::Stream,
191 }))
192 }
193}
194
195pub struct LazyConfigAcceptor<IO> {
196 acceptor: rustls::server::Acceptor,
197 io: Option<IO>,
198}
199
200impl<IO> LazyConfigAcceptor<IO>
201where
202 IO: AsyncRead + AsyncWrite + Unpin,
203{
204 #[inline]
205 pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
206 Self {
207 acceptor,
208 io: Some(io),
209 }
210 }
211
212 pub fn take_io(&mut self) -> Option<IO> {
254 self.io.take()
255 }
256}
257
258impl<IO> Future for LazyConfigAcceptor<IO>
259where
260 IO: AsyncRead + AsyncWrite + Unpin,
261{
262 type Output = Result<StartHandshake<IO>, io::Error>;
263
264 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
265 let this = self.get_mut();
266 loop {
267 let io = match this.io.as_mut() {
268 Some(io) => io,
269 None => {
270 return Poll::Ready(Err(io::Error::new(
271 io::ErrorKind::Other,
272 "acceptor cannot be polled after acceptance",
273 )))
274 }
275 };
276
277 let mut reader = common::SyncReadAdapter { io, cx };
278 match this.acceptor.read_tls(&mut reader) {
279 Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(),
280 Ok(_) => {}
281 Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
282 Err(e) => return Err(e).into(),
283 }
284
285 match this.acceptor.accept() {
286 Ok(Some(accepted)) => {
287 let io = this.io.take().unwrap();
288 return Poll::Ready(Ok(StartHandshake { accepted, io }));
289 }
290 Ok(None) => continue,
291 Err(err) => {
292 return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err)))
293 }
294 }
295 }
296 }
297}
298
299pub struct StartHandshake<IO> {
300 accepted: rustls::server::Accepted,
301 io: IO,
302}
303
304impl<IO> StartHandshake<IO>
305where
306 IO: AsyncRead + AsyncWrite + Unpin,
307{
308 pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
309 self.accepted.client_hello()
310 }
311
312 pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
313 self.into_stream_with(config, |_| ())
314 }
315
316 pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
317 where
318 F: FnOnce(&mut ServerConnection),
319 {
320 let mut conn = match self.accepted.into_connection(config) {
321 Ok(conn) => conn,
322 Err(error) => {
323 return Accept(MidHandshake::Error {
324 io: self.io,
325 error: io::Error::new(io::ErrorKind::Other, error),
328 });
329 }
330 };
331 f(&mut conn);
332
333 Accept(MidHandshake::Handshaking(server::TlsStream {
334 session: conn,
335 io: self.io,
336 state: TlsState::Stream,
337 }))
338 }
339}
340
341pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
344
345pub struct Accept<IO>(MidHandshake<server::TlsStream<IO>>);
348
349pub struct FallibleConnect<IO>(MidHandshake<client::TlsStream<IO>>);
351
352pub struct FallibleAccept<IO>(MidHandshake<server::TlsStream<IO>>);
354
355impl<IO> Connect<IO> {
356 #[inline]
357 pub fn into_fallible(self) -> FallibleConnect<IO> {
358 FallibleConnect(self.0)
359 }
360
361 pub fn get_ref(&self) -> Option<&IO> {
362 match &self.0 {
363 MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
364 MidHandshake::Error { io, .. } => Some(io),
365 MidHandshake::End => None,
366 }
367 }
368
369 pub fn get_mut(&mut self) -> Option<&mut IO> {
370 match &mut self.0 {
371 MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
372 MidHandshake::Error { io, .. } => Some(io),
373 MidHandshake::End => None,
374 }
375 }
376}
377
378impl<IO> Accept<IO> {
379 #[inline]
380 pub fn into_fallible(self) -> FallibleAccept<IO> {
381 FallibleAccept(self.0)
382 }
383
384 pub fn get_ref(&self) -> Option<&IO> {
385 match &self.0 {
386 MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
387 MidHandshake::Error { io, .. } => Some(io),
388 MidHandshake::End => None,
389 }
390 }
391
392 pub fn get_mut(&mut self) -> Option<&mut IO> {
393 match &mut self.0 {
394 MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
395 MidHandshake::Error { io, .. } => Some(io),
396 MidHandshake::End => None,
397 }
398 }
399}
400
401impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
402 type Output = io::Result<client::TlsStream<IO>>;
403
404 #[inline]
405 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
406 Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
407 }
408}
409
410impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
411 type Output = io::Result<server::TlsStream<IO>>;
412
413 #[inline]
414 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
415 Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
416 }
417}
418
419impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleConnect<IO> {
420 type Output = Result<client::TlsStream<IO>, (io::Error, IO)>;
421
422 #[inline]
423 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
424 Pin::new(&mut self.0).poll(cx)
425 }
426}
427
428impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleAccept<IO> {
429 type Output = Result<server::TlsStream<IO>, (io::Error, IO)>;
430
431 #[inline]
432 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
433 Pin::new(&mut self.0).poll(cx)
434 }
435}
436
437#[allow(clippy::large_enum_variant)] #[derive(Debug)]
443pub enum TlsStream<T> {
444 Client(client::TlsStream<T>),
445 Server(server::TlsStream<T>),
446}
447
448impl<T> TlsStream<T> {
449 pub fn get_ref(&self) -> (&T, &CommonState) {
450 use TlsStream::*;
451 match self {
452 Client(io) => {
453 let (io, session) = io.get_ref();
454 (io, session)
455 }
456 Server(io) => {
457 let (io, session) = io.get_ref();
458 (io, session)
459 }
460 }
461 }
462
463 pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) {
464 use TlsStream::*;
465 match self {
466 Client(io) => {
467 let (io, session) = io.get_mut();
468 (io, &mut *session)
469 }
470 Server(io) => {
471 let (io, session) = io.get_mut();
472 (io, &mut *session)
473 }
474 }
475 }
476}
477
478impl<T> From<client::TlsStream<T>> for TlsStream<T> {
479 fn from(s: client::TlsStream<T>) -> Self {
480 Self::Client(s)
481 }
482}
483
484impl<T> From<server::TlsStream<T>> for TlsStream<T> {
485 fn from(s: server::TlsStream<T>) -> Self {
486 Self::Server(s)
487 }
488}
489
490#[cfg(unix)]
491impl<S> AsRawFd for TlsStream<S>
492where
493 S: AsRawFd,
494{
495 fn as_raw_fd(&self) -> RawFd {
496 self.get_ref().0.as_raw_fd()
497 }
498}
499
500#[cfg(windows)]
501impl<S> AsRawSocket for TlsStream<S>
502where
503 S: AsRawSocket,
504{
505 fn as_raw_socket(&self) -> RawSocket {
506 self.get_ref().0.as_raw_socket()
507 }
508}
509
510impl<T> AsyncRead for TlsStream<T>
511where
512 T: AsyncRead + AsyncWrite + Unpin,
513{
514 #[inline]
515 fn poll_read(
516 self: Pin<&mut Self>,
517 cx: &mut Context<'_>,
518 buf: &mut ReadBuf<'_>,
519 ) -> Poll<io::Result<()>> {
520 match self.get_mut() {
521 TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf),
522 TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf),
523 }
524 }
525}
526
527impl<T> AsyncWrite for TlsStream<T>
528where
529 T: AsyncRead + AsyncWrite + Unpin,
530{
531 #[inline]
532 fn poll_write(
533 self: Pin<&mut Self>,
534 cx: &mut Context<'_>,
535 buf: &[u8],
536 ) -> Poll<io::Result<usize>> {
537 match self.get_mut() {
538 TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf),
539 TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf),
540 }
541 }
542
543 #[inline]
544 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
545 match self.get_mut() {
546 TlsStream::Client(x) => Pin::new(x).poll_flush(cx),
547 TlsStream::Server(x) => Pin::new(x).poll_flush(cx),
548 }
549 }
550
551 #[inline]
552 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
553 match self.get_mut() {
554 TlsStream::Client(x) => Pin::new(x).poll_shutdown(cx),
555 TlsStream::Server(x) => Pin::new(x).poll_shutdown(cx),
556 }
557 }
558}