tungstenite/
client.rs

1//! Methods to connect to a WebSocket as a client.
2
3use std::{
4    io::{Read, Write},
5    net::{SocketAddr, TcpStream, ToSocketAddrs},
6    result::Result as StdResult,
7};
8
9use http::{request::Parts, Uri};
10use log::*;
11
12use url::Url;
13
14use crate::{
15    handshake::client::{generate_key, Request, Response},
16    protocol::WebSocketConfig,
17    stream::MaybeTlsStream,
18};
19
20use crate::{
21    error::{Error, Result, UrlError},
22    handshake::{client::ClientHandshake, HandshakeError},
23    protocol::WebSocket,
24    stream::{Mode, NoDelay},
25};
26
27/// Connect to the given WebSocket in blocking mode.
28///
29/// Uses a websocket configuration passed as an argument to the function. Calling it with `None` is
30/// equal to calling `connect()` function.
31///
32/// The URL may be either ws:// or wss://.
33/// To support wss:// URLs, you must activate the TLS feature on the crate level. Please refer to the
34/// project's [README][readme] for more information on available features.
35///
36/// This function "just works" for those who wants a simple blocking solution
37/// similar to `std::net::TcpStream`. If you want a non-blocking or other
38/// custom stream, call `client` instead.
39///
40/// This function uses `native_tls` or `rustls` to do TLS depending on the feature flags enabled. If
41/// you want to use other TLS libraries, use `client` instead. There is no need to enable any of
42/// the `*-tls` features if you don't call `connect` since it's the only function that uses them.
43///
44/// [readme]: https://github.com/snapview/tungstenite-rs/#features
45pub fn connect_with_config<Req: IntoClientRequest>(
46    request: Req,
47    config: Option<WebSocketConfig>,
48    max_redirects: u8,
49) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
50    fn try_client_handshake(
51        request: Request,
52        config: Option<WebSocketConfig>,
53    ) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
54        let uri = request.uri();
55        let mode = uri_mode(uri)?;
56        let host = request.uri().host().ok_or(Error::Url(UrlError::NoHostName))?;
57        let host = if host.starts_with('[') { &host[1..host.len() - 1] } else { host };
58        let port = uri.port_u16().unwrap_or(match mode {
59            Mode::Plain => 80,
60            Mode::Tls => 443,
61        });
62        let addrs = (host, port).to_socket_addrs()?;
63        let mut stream = connect_to_some(addrs.as_slice(), request.uri())?;
64        NoDelay::set_nodelay(&mut stream, true)?;
65
66        #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
67        let client = client_with_config(request, MaybeTlsStream::Plain(stream), config);
68        #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
69        let client = crate::tls::client_tls_with_config(request, stream, config, None);
70
71        client.map_err(|e| match e {
72            HandshakeError::Failure(f) => f,
73            HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
74        })
75    }
76
77    fn create_request(parts: &Parts, uri: &Uri) -> Request {
78        let mut builder =
79            Request::builder().uri(uri.clone()).method(parts.method.clone()).version(parts.version);
80        *builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone();
81        builder.body(()).expect("Failed to create `Request`")
82    }
83
84    let (parts, _) = request.into_client_request()?.into_parts();
85    let mut uri = parts.uri.clone();
86
87    for attempt in 0..(max_redirects + 1) {
88        let request = create_request(&parts, &uri);
89
90        match try_client_handshake(request, config) {
91            Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => {
92                if let Some(location) = res.headers().get("Location") {
93                    uri = location.to_str()?.parse::<Uri>()?;
94                    debug!("Redirecting to {:?}", uri);
95                    continue;
96                } else {
97                    warn!("No `Location` found in redirect");
98                    return Err(Error::Http(res));
99                }
100            }
101            other => return other,
102        }
103    }
104
105    unreachable!("Bug in a redirect handling logic")
106}
107
108/// Connect to the given WebSocket in blocking mode.
109///
110/// The URL may be either ws:// or wss://.
111/// To support wss:// URLs, feature `native-tls` or `rustls-tls` must be turned on.
112///
113/// This function "just works" for those who wants a simple blocking solution
114/// similar to `std::net::TcpStream`. If you want a non-blocking or other
115/// custom stream, call `client` instead.
116///
117/// This function uses `native_tls` or `rustls` to do TLS depending on the feature flags enabled. If
118/// you want to use other TLS libraries, use `client` instead. There is no need to enable any of
119/// the `*-tls` features if you don't call `connect` since it's the only function that uses them.
120pub fn connect<Req: IntoClientRequest>(
121    request: Req,
122) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
123    connect_with_config(request, None, 3)
124}
125
126fn connect_to_some(addrs: &[SocketAddr], uri: &Uri) -> Result<TcpStream> {
127    for addr in addrs {
128        debug!("Trying to contact {} at {}...", uri, addr);
129        if let Ok(stream) = TcpStream::connect(addr) {
130            return Ok(stream);
131        }
132    }
133    Err(Error::Url(UrlError::UnableToConnect(uri.to_string())))
134}
135
136/// Get the mode of the given URL.
137///
138/// This function may be used to ease the creation of custom TLS streams
139/// in non-blocking algorithms or for use with TLS libraries other than `native_tls` or `rustls`.
140pub fn uri_mode(uri: &Uri) -> Result<Mode> {
141    match uri.scheme_str() {
142        Some("ws") => Ok(Mode::Plain),
143        Some("wss") => Ok(Mode::Tls),
144        _ => Err(Error::Url(UrlError::UnsupportedUrlScheme)),
145    }
146}
147
148/// Do the client handshake over the given stream given a web socket configuration. Passing `None`
149/// as configuration is equal to calling `client()` function.
150///
151/// Use this function if you need a nonblocking handshake support or if you
152/// want to use a custom stream like `mio::net::TcpStream` or `openssl::ssl::SslStream`.
153/// Any stream supporting `Read + Write` will do.
154pub fn client_with_config<Stream, Req>(
155    request: Req,
156    stream: Stream,
157    config: Option<WebSocketConfig>,
158) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
159where
160    Stream: Read + Write,
161    Req: IntoClientRequest,
162{
163    ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake()
164}
165
166/// Do the client handshake over the given stream.
167///
168/// Use this function if you need a nonblocking handshake support or if you
169/// want to use a custom stream like `mio::net::TcpStream` or `openssl::ssl::SslStream`.
170/// Any stream supporting `Read + Write` will do.
171pub fn client<Stream, Req>(
172    request: Req,
173    stream: Stream,
174) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
175where
176    Stream: Read + Write,
177    Req: IntoClientRequest,
178{
179    client_with_config(request, stream, None)
180}
181
182/// Trait for converting various types into HTTP requests used for a client connection.
183///
184/// This trait is implemented by default for string slices, strings, `url::Url`, `http::Uri` and
185/// `http::Request<()>`. Note that the implementation for `http::Request<()>` is trivial and will
186/// simply take your request and pass it as is further without altering any headers or URLs, so
187/// be aware of this. If you just want to connect to the endpoint with a certain URL, better pass
188/// a regular string containing the URL in which case `tungstenite-rs` will take care for generating
189/// the proper `http::Request<()>` for you.
190pub trait IntoClientRequest {
191    /// Convert into a `Request` that can be used for a client connection.
192    fn into_client_request(self) -> Result<Request>;
193}
194
195impl<'a> IntoClientRequest for &'a str {
196    fn into_client_request(self) -> Result<Request> {
197        self.parse::<Uri>()?.into_client_request()
198    }
199}
200
201impl<'a> IntoClientRequest for &'a String {
202    fn into_client_request(self) -> Result<Request> {
203        <&str as IntoClientRequest>::into_client_request(self)
204    }
205}
206
207impl IntoClientRequest for String {
208    fn into_client_request(self) -> Result<Request> {
209        <&str as IntoClientRequest>::into_client_request(&self)
210    }
211}
212
213impl<'a> IntoClientRequest for &'a Uri {
214    fn into_client_request(self) -> Result<Request> {
215        self.clone().into_client_request()
216    }
217}
218
219impl IntoClientRequest for Uri {
220    fn into_client_request(self) -> Result<Request> {
221        let authority = self.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str();
222        let host = authority
223            .find('@')
224            .map(|idx| authority.split_at(idx + 1).1)
225            .unwrap_or_else(|| authority);
226
227        if host.is_empty() {
228            return Err(Error::Url(UrlError::EmptyHostName));
229        }
230
231        let req = Request::builder()
232            .method("GET")
233            .header("Host", host)
234            .header("Connection", "Upgrade")
235            .header("Upgrade", "websocket")
236            .header("Sec-WebSocket-Version", "13")
237            .header("Sec-WebSocket-Key", generate_key())
238            .uri(self)
239            .body(())?;
240        Ok(req)
241    }
242}
243
244impl<'a> IntoClientRequest for &'a Url {
245    fn into_client_request(self) -> Result<Request> {
246        self.as_str().into_client_request()
247    }
248}
249
250impl IntoClientRequest for Url {
251    fn into_client_request(self) -> Result<Request> {
252        self.as_str().into_client_request()
253    }
254}
255
256impl IntoClientRequest for Request {
257    fn into_client_request(self) -> Result<Request> {
258        Ok(self)
259    }
260}
261
262impl<'h, 'b> IntoClientRequest for httparse::Request<'h, 'b> {
263    fn into_client_request(self) -> Result<Request> {
264        use crate::handshake::headers::FromHttparse;
265        Request::from_httparse(self)
266    }
267}