tungstenite/protocol/frame/
frame.rs1use byteorder::{NetworkEndian, ReadBytesExt};
2use log::*;
3use std::{
4 borrow::Cow,
5 default::Default,
6 fmt,
7 io::{Cursor, ErrorKind, Read, Write},
8 result::Result as StdResult,
9 str::Utf8Error,
10 string::{FromUtf8Error, String},
11};
12
13use super::{
14 coding::{CloseCode, Control, Data, OpCode},
15 mask::{apply_mask, generate_mask},
16};
17use crate::error::{Error, ProtocolError, Result};
18
19#[derive(Debug, Clone, Eq, PartialEq)]
21pub struct CloseFrame<'t> {
22 pub code: CloseCode,
24 pub reason: Cow<'t, str>,
26}
27
28impl<'t> CloseFrame<'t> {
29 pub fn into_owned(self) -> CloseFrame<'static> {
31 CloseFrame { code: self.code, reason: self.reason.into_owned().into() }
32 }
33}
34
35impl<'t> fmt::Display for CloseFrame<'t> {
36 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
37 write!(f, "{} ({})", self.reason, self.code)
38 }
39}
40
41#[allow(missing_copy_implementations)]
43#[derive(Debug, Clone, Eq, PartialEq)]
44pub struct FrameHeader {
45 pub is_final: bool,
47 pub rsv1: bool,
49 pub rsv2: bool,
51 pub rsv3: bool,
53 pub opcode: OpCode,
55 pub mask: Option<[u8; 4]>,
57}
58
59impl Default for FrameHeader {
60 fn default() -> Self {
61 FrameHeader {
62 is_final: true,
63 rsv1: false,
64 rsv2: false,
65 rsv3: false,
66 opcode: OpCode::Control(Control::Close),
67 mask: None,
68 }
69 }
70}
71
72impl FrameHeader {
73 pub fn parse(cursor: &mut Cursor<impl AsRef<[u8]>>) -> Result<Option<(Self, u64)>> {
77 let initial = cursor.position();
78 match Self::parse_internal(cursor) {
79 ret @ Ok(None) => {
80 cursor.set_position(initial);
81 ret
82 }
83 ret => ret,
84 }
85 }
86
87 #[allow(clippy::len_without_is_empty)]
89 pub fn len(&self, length: u64) -> usize {
90 2 + LengthFormat::for_length(length).extra_bytes() + if self.mask.is_some() { 4 } else { 0 }
91 }
92
93 pub fn format(&self, length: u64, output: &mut impl Write) -> Result<()> {
95 let code: u8 = self.opcode.into();
96
97 let one = {
98 code | if self.is_final { 0x80 } else { 0 }
99 | if self.rsv1 { 0x40 } else { 0 }
100 | if self.rsv2 { 0x20 } else { 0 }
101 | if self.rsv3 { 0x10 } else { 0 }
102 };
103
104 let lenfmt = LengthFormat::for_length(length);
105
106 let two = { lenfmt.length_byte() | if self.mask.is_some() { 0x80 } else { 0 } };
107
108 output.write_all(&[one, two])?;
109 match lenfmt {
110 LengthFormat::U8(_) => (),
111 LengthFormat::U16 => {
112 output.write_all(&(length as u16).to_be_bytes())?;
113 }
114 LengthFormat::U64 => {
115 output.write_all(&length.to_be_bytes())?;
116 }
117 }
118
119 if let Some(ref mask) = self.mask {
120 output.write_all(mask)?
121 }
122
123 Ok(())
124 }
125
126 pub(crate) fn set_random_mask(&mut self) {
130 self.mask = Some(generate_mask())
131 }
132}
133
134impl FrameHeader {
135 fn parse_internal(cursor: &mut impl Read) -> Result<Option<(Self, u64)>> {
139 let (first, second) = {
140 let mut head = [0u8; 2];
141 if cursor.read(&mut head)? != 2 {
142 return Ok(None);
143 }
144 trace!("Parsed headers {:?}", head);
145 (head[0], head[1])
146 };
147
148 trace!("First: {:b}", first);
149 trace!("Second: {:b}", second);
150
151 let is_final = first & 0x80 != 0;
152
153 let rsv1 = first & 0x40 != 0;
154 let rsv2 = first & 0x20 != 0;
155 let rsv3 = first & 0x10 != 0;
156
157 let opcode = OpCode::from(first & 0x0F);
158 trace!("Opcode: {:?}", opcode);
159
160 let masked = second & 0x80 != 0;
161 trace!("Masked: {:?}", masked);
162
163 let length = {
164 let length_byte = second & 0x7F;
165 let length_length = LengthFormat::for_byte(length_byte).extra_bytes();
166 if length_length > 0 {
167 match cursor.read_uint::<NetworkEndian>(length_length) {
168 Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => {
169 return Ok(None);
170 }
171 Err(err) => {
172 return Err(err.into());
173 }
174 Ok(read) => read,
175 }
176 } else {
177 u64::from(length_byte)
178 }
179 };
180
181 let mask = if masked {
182 let mut mask_bytes = [0u8; 4];
183 if cursor.read(&mut mask_bytes)? != 4 {
184 return Ok(None);
185 } else {
186 Some(mask_bytes)
187 }
188 } else {
189 None
190 };
191
192 match opcode {
194 OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => {
195 return Err(Error::Protocol(ProtocolError::InvalidOpcode(first & 0x0F)))
196 }
197 _ => (),
198 }
199
200 let hdr = FrameHeader { is_final, rsv1, rsv2, rsv3, opcode, mask };
201
202 Ok(Some((hdr, length)))
203 }
204}
205
206#[derive(Debug, Clone, Eq, PartialEq)]
208pub struct Frame {
209 header: FrameHeader,
210 payload: Vec<u8>,
211}
212
213impl Frame {
214 #[inline]
217 pub fn len(&self) -> usize {
218 let length = self.payload.len();
219 self.header.len(length as u64) + length
220 }
221
222 #[inline]
224 pub fn is_empty(&self) -> bool {
225 self.len() == 0
226 }
227
228 #[inline]
230 pub fn header(&self) -> &FrameHeader {
231 &self.header
232 }
233
234 #[inline]
236 pub fn header_mut(&mut self) -> &mut FrameHeader {
237 &mut self.header
238 }
239
240 #[inline]
242 pub fn payload(&self) -> &Vec<u8> {
243 &self.payload
244 }
245
246 #[inline]
248 pub fn payload_mut(&mut self) -> &mut Vec<u8> {
249 &mut self.payload
250 }
251
252 #[inline]
254 pub(crate) fn is_masked(&self) -> bool {
255 self.header.mask.is_some()
256 }
257
258 #[inline]
263 pub(crate) fn set_random_mask(&mut self) {
264 self.header.set_random_mask()
265 }
266
267 #[inline]
270 pub(crate) fn apply_mask(&mut self) {
271 if let Some(mask) = self.header.mask.take() {
272 apply_mask(&mut self.payload, mask)
273 }
274 }
275
276 #[inline]
278 pub fn into_data(self) -> Vec<u8> {
279 self.payload
280 }
281
282 #[inline]
284 pub fn into_string(self) -> StdResult<String, FromUtf8Error> {
285 String::from_utf8(self.payload)
286 }
287
288 #[inline]
290 pub fn to_text(&self) -> Result<&str, Utf8Error> {
291 std::str::from_utf8(&self.payload)
292 }
293
294 #[inline]
296 pub(crate) fn into_close(self) -> Result<Option<CloseFrame<'static>>> {
297 match self.payload.len() {
298 0 => Ok(None),
299 1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)),
300 _ => {
301 let mut data = self.payload;
302 let code = u16::from_be_bytes([data[0], data[1]]).into();
303 data.drain(0..2);
304 let text = String::from_utf8(data)?;
305 Ok(Some(CloseFrame { code, reason: text.into() }))
306 }
307 }
308 }
309
310 #[inline]
312 pub fn message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
313 debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
314
315 Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data }
316 }
317
318 #[inline]
320 pub fn pong(data: Vec<u8>) -> Frame {
321 Frame {
322 header: FrameHeader {
323 opcode: OpCode::Control(Control::Pong),
324 ..FrameHeader::default()
325 },
326 payload: data,
327 }
328 }
329
330 #[inline]
332 pub fn ping(data: Vec<u8>) -> Frame {
333 Frame {
334 header: FrameHeader {
335 opcode: OpCode::Control(Control::Ping),
336 ..FrameHeader::default()
337 },
338 payload: data,
339 }
340 }
341
342 #[inline]
344 pub fn close(msg: Option<CloseFrame>) -> Frame {
345 let payload = if let Some(CloseFrame { code, reason }) = msg {
346 let mut p = Vec::with_capacity(reason.as_bytes().len() + 2);
347 p.extend(u16::from(code).to_be_bytes());
348 p.extend_from_slice(reason.as_bytes());
349 p
350 } else {
351 Vec::new()
352 };
353
354 Frame { header: FrameHeader::default(), payload }
355 }
356
357 pub fn from_payload(header: FrameHeader, payload: Vec<u8>) -> Self {
359 Frame { header, payload }
360 }
361
362 pub fn format(mut self, output: &mut impl Write) -> Result<()> {
364 self.header.format(self.payload.len() as u64, output)?;
365 self.apply_mask();
366 output.write_all(self.payload())?;
367 Ok(())
368 }
369}
370
371impl fmt::Display for Frame {
372 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
373 use std::fmt::Write;
374
375 write!(
376 f,
377 "
378<FRAME>
379final: {}
380reserved: {} {} {}
381opcode: {}
382length: {}
383payload length: {}
384payload: 0x{}
385 ",
386 self.header.is_final,
387 self.header.rsv1,
388 self.header.rsv2,
389 self.header.rsv3,
390 self.header.opcode,
391 self.len(),
393 self.payload.len(),
394 self.payload.iter().fold(String::new(), |mut output, byte| {
395 _ = write!(output, "{byte:02x}");
396 output
397 })
398 )
399 }
400}
401
402enum LengthFormat {
404 U8(u8),
405 U16,
406 U64,
407}
408
409impl LengthFormat {
410 #[inline]
412 fn for_length(length: u64) -> Self {
413 if length < 126 {
414 LengthFormat::U8(length as u8)
415 } else if length < 65536 {
416 LengthFormat::U16
417 } else {
418 LengthFormat::U64
419 }
420 }
421
422 #[inline]
424 fn extra_bytes(&self) -> usize {
425 match *self {
426 LengthFormat::U8(_) => 0,
427 LengthFormat::U16 => 2,
428 LengthFormat::U64 => 8,
429 }
430 }
431
432 #[inline]
434 fn length_byte(&self) -> u8 {
435 match *self {
436 LengthFormat::U8(b) => b,
437 LengthFormat::U16 => 126,
438 LengthFormat::U64 => 127,
439 }
440 }
441
442 #[inline]
444 fn for_byte(byte: u8) -> Self {
445 match byte & 0x7F {
446 126 => LengthFormat::U16,
447 127 => LengthFormat::U64,
448 b => LengthFormat::U8(b),
449 }
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456
457 use super::super::coding::{Data, OpCode};
458 use std::io::Cursor;
459
460 #[test]
461 fn parse() {
462 let mut raw: Cursor<Vec<u8>> =
463 Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
464 let (header, length) = FrameHeader::parse(&mut raw).unwrap().unwrap();
465 assert_eq!(length, 7);
466 let mut payload = Vec::new();
467 raw.read_to_end(&mut payload).unwrap();
468 let frame = Frame::from_payload(header, payload);
469 assert_eq!(frame.into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
470 }
471
472 #[test]
473 fn format() {
474 let frame = Frame::ping(vec![0x01, 0x02]);
475 let mut buf = Vec::with_capacity(frame.len());
476 frame.format(&mut buf).unwrap();
477 assert_eq!(buf, vec![0x89, 0x02, 0x01, 0x02]);
478 }
479
480 #[test]
481 fn display() {
482 let f = Frame::message("hi there".into(), OpCode::Data(Data::Text), true);
483 let view = format!("{}", f);
484 assert!(view.contains("payload:"));
485 }
486}