1#![deny(missing_docs, unused_must_use, unused_mut, unused_imports, unused_import_braces)]
12
13pub use tungstenite;
14
15mod compat;
16#[cfg(feature = "connect")]
17mod connect;
18mod handshake;
19#[cfg(feature = "stream")]
20mod stream;
21#[cfg(any(feature = "native-tls", feature = "__rustls-tls", feature = "connect"))]
22mod tls;
23
24use std::io::{Read, Write};
25
26use compat::{cvt, AllowStd, ContextWaker};
27use futures_util::{
28 sink::{Sink, SinkExt},
29 stream::{FusedStream, Stream},
30};
31use log::*;
32use std::{
33 pin::Pin,
34 task::{Context, Poll},
35};
36use tokio::io::{AsyncRead, AsyncWrite};
37
38#[cfg(feature = "handshake")]
39use tungstenite::{
40 client::IntoClientRequest,
41 handshake::{
42 client::{ClientHandshake, Response},
43 server::{Callback, NoCallback},
44 HandshakeError,
45 },
46};
47use tungstenite::{
48 error::Error as WsError,
49 protocol::{Message, Role, WebSocket, WebSocketConfig},
50};
51
52#[cfg(any(feature = "native-tls", feature = "__rustls-tls", feature = "connect"))]
53pub use tls::Connector;
54#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
55pub use tls::{client_async_tls, client_async_tls_with_config};
56
57#[cfg(feature = "connect")]
58pub use connect::{connect_async, connect_async_with_config};
59
60#[cfg(all(any(feature = "native-tls", feature = "__rustls-tls"), feature = "connect"))]
61pub use connect::connect_async_tls_with_config;
62
63#[cfg(feature = "stream")]
64pub use stream::MaybeTlsStream;
65
66use tungstenite::protocol::CloseFrame;
67
68#[cfg(feature = "handshake")]
81pub async fn client_async<'a, R, S>(
82 request: R,
83 stream: S,
84) -> Result<(WebSocketStream<S>, Response), WsError>
85where
86 R: IntoClientRequest + Unpin,
87 S: AsyncRead + AsyncWrite + Unpin,
88{
89 client_async_with_config(request, stream, None).await
90}
91
92#[cfg(feature = "handshake")]
95pub async fn client_async_with_config<'a, R, S>(
96 request: R,
97 stream: S,
98 config: Option<WebSocketConfig>,
99) -> Result<(WebSocketStream<S>, Response), WsError>
100where
101 R: IntoClientRequest + Unpin,
102 S: AsyncRead + AsyncWrite + Unpin,
103{
104 let f = handshake::client_handshake(stream, move |allow_std| {
105 let request = request.into_client_request()?;
106 let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
107 cli_handshake.handshake()
108 });
109 f.await.map_err(|e| match e {
110 HandshakeError::Failure(e) => e,
111 e => WsError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())),
112 })
113}
114
115#[cfg(feature = "handshake")]
127pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
128where
129 S: AsyncRead + AsyncWrite + Unpin,
130{
131 accept_hdr_async(stream, NoCallback).await
132}
133
134#[cfg(feature = "handshake")]
137pub async fn accept_async_with_config<S>(
138 stream: S,
139 config: Option<WebSocketConfig>,
140) -> Result<WebSocketStream<S>, WsError>
141where
142 S: AsyncRead + AsyncWrite + Unpin,
143{
144 accept_hdr_async_with_config(stream, NoCallback, config).await
145}
146
147#[cfg(feature = "handshake")]
153pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
154where
155 S: AsyncRead + AsyncWrite + Unpin,
156 C: Callback + Unpin,
157{
158 accept_hdr_async_with_config(stream, callback, None).await
159}
160
161#[cfg(feature = "handshake")]
164pub async fn accept_hdr_async_with_config<S, C>(
165 stream: S,
166 callback: C,
167 config: Option<WebSocketConfig>,
168) -> Result<WebSocketStream<S>, WsError>
169where
170 S: AsyncRead + AsyncWrite + Unpin,
171 C: Callback + Unpin,
172{
173 let f = handshake::server_handshake(stream, move |allow_std| {
174 tungstenite::accept_hdr_with_config(allow_std, callback, config)
175 });
176 f.await.map_err(|e| match e {
177 HandshakeError::Failure(e) => e,
178 e => WsError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())),
179 })
180}
181
182#[derive(Debug)]
192pub struct WebSocketStream<S> {
193 inner: WebSocket<AllowStd<S>>,
194 closing: bool,
195 ended: bool,
196 ready: bool,
201}
202
203impl<S> WebSocketStream<S> {
204 pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
207 where
208 S: AsyncRead + AsyncWrite + Unpin,
209 {
210 handshake::without_handshake(stream, move |allow_std| {
211 WebSocket::from_raw_socket(allow_std, role, config)
212 })
213 .await
214 }
215
216 pub async fn from_partially_read(
219 stream: S,
220 part: Vec<u8>,
221 role: Role,
222 config: Option<WebSocketConfig>,
223 ) -> Self
224 where
225 S: AsyncRead + AsyncWrite + Unpin,
226 {
227 handshake::without_handshake(stream, move |allow_std| {
228 WebSocket::from_partially_read(allow_std, part, role, config)
229 })
230 .await
231 }
232
233 pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
234 Self { inner: ws, closing: false, ended: false, ready: true }
235 }
236
237 fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
238 where
239 S: Unpin,
240 F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
241 AllowStd<S>: Read + Write,
242 {
243 trace!("{}:{} WebSocketStream.with_context", file!(), line!());
244 if let Some((kind, ctx)) = ctx {
245 self.inner.get_mut().set_waker(kind, ctx.waker());
246 }
247 f(&mut self.inner)
248 }
249
250 pub fn get_ref(&self) -> &S
252 where
253 S: AsyncRead + AsyncWrite + Unpin,
254 {
255 self.inner.get_ref().get_ref()
256 }
257
258 pub fn get_mut(&mut self) -> &mut S
260 where
261 S: AsyncRead + AsyncWrite + Unpin,
262 {
263 self.inner.get_mut().get_mut()
264 }
265
266 pub fn get_config(&self) -> &WebSocketConfig {
268 self.inner.get_config()
269 }
270
271 pub async fn close(&mut self, msg: Option<CloseFrame<'_>>) -> Result<(), WsError>
273 where
274 S: AsyncRead + AsyncWrite + Unpin,
275 {
276 let msg = msg.map(|msg| msg.into_owned());
277 self.send(Message::Close(msg)).await
278 }
279}
280
281impl<T> Stream for WebSocketStream<T>
282where
283 T: AsyncRead + AsyncWrite + Unpin,
284{
285 type Item = Result<Message, WsError>;
286
287 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
288 trace!("{}:{} Stream.poll_next", file!(), line!());
289
290 if self.ended {
294 return Poll::Ready(None);
295 }
296
297 match futures_util::ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
298 trace!("{}:{} Stream.with_context poll_next -> read()", file!(), line!());
299 cvt(s.read())
300 })) {
301 Ok(v) => Poll::Ready(Some(Ok(v))),
302 Err(e) => {
303 self.ended = true;
304 if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
305 Poll::Ready(None)
306 } else {
307 Poll::Ready(Some(Err(e)))
308 }
309 }
310 }
311 }
312}
313
314impl<T> FusedStream for WebSocketStream<T>
315where
316 T: AsyncRead + AsyncWrite + Unpin,
317{
318 fn is_terminated(&self) -> bool {
319 self.ended
320 }
321}
322
323impl<T> Sink<Message> for WebSocketStream<T>
324where
325 T: AsyncRead + AsyncWrite + Unpin,
326{
327 type Error = WsError;
328
329 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
330 if self.ready {
331 Poll::Ready(Ok(()))
332 } else {
333 (*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush())).map(|r| {
335 self.ready = true;
336 r
337 })
338 }
339 }
340
341 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
342 match (*self).with_context(None, |s| s.write(item)) {
343 Ok(()) => {
344 self.ready = true;
345 Ok(())
346 }
347 Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
348 self.ready = false;
351 Ok(())
352 }
353 Err(e) => {
354 self.ready = true;
355 debug!("websocket start_send error: {}", e);
356 Err(e)
357 }
358 }
359 }
360
361 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
362 (*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush())).map(|r| {
363 self.ready = true;
364 match r {
365 Err(WsError::ConnectionClosed) => Ok(()),
367 other => other,
368 }
369 })
370 }
371
372 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
373 self.ready = true;
374 let res = if self.closing {
375 (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
377 } else {
378 (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
379 };
380
381 match res {
382 Ok(()) => Poll::Ready(Ok(())),
383 Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
384 Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
385 trace!("WouldBlock");
386 self.closing = true;
387 Poll::Pending
388 }
389 Err(err) => {
390 debug!("websocket close error: {}", err);
391 Poll::Ready(Err(err))
392 }
393 }
394 }
395}
396
397#[cfg(any(feature = "connect", feature = "native-tls", feature = "__rustls-tls"))]
399#[inline]
400fn domain(request: &tungstenite::handshake::client::Request) -> Result<String, WsError> {
401 match request.uri().host() {
402 #[cfg(feature = "__rustls-tls")]
404 Some(d) if d.starts_with('[') && d.ends_with(']') => Ok(d[1..d.len() - 1].to_string()),
405 Some(d) => Ok(d.to_string()),
406 None => Err(WsError::Url(tungstenite::error::UrlError::NoHostName)),
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 #[cfg(feature = "connect")]
413 use crate::stream::MaybeTlsStream;
414 use crate::{compat::AllowStd, WebSocketStream};
415 use std::io::{Read, Write};
416 #[cfg(feature = "connect")]
417 use tokio::io::{AsyncReadExt, AsyncWriteExt};
418
419 fn is_read<T: Read>() {}
420 fn is_write<T: Write>() {}
421 #[cfg(feature = "connect")]
422 fn is_async_read<T: AsyncReadExt>() {}
423 #[cfg(feature = "connect")]
424 fn is_async_write<T: AsyncWriteExt>() {}
425 fn is_unpin<T: Unpin>() {}
426
427 #[test]
428 fn web_socket_stream_has_traits() {
429 is_read::<AllowStd<tokio::net::TcpStream>>();
430 is_write::<AllowStd<tokio::net::TcpStream>>();
431
432 #[cfg(feature = "connect")]
433 is_async_read::<MaybeTlsStream<tokio::net::TcpStream>>();
434 #[cfg(feature = "connect")]
435 is_async_write::<MaybeTlsStream<tokio::net::TcpStream>>();
436
437 is_unpin::<WebSocketStream<tokio::net::TcpStream>>();
438 #[cfg(feature = "connect")]
439 is_unpin::<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>>();
440 }
441}