tungstenite/handshake/
server.rs

1//! Server handshake machine.
2
3use std::{
4    io::{self, Read, Write},
5    marker::PhantomData,
6    result::Result as StdResult,
7};
8
9use http::{
10    response::Builder, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
11};
12use httparse::Status;
13use log::*;
14
15use super::{
16    derive_accept_key,
17    headers::{FromHttparse, MAX_HEADERS},
18    machine::{HandshakeMachine, StageResult, TryParse},
19    HandshakeRole, MidHandshake, ProcessingResult,
20};
21use crate::{
22    error::{Error, ProtocolError, Result},
23    protocol::{Role, WebSocket, WebSocketConfig},
24};
25
26/// Server request type.
27pub type Request = HttpRequest<()>;
28
29/// Server response type.
30pub type Response = HttpResponse<()>;
31
32/// Server error response type.
33pub type ErrorResponse = HttpResponse<Option<String>>;
34
35fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> {
36    if request.method() != http::Method::GET {
37        return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
38    }
39
40    if request.version() < http::Version::HTTP_11 {
41        return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
42    }
43
44    if !request
45        .headers()
46        .get("Connection")
47        .and_then(|h| h.to_str().ok())
48        .map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade")))
49        .unwrap_or(false)
50    {
51        return Err(Error::Protocol(ProtocolError::MissingConnectionUpgradeHeader));
52    }
53
54    if !request
55        .headers()
56        .get("Upgrade")
57        .and_then(|h| h.to_str().ok())
58        .map(|h| h.eq_ignore_ascii_case("websocket"))
59        .unwrap_or(false)
60    {
61        return Err(Error::Protocol(ProtocolError::MissingUpgradeWebSocketHeader));
62    }
63
64    if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) {
65        return Err(Error::Protocol(ProtocolError::MissingSecWebSocketVersionHeader));
66    }
67
68    let key = request
69        .headers()
70        .get("Sec-WebSocket-Key")
71        .ok_or(Error::Protocol(ProtocolError::MissingSecWebSocketKey))?;
72
73    let builder = Response::builder()
74        .status(StatusCode::SWITCHING_PROTOCOLS)
75        .version(request.version())
76        .header("Connection", "Upgrade")
77        .header("Upgrade", "websocket")
78        .header("Sec-WebSocket-Accept", derive_accept_key(key.as_bytes()));
79
80    Ok(builder)
81}
82
83/// Create a response for the request.
84pub fn create_response(request: &Request) -> Result<Response> {
85    Ok(create_parts(request)?.body(())?)
86}
87
88/// Create a response for the request with a custom body.
89pub fn create_response_with_body<T>(
90    request: &HttpRequest<T>,
91    generate_body: impl FnOnce() -> T,
92) -> Result<HttpResponse<T>> {
93    Ok(create_parts(request)?.body(generate_body())?)
94}
95
96/// Write `response` to the stream `w`.
97pub fn write_response<T>(mut w: impl io::Write, response: &HttpResponse<T>) -> Result<()> {
98    writeln!(
99        w,
100        "{version:?} {status}\r",
101        version = response.version(),
102        status = response.status()
103    )?;
104
105    for (k, v) in response.headers() {
106        writeln!(w, "{}: {}\r", k, v.to_str()?)?;
107    }
108
109    writeln!(w, "\r")?;
110
111    Ok(())
112}
113
114impl TryParse for Request {
115    fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
116        let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
117        let mut req = httparse::Request::new(&mut hbuffer);
118        Ok(match req.parse(buf)? {
119            Status::Partial => None,
120            Status::Complete(size) => Some((size, Request::from_httparse(req)?)),
121        })
122    }
123}
124
125impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
126    fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result<Self> {
127        if raw.method.expect("Bug: no method in header") != "GET" {
128            return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
129        }
130
131        if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
132            return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
133        }
134
135        let headers = HeaderMap::from_httparse(raw.headers)?;
136
137        let mut request = Request::new(());
138        *request.method_mut() = http::Method::GET;
139        *request.headers_mut() = headers;
140        *request.uri_mut() = raw.path.expect("Bug: no path in header").parse()?;
141        // TODO: httparse only supports HTTP 0.9/1.0/1.1 but not HTTP 2.0
142        // so the only valid value we could get in the response would be 1.1.
143        *request.version_mut() = http::Version::HTTP_11;
144
145        Ok(request)
146    }
147}
148
149/// The callback trait.
150///
151/// The callback is called when the server receives an incoming WebSocket
152/// handshake request from the client. Specifying a callback allows you to analyze incoming headers
153/// and add additional headers to the response that server sends to the client and/or reject the
154/// connection based on the incoming headers.
155pub trait Callback: Sized {
156    /// Called whenever the server read the request from the client and is ready to reply to it.
157    /// May return additional reply headers.
158    /// Returning an error resulting in rejecting the incoming connection.
159    fn on_request(
160        self,
161        request: &Request,
162        response: Response,
163    ) -> StdResult<Response, ErrorResponse>;
164}
165
166impl<F> Callback for F
167where
168    F: FnOnce(&Request, Response) -> StdResult<Response, ErrorResponse>,
169{
170    fn on_request(
171        self,
172        request: &Request,
173        response: Response,
174    ) -> StdResult<Response, ErrorResponse> {
175        self(request, response)
176    }
177}
178
179/// Stub for callback that does nothing.
180#[derive(Clone, Copy, Debug)]
181pub struct NoCallback;
182
183impl Callback for NoCallback {
184    fn on_request(
185        self,
186        _request: &Request,
187        response: Response,
188    ) -> StdResult<Response, ErrorResponse> {
189        Ok(response)
190    }
191}
192
193/// Server handshake role.
194#[allow(missing_copy_implementations)]
195#[derive(Debug)]
196pub struct ServerHandshake<S, C> {
197    /// Callback which is called whenever the server read the request from the client and is ready
198    /// to reply to it. The callback returns an optional headers which will be added to the reply
199    /// which the server sends to the user.
200    callback: Option<C>,
201    /// WebSocket configuration.
202    config: Option<WebSocketConfig>,
203    /// Error code/flag. If set, an error will be returned after sending response to the client.
204    error_response: Option<ErrorResponse>,
205    /// Internal stream type.
206    _marker: PhantomData<S>,
207}
208
209impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
210    /// Start server handshake. `callback` specifies a custom callback which the user can pass to
211    /// the handshake, this callback will be called when the a websocket client connects to the
212    /// server, you can specify the callback if you want to add additional header to the client
213    /// upon join based on the incoming headers.
214    pub fn start(stream: S, callback: C, config: Option<WebSocketConfig>) -> MidHandshake<Self> {
215        trace!("Server handshake initiated.");
216        MidHandshake {
217            machine: HandshakeMachine::start_read(stream),
218            role: ServerHandshake {
219                callback: Some(callback),
220                config,
221                error_response: None,
222                _marker: PhantomData,
223            },
224        }
225    }
226}
227
228impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
229    type IncomingData = Request;
230    type InternalStream = S;
231    type FinalResult = WebSocket<S>;
232
233    fn stage_finished(
234        &mut self,
235        finish: StageResult<Self::IncomingData, Self::InternalStream>,
236    ) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> {
237        Ok(match finish {
238            StageResult::DoneReading { stream, result, tail } => {
239                if !tail.is_empty() {
240                    return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
241                }
242
243                let response = create_response(&result)?;
244                let callback_result = if let Some(callback) = self.callback.take() {
245                    callback.on_request(&result, response)
246                } else {
247                    Ok(response)
248                };
249
250                match callback_result {
251                    Ok(response) => {
252                        let mut output = vec![];
253                        write_response(&mut output, &response)?;
254                        ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
255                    }
256
257                    Err(resp) => {
258                        if resp.status().is_success() {
259                            return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful));
260                        }
261
262                        self.error_response = Some(resp);
263                        let resp = self.error_response.as_ref().unwrap();
264
265                        let mut output = vec![];
266                        write_response(&mut output, resp)?;
267
268                        if let Some(body) = resp.body() {
269                            output.extend_from_slice(body.as_bytes());
270                        }
271
272                        ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
273                    }
274                }
275            }
276
277            StageResult::DoneWriting(stream) => {
278                if let Some(err) = self.error_response.take() {
279                    debug!("Server handshake failed.");
280
281                    let (parts, body) = err.into_parts();
282                    let body = body.map(|b| b.as_bytes().to_vec());
283                    return Err(Error::Http(http::Response::from_parts(parts, body)));
284                } else {
285                    debug!("Server handshake done.");
286                    let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
287                    ProcessingResult::Done(websocket)
288                }
289            }
290        })
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::{super::machine::TryParse, create_response, Request};
297
298    #[test]
299    fn request_parsing() {
300        const DATA: &[u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n";
301        let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
302        assert_eq!(req.uri().path(), "/script.ws");
303        assert_eq!(req.headers().get("Host").unwrap(), &b"foo.com"[..]);
304    }
305
306    #[test]
307    fn request_replying() {
308        const DATA: &[u8] = b"\
309            GET /script.ws HTTP/1.1\r\n\
310            Host: foo.com\r\n\
311            Connection: upgrade\r\n\
312            Upgrade: websocket\r\n\
313            Sec-WebSocket-Version: 13\r\n\
314            Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
315            \r\n";
316        let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
317        let response = create_response(&req).unwrap();
318
319        assert_eq!(
320            response.headers().get("Sec-WebSocket-Accept").unwrap(),
321            b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".as_ref()
322        );
323    }
324}