tungstenite/handshake/
mod.rs

1//! WebSocket handshake control.
2
3pub mod client;
4pub mod headers;
5pub mod machine;
6pub mod server;
7
8use std::{
9    error::Error as ErrorTrait,
10    fmt,
11    io::{Read, Write},
12};
13
14use sha1::{Digest, Sha1};
15
16use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse};
17use crate::error::Error;
18
19/// A WebSocket handshake.
20#[derive(Debug)]
21pub struct MidHandshake<Role: HandshakeRole> {
22    role: Role,
23    machine: HandshakeMachine<Role::InternalStream>,
24}
25
26impl<Role: HandshakeRole> MidHandshake<Role> {
27    /// Allow access to machine
28    pub fn get_ref(&self) -> &HandshakeMachine<Role::InternalStream> {
29        &self.machine
30    }
31
32    /// Allow mutable access to machine
33    pub fn get_mut(&mut self) -> &mut HandshakeMachine<Role::InternalStream> {
34        &mut self.machine
35    }
36
37    /// Restarts the handshake process.
38    pub fn handshake(mut self) -> Result<Role::FinalResult, HandshakeError<Role>> {
39        let mut mach = self.machine;
40        loop {
41            mach = match mach.single_round()? {
42                RoundResult::WouldBlock(m) => {
43                    return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self }))
44                }
45                RoundResult::Incomplete(m) => m,
46                RoundResult::StageFinished(s) => match self.role.stage_finished(s)? {
47                    ProcessingResult::Continue(m) => m,
48                    ProcessingResult::Done(result) => return Ok(result),
49                },
50            }
51        }
52    }
53}
54
55/// A handshake result.
56pub enum HandshakeError<Role: HandshakeRole> {
57    /// Handshake was interrupted (would block).
58    Interrupted(MidHandshake<Role>),
59    /// Handshake failed.
60    Failure(Error),
61}
62
63impl<Role: HandshakeRole> fmt::Debug for HandshakeError<Role> {
64    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
65        match *self {
66            HandshakeError::Interrupted(_) => write!(f, "HandshakeError::Interrupted(...)"),
67            HandshakeError::Failure(ref e) => write!(f, "HandshakeError::Failure({:?})", e),
68        }
69    }
70}
71
72impl<Role: HandshakeRole> fmt::Display for HandshakeError<Role> {
73    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
74        match *self {
75            HandshakeError::Interrupted(_) => write!(f, "Interrupted handshake (WouldBlock)"),
76            HandshakeError::Failure(ref e) => write!(f, "{}", e),
77        }
78    }
79}
80
81impl<Role: HandshakeRole> ErrorTrait for HandshakeError<Role> {}
82
83impl<Role: HandshakeRole> From<Error> for HandshakeError<Role> {
84    fn from(err: Error) -> Self {
85        HandshakeError::Failure(err)
86    }
87}
88
89/// Handshake role.
90pub trait HandshakeRole {
91    #[doc(hidden)]
92    type IncomingData: TryParse;
93    #[doc(hidden)]
94    type InternalStream: Read + Write;
95    #[doc(hidden)]
96    type FinalResult;
97    #[doc(hidden)]
98    fn stage_finished(
99        &mut self,
100        finish: StageResult<Self::IncomingData, Self::InternalStream>,
101    ) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>, Error>;
102}
103
104/// Stage processing result.
105#[doc(hidden)]
106#[derive(Debug)]
107pub enum ProcessingResult<Stream, FinalResult> {
108    Continue(HandshakeMachine<Stream>),
109    Done(FinalResult),
110}
111
112/// Derive the `Sec-WebSocket-Accept` response header from a `Sec-WebSocket-Key` request header.
113///
114/// This function can be used to perform a handshake before passing a raw TCP stream to
115/// [`WebSocket::from_raw_socket`][crate::protocol::WebSocket::from_raw_socket].
116pub fn derive_accept_key(request_key: &[u8]) -> String {
117    // ... field is constructed by concatenating /key/ ...
118    // ... with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455)
119    const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
120    let mut sha1 = Sha1::default();
121    sha1.update(request_key);
122    sha1.update(WS_GUID);
123    data_encoding::BASE64.encode(&sha1.finalize())
124}
125
126#[cfg(test)]
127mod tests {
128    use super::derive_accept_key;
129
130    #[test]
131    fn key_conversion() {
132        // example from RFC 6455
133        assert_eq!(derive_accept_key(b"dGhlIHNhbXBsZSBub25jZQ=="), "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
134    }
135}