1use std::{
4 io::{self, Read, Write},
5 marker::PhantomData,
6 result::Result as StdResult,
7};
8
9use http::{
10 response::Builder, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
11};
12use httparse::Status;
13use log::*;
14
15use super::{
16 derive_accept_key,
17 headers::{FromHttparse, MAX_HEADERS},
18 machine::{HandshakeMachine, StageResult, TryParse},
19 HandshakeRole, MidHandshake, ProcessingResult,
20};
21use crate::{
22 error::{Error, ProtocolError, Result},
23 protocol::{Role, WebSocket, WebSocketConfig},
24};
25
26pub type Request = HttpRequest<()>;
28
29pub type Response = HttpResponse<()>;
31
32pub type ErrorResponse = HttpResponse<Option<String>>;
34
35fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> {
36 if request.method() != http::Method::GET {
37 return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
38 }
39
40 if request.version() < http::Version::HTTP_11 {
41 return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
42 }
43
44 if !request
45 .headers()
46 .get("Connection")
47 .and_then(|h| h.to_str().ok())
48 .map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade")))
49 .unwrap_or(false)
50 {
51 return Err(Error::Protocol(ProtocolError::MissingConnectionUpgradeHeader));
52 }
53
54 if !request
55 .headers()
56 .get("Upgrade")
57 .and_then(|h| h.to_str().ok())
58 .map(|h| h.eq_ignore_ascii_case("websocket"))
59 .unwrap_or(false)
60 {
61 return Err(Error::Protocol(ProtocolError::MissingUpgradeWebSocketHeader));
62 }
63
64 if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) {
65 return Err(Error::Protocol(ProtocolError::MissingSecWebSocketVersionHeader));
66 }
67
68 let key = request
69 .headers()
70 .get("Sec-WebSocket-Key")
71 .ok_or(Error::Protocol(ProtocolError::MissingSecWebSocketKey))?;
72
73 let builder = Response::builder()
74 .status(StatusCode::SWITCHING_PROTOCOLS)
75 .version(request.version())
76 .header("Connection", "Upgrade")
77 .header("Upgrade", "websocket")
78 .header("Sec-WebSocket-Accept", derive_accept_key(key.as_bytes()));
79
80 Ok(builder)
81}
82
83pub fn create_response(request: &Request) -> Result<Response> {
85 Ok(create_parts(request)?.body(())?)
86}
87
88pub fn create_response_with_body<T>(
90 request: &HttpRequest<T>,
91 generate_body: impl FnOnce() -> T,
92) -> Result<HttpResponse<T>> {
93 Ok(create_parts(request)?.body(generate_body())?)
94}
95
96pub fn write_response<T>(mut w: impl io::Write, response: &HttpResponse<T>) -> Result<()> {
98 writeln!(
99 w,
100 "{version:?} {status}\r",
101 version = response.version(),
102 status = response.status()
103 )?;
104
105 for (k, v) in response.headers() {
106 writeln!(w, "{}: {}\r", k, v.to_str()?)?;
107 }
108
109 writeln!(w, "\r")?;
110
111 Ok(())
112}
113
114impl TryParse for Request {
115 fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
116 let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
117 let mut req = httparse::Request::new(&mut hbuffer);
118 Ok(match req.parse(buf)? {
119 Status::Partial => None,
120 Status::Complete(size) => Some((size, Request::from_httparse(req)?)),
121 })
122 }
123}
124
125impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
126 fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result<Self> {
127 if raw.method.expect("Bug: no method in header") != "GET" {
128 return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
129 }
130
131 if raw.version.expect("Bug: no HTTP version") < 1 {
132 return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
133 }
134
135 let headers = HeaderMap::from_httparse(raw.headers)?;
136
137 let mut request = Request::new(());
138 *request.method_mut() = http::Method::GET;
139 *request.headers_mut() = headers;
140 *request.uri_mut() = raw.path.expect("Bug: no path in header").parse()?;
141 *request.version_mut() = http::Version::HTTP_11;
144
145 Ok(request)
146 }
147}
148
149pub trait Callback: Sized {
156 fn on_request(
160 self,
161 request: &Request,
162 response: Response,
163 ) -> StdResult<Response, ErrorResponse>;
164}
165
166impl<F> Callback for F
167where
168 F: FnOnce(&Request, Response) -> StdResult<Response, ErrorResponse>,
169{
170 fn on_request(
171 self,
172 request: &Request,
173 response: Response,
174 ) -> StdResult<Response, ErrorResponse> {
175 self(request, response)
176 }
177}
178
179#[derive(Clone, Copy, Debug)]
181pub struct NoCallback;
182
183impl Callback for NoCallback {
184 fn on_request(
185 self,
186 _request: &Request,
187 response: Response,
188 ) -> StdResult<Response, ErrorResponse> {
189 Ok(response)
190 }
191}
192
193#[allow(missing_copy_implementations)]
195#[derive(Debug)]
196pub struct ServerHandshake<S, C> {
197 callback: Option<C>,
201 config: Option<WebSocketConfig>,
203 error_response: Option<ErrorResponse>,
205 _marker: PhantomData<S>,
207}
208
209impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
210 pub fn start(stream: S, callback: C, config: Option<WebSocketConfig>) -> MidHandshake<Self> {
215 trace!("Server handshake initiated.");
216 MidHandshake {
217 machine: HandshakeMachine::start_read(stream),
218 role: ServerHandshake {
219 callback: Some(callback),
220 config,
221 error_response: None,
222 _marker: PhantomData,
223 },
224 }
225 }
226}
227
228impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
229 type IncomingData = Request;
230 type InternalStream = S;
231 type FinalResult = WebSocket<S>;
232
233 fn stage_finished(
234 &mut self,
235 finish: StageResult<Self::IncomingData, Self::InternalStream>,
236 ) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> {
237 Ok(match finish {
238 StageResult::DoneReading { stream, result, tail } => {
239 if !tail.is_empty() {
240 return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
241 }
242
243 let response = create_response(&result)?;
244 let callback_result = if let Some(callback) = self.callback.take() {
245 callback.on_request(&result, response)
246 } else {
247 Ok(response)
248 };
249
250 match callback_result {
251 Ok(response) => {
252 let mut output = vec![];
253 write_response(&mut output, &response)?;
254 ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
255 }
256
257 Err(resp) => {
258 if resp.status().is_success() {
259 return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful));
260 }
261
262 self.error_response = Some(resp);
263 let resp = self.error_response.as_ref().unwrap();
264
265 let mut output = vec![];
266 write_response(&mut output, resp)?;
267
268 if let Some(body) = resp.body() {
269 output.extend_from_slice(body.as_bytes());
270 }
271
272 ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
273 }
274 }
275 }
276
277 StageResult::DoneWriting(stream) => {
278 if let Some(err) = self.error_response.take() {
279 debug!("Server handshake failed.");
280
281 let (parts, body) = err.into_parts();
282 let body = body.map(|b| b.as_bytes().to_vec());
283 return Err(Error::Http(http::Response::from_parts(parts, body)));
284 } else {
285 debug!("Server handshake done.");
286 let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
287 ProcessingResult::Done(websocket)
288 }
289 }
290 })
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::{super::machine::TryParse, create_response, Request};
297
298 #[test]
299 fn request_parsing() {
300 const DATA: &[u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n";
301 let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
302 assert_eq!(req.uri().path(), "/script.ws");
303 assert_eq!(req.headers().get("Host").unwrap(), &b"foo.com"[..]);
304 }
305
306 #[test]
307 fn request_replying() {
308 const DATA: &[u8] = b"\
309 GET /script.ws HTTP/1.1\r\n\
310 Host: foo.com\r\n\
311 Connection: upgrade\r\n\
312 Upgrade: websocket\r\n\
313 Sec-WebSocket-Version: 13\r\n\
314 Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
315 \r\n";
316 let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
317 let response = create_response(&req).unwrap();
318
319 assert_eq!(
320 response.headers().get("Sec-WebSocket-Accept").unwrap(),
321 b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".as_ref()
322 );
323 }
324}