1use std::io::{Read, Write};
3
4use crate::{
5 client::{client_with_config, uri_mode, IntoClientRequest},
6 error::UrlError,
7 handshake::client::Response,
8 protocol::WebSocketConfig,
9 stream::MaybeTlsStream,
10 ClientHandshake, Error, HandshakeError, Result, WebSocket,
11};
12
13#[non_exhaustive]
17#[allow(missing_debug_implementations)]
18pub enum Connector {
19 Plain,
21 #[cfg(feature = "native-tls")]
23 NativeTls(native_tls_crate::TlsConnector),
24 #[cfg(feature = "__rustls-tls")]
26 Rustls(std::sync::Arc<rustls::ClientConfig>),
27}
28
29mod encryption {
30 #[cfg(feature = "native-tls")]
31 pub mod native_tls {
32 use native_tls_crate::{HandshakeError as TlsHandshakeError, TlsConnector};
33
34 use std::io::{Read, Write};
35
36 use crate::{
37 error::TlsError,
38 stream::{MaybeTlsStream, Mode},
39 Error, Result,
40 };
41
42 pub fn wrap_stream<S>(
43 socket: S,
44 domain: &str,
45 mode: Mode,
46 tls_connector: Option<TlsConnector>,
47 ) -> Result<MaybeTlsStream<S>>
48 where
49 S: Read + Write,
50 {
51 match mode {
52 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
53 Mode::Tls => {
54 let try_connector = tls_connector.map_or_else(TlsConnector::new, Ok);
55 let connector = try_connector.map_err(TlsError::Native)?;
56 let connected = connector.connect(domain, socket);
57 match connected {
58 Err(e) => match e {
59 TlsHandshakeError::Failure(f) => Err(Error::Tls(f.into())),
60 TlsHandshakeError::WouldBlock(_) => {
61 panic!("Bug: TLS handshake not blocked")
62 }
63 },
64 Ok(s) => Ok(MaybeTlsStream::NativeTls(s)),
65 }
66 }
67 }
68 }
69 }
70
71 #[cfg(feature = "__rustls-tls")]
72 pub mod rustls {
73 use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned};
74 use rustls_pki_types::ServerName;
75
76 use std::{
77 convert::TryFrom,
78 io::{Read, Write},
79 sync::Arc,
80 };
81
82 use crate::{
83 error::TlsError,
84 stream::{MaybeTlsStream, Mode},
85 Result,
86 };
87
88 pub fn wrap_stream<S>(
89 socket: S,
90 domain: &str,
91 mode: Mode,
92 tls_connector: Option<Arc<ClientConfig>>,
93 ) -> Result<MaybeTlsStream<S>>
94 where
95 S: Read + Write,
96 {
97 match mode {
98 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
99 Mode::Tls => {
100 let config = match tls_connector {
101 Some(config) => config,
102 None => {
103 #[allow(unused_mut)]
104 let mut root_store = RootCertStore::empty();
105
106 #[cfg(feature = "rustls-tls-native-roots")]
107 {
108 let native_certs = rustls_native_certs::load_native_certs()?;
109 let total_number = native_certs.len();
110 let (number_added, number_ignored) =
111 root_store.add_parsable_certificates(native_certs);
112 log::debug!("Added {number_added}/{total_number} native root certificates (ignored {number_ignored})");
113 }
114 #[cfg(feature = "rustls-tls-webpki-roots")]
115 {
116 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
117 }
118
119 Arc::new(
120 ClientConfig::builder()
121 .with_root_certificates(root_store)
122 .with_no_client_auth(),
123 )
124 }
125 };
126 let domain = ServerName::try_from(domain)
127 .map_err(|_| TlsError::InvalidDnsName)?
128 .to_owned();
129 let client = ClientConnection::new(config, domain).map_err(TlsError::Rustls)?;
130 let stream = StreamOwned::new(client, socket);
131
132 Ok(MaybeTlsStream::Rustls(stream))
133 }
134 }
135 }
136 }
137
138 pub mod plain {
139 use std::io::{Read, Write};
140
141 use crate::{
142 error::UrlError,
143 stream::{MaybeTlsStream, Mode},
144 Error, Result,
145 };
146
147 pub fn wrap_stream<S>(socket: S, mode: Mode) -> Result<MaybeTlsStream<S>>
148 where
149 S: Read + Write,
150 {
151 match mode {
152 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
153 Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
154 }
155 }
156 }
157}
158
159type TlsHandshakeError<S> = HandshakeError<ClientHandshake<MaybeTlsStream<S>>>;
160
161pub fn client_tls<R, S>(
164 request: R,
165 stream: S,
166) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>>
167where
168 R: IntoClientRequest,
169 S: Read + Write,
170{
171 client_tls_with_config(request, stream, None, None)
172}
173
174pub fn client_tls_with_config<R, S>(
180 request: R,
181 stream: S,
182 config: Option<WebSocketConfig>,
183 connector: Option<Connector>,
184) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>>
185where
186 R: IntoClientRequest,
187 S: Read + Write,
188{
189 let request = request.into_client_request()?;
190
191 #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
192 let domain = match request.uri().host() {
193 Some(d) => Ok(d.to_string()),
194 None => Err(Error::Url(UrlError::NoHostName)),
195 }?;
196
197 let mode = uri_mode(request.uri())?;
198
199 let stream = match connector {
200 Some(conn) => match conn {
201 #[cfg(feature = "native-tls")]
202 Connector::NativeTls(conn) => {
203 self::encryption::native_tls::wrap_stream(stream, &domain, mode, Some(conn))
204 }
205 #[cfg(feature = "__rustls-tls")]
206 Connector::Rustls(conn) => {
207 self::encryption::rustls::wrap_stream(stream, &domain, mode, Some(conn))
208 }
209 Connector::Plain => self::encryption::plain::wrap_stream(stream, mode),
210 },
211 None => {
212 #[cfg(feature = "native-tls")]
213 {
214 self::encryption::native_tls::wrap_stream(stream, &domain, mode, None)
215 }
216 #[cfg(all(feature = "__rustls-tls", not(feature = "native-tls")))]
217 {
218 self::encryption::rustls::wrap_stream(stream, &domain, mode, None)
219 }
220 #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
221 {
222 self::encryption::plain::wrap_stream(stream, mode)
223 }
224 }
225 }?;
226
227 client_with_config(request, stream, config)
228}