1use std::{
4 io::{Read, Write},
5 net::{SocketAddr, TcpStream, ToSocketAddrs},
6 result::Result as StdResult,
7};
8
9use http::{request::Parts, Uri};
10use log::*;
11
12use url::Url;
13
14use crate::{
15 handshake::client::{generate_key, Request, Response},
16 protocol::WebSocketConfig,
17 stream::MaybeTlsStream,
18};
19
20use crate::{
21 error::{Error, Result, UrlError},
22 handshake::{client::ClientHandshake, HandshakeError},
23 protocol::WebSocket,
24 stream::{Mode, NoDelay},
25};
26
27pub fn connect_with_config<Req: IntoClientRequest>(
46 request: Req,
47 config: Option<WebSocketConfig>,
48 max_redirects: u8,
49) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
50 fn try_client_handshake(
51 request: Request,
52 config: Option<WebSocketConfig>,
53 ) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
54 let uri = request.uri();
55 let mode = uri_mode(uri)?;
56 let host = request.uri().host().ok_or(Error::Url(UrlError::NoHostName))?;
57 let host = if host.starts_with('[') { &host[1..host.len() - 1] } else { host };
58 let port = uri.port_u16().unwrap_or(match mode {
59 Mode::Plain => 80,
60 Mode::Tls => 443,
61 });
62 let addrs = (host, port).to_socket_addrs()?;
63 let mut stream = connect_to_some(addrs.as_slice(), request.uri())?;
64 NoDelay::set_nodelay(&mut stream, true)?;
65
66 #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
67 let client = client_with_config(request, MaybeTlsStream::Plain(stream), config);
68 #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
69 let client = crate::tls::client_tls_with_config(request, stream, config, None);
70
71 client.map_err(|e| match e {
72 HandshakeError::Failure(f) => f,
73 HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
74 })
75 }
76
77 fn create_request(parts: &Parts, uri: &Uri) -> Request {
78 let mut builder =
79 Request::builder().uri(uri.clone()).method(parts.method.clone()).version(parts.version);
80 *builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone();
81 builder.body(()).expect("Failed to create `Request`")
82 }
83
84 let (parts, _) = request.into_client_request()?.into_parts();
85 let mut uri = parts.uri.clone();
86
87 for attempt in 0..(max_redirects + 1) {
88 let request = create_request(&parts, &uri);
89
90 match try_client_handshake(request, config) {
91 Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => {
92 if let Some(location) = res.headers().get("Location") {
93 uri = location.to_str()?.parse::<Uri>()?;
94 debug!("Redirecting to {:?}", uri);
95 continue;
96 } else {
97 warn!("No `Location` found in redirect");
98 return Err(Error::Http(res));
99 }
100 }
101 other => return other,
102 }
103 }
104
105 unreachable!("Bug in a redirect handling logic")
106}
107
108pub fn connect<Req: IntoClientRequest>(
121 request: Req,
122) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
123 connect_with_config(request, None, 3)
124}
125
126fn connect_to_some(addrs: &[SocketAddr], uri: &Uri) -> Result<TcpStream> {
127 for addr in addrs {
128 debug!("Trying to contact {} at {}...", uri, addr);
129 if let Ok(stream) = TcpStream::connect(addr) {
130 return Ok(stream);
131 }
132 }
133 Err(Error::Url(UrlError::UnableToConnect(uri.to_string())))
134}
135
136pub fn uri_mode(uri: &Uri) -> Result<Mode> {
141 match uri.scheme_str() {
142 Some("ws") => Ok(Mode::Plain),
143 Some("wss") => Ok(Mode::Tls),
144 _ => Err(Error::Url(UrlError::UnsupportedUrlScheme)),
145 }
146}
147
148pub fn client_with_config<Stream, Req>(
155 request: Req,
156 stream: Stream,
157 config: Option<WebSocketConfig>,
158) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
159where
160 Stream: Read + Write,
161 Req: IntoClientRequest,
162{
163 ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake()
164}
165
166pub fn client<Stream, Req>(
172 request: Req,
173 stream: Stream,
174) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
175where
176 Stream: Read + Write,
177 Req: IntoClientRequest,
178{
179 client_with_config(request, stream, None)
180}
181
182pub trait IntoClientRequest {
191 fn into_client_request(self) -> Result<Request>;
193}
194
195impl<'a> IntoClientRequest for &'a str {
196 fn into_client_request(self) -> Result<Request> {
197 self.parse::<Uri>()?.into_client_request()
198 }
199}
200
201impl<'a> IntoClientRequest for &'a String {
202 fn into_client_request(self) -> Result<Request> {
203 <&str as IntoClientRequest>::into_client_request(self)
204 }
205}
206
207impl IntoClientRequest for String {
208 fn into_client_request(self) -> Result<Request> {
209 <&str as IntoClientRequest>::into_client_request(&self)
210 }
211}
212
213impl<'a> IntoClientRequest for &'a Uri {
214 fn into_client_request(self) -> Result<Request> {
215 self.clone().into_client_request()
216 }
217}
218
219impl IntoClientRequest for Uri {
220 fn into_client_request(self) -> Result<Request> {
221 let authority = self.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str();
222 let host = authority
223 .find('@')
224 .map(|idx| authority.split_at(idx + 1).1)
225 .unwrap_or_else(|| authority);
226
227 if host.is_empty() {
228 return Err(Error::Url(UrlError::EmptyHostName));
229 }
230
231 let req = Request::builder()
232 .method("GET")
233 .header("Host", host)
234 .header("Connection", "Upgrade")
235 .header("Upgrade", "websocket")
236 .header("Sec-WebSocket-Version", "13")
237 .header("Sec-WebSocket-Key", generate_key())
238 .uri(self)
239 .body(())?;
240 Ok(req)
241 }
242}
243
244impl<'a> IntoClientRequest for &'a Url {
245 fn into_client_request(self) -> Result<Request> {
246 self.as_str().into_client_request()
247 }
248}
249
250impl IntoClientRequest for Url {
251 fn into_client_request(self) -> Result<Request> {
252 self.as_str().into_client_request()
253 }
254}
255
256impl IntoClientRequest for Request {
257 fn into_client_request(self) -> Result<Request> {
258 Ok(self)
259 }
260}
261
262impl<'h, 'b> IntoClientRequest for httparse::Request<'h, 'b> {
263 fn into_client_request(self) -> Result<Request> {
264 use crate::handshake::headers::FromHttparse;
265 Request::from_httparse(self)
266 }
267}