tungstenite/protocol/frame/
frame.rs

1use byteorder::{NetworkEndian, ReadBytesExt};
2use log::*;
3use std::{
4    borrow::Cow,
5    default::Default,
6    fmt,
7    io::{Cursor, ErrorKind, Read, Write},
8    result::Result as StdResult,
9    str::Utf8Error,
10    string::{FromUtf8Error, String},
11};
12
13use super::{
14    coding::{CloseCode, Control, Data, OpCode},
15    mask::{apply_mask, generate_mask},
16};
17use crate::error::{Error, ProtocolError, Result};
18
19/// A struct representing the close command.
20#[derive(Debug, Clone, Eq, PartialEq)]
21pub struct CloseFrame<'t> {
22    /// The reason as a code.
23    pub code: CloseCode,
24    /// The reason as text string.
25    pub reason: Cow<'t, str>,
26}
27
28impl<'t> CloseFrame<'t> {
29    /// Convert into a owned string.
30    pub fn into_owned(self) -> CloseFrame<'static> {
31        CloseFrame { code: self.code, reason: self.reason.into_owned().into() }
32    }
33}
34
35impl<'t> fmt::Display for CloseFrame<'t> {
36    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
37        write!(f, "{} ({})", self.reason, self.code)
38    }
39}
40
41/// A struct representing a WebSocket frame header.
42#[allow(missing_copy_implementations)]
43#[derive(Debug, Clone, Eq, PartialEq)]
44pub struct FrameHeader {
45    /// Indicates that the frame is the last one of a possibly fragmented message.
46    pub is_final: bool,
47    /// Reserved for protocol extensions.
48    pub rsv1: bool,
49    /// Reserved for protocol extensions.
50    pub rsv2: bool,
51    /// Reserved for protocol extensions.
52    pub rsv3: bool,
53    /// WebSocket protocol opcode.
54    pub opcode: OpCode,
55    /// A frame mask, if any.
56    pub mask: Option<[u8; 4]>,
57}
58
59impl Default for FrameHeader {
60    fn default() -> Self {
61        FrameHeader {
62            is_final: true,
63            rsv1: false,
64            rsv2: false,
65            rsv3: false,
66            opcode: OpCode::Control(Control::Close),
67            mask: None,
68        }
69    }
70}
71
72impl FrameHeader {
73    /// Parse a header from an input stream.
74    /// Returns `None` if insufficient data and does not consume anything in this case.
75    /// Payload size is returned along with the header.
76    pub fn parse(cursor: &mut Cursor<impl AsRef<[u8]>>) -> Result<Option<(Self, u64)>> {
77        let initial = cursor.position();
78        match Self::parse_internal(cursor) {
79            ret @ Ok(None) => {
80                cursor.set_position(initial);
81                ret
82            }
83            ret => ret,
84        }
85    }
86
87    /// Get the size of the header formatted with given payload length.
88    #[allow(clippy::len_without_is_empty)]
89    pub fn len(&self, length: u64) -> usize {
90        2 + LengthFormat::for_length(length).extra_bytes() + if self.mask.is_some() { 4 } else { 0 }
91    }
92
93    /// Format a header for given payload size.
94    pub fn format(&self, length: u64, output: &mut impl Write) -> Result<()> {
95        let code: u8 = self.opcode.into();
96
97        let one = {
98            code | if self.is_final { 0x80 } else { 0 }
99                | if self.rsv1 { 0x40 } else { 0 }
100                | if self.rsv2 { 0x20 } else { 0 }
101                | if self.rsv3 { 0x10 } else { 0 }
102        };
103
104        let lenfmt = LengthFormat::for_length(length);
105
106        let two = { lenfmt.length_byte() | if self.mask.is_some() { 0x80 } else { 0 } };
107
108        output.write_all(&[one, two])?;
109        match lenfmt {
110            LengthFormat::U8(_) => (),
111            LengthFormat::U16 => {
112                output.write_all(&(length as u16).to_be_bytes())?;
113            }
114            LengthFormat::U64 => {
115                output.write_all(&length.to_be_bytes())?;
116            }
117        }
118
119        if let Some(ref mask) = self.mask {
120            output.write_all(mask)?
121        }
122
123        Ok(())
124    }
125
126    /// Generate a random frame mask and store this in the header.
127    ///
128    /// Of course this does not change frame contents. It just generates a mask.
129    pub(crate) fn set_random_mask(&mut self) {
130        self.mask = Some(generate_mask())
131    }
132}
133
134impl FrameHeader {
135    /// Internal parse engine.
136    /// Returns `None` if insufficient data.
137    /// Payload size is returned along with the header.
138    fn parse_internal(cursor: &mut impl Read) -> Result<Option<(Self, u64)>> {
139        let (first, second) = {
140            let mut head = [0u8; 2];
141            if cursor.read(&mut head)? != 2 {
142                return Ok(None);
143            }
144            trace!("Parsed headers {:?}", head);
145            (head[0], head[1])
146        };
147
148        trace!("First: {:b}", first);
149        trace!("Second: {:b}", second);
150
151        let is_final = first & 0x80 != 0;
152
153        let rsv1 = first & 0x40 != 0;
154        let rsv2 = first & 0x20 != 0;
155        let rsv3 = first & 0x10 != 0;
156
157        let opcode = OpCode::from(first & 0x0F);
158        trace!("Opcode: {:?}", opcode);
159
160        let masked = second & 0x80 != 0;
161        trace!("Masked: {:?}", masked);
162
163        let length = {
164            let length_byte = second & 0x7F;
165            let length_length = LengthFormat::for_byte(length_byte).extra_bytes();
166            if length_length > 0 {
167                match cursor.read_uint::<NetworkEndian>(length_length) {
168                    Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => {
169                        return Ok(None);
170                    }
171                    Err(err) => {
172                        return Err(err.into());
173                    }
174                    Ok(read) => read,
175                }
176            } else {
177                u64::from(length_byte)
178            }
179        };
180
181        let mask = if masked {
182            let mut mask_bytes = [0u8; 4];
183            if cursor.read(&mut mask_bytes)? != 4 {
184                return Ok(None);
185            } else {
186                Some(mask_bytes)
187            }
188        } else {
189            None
190        };
191
192        // Disallow bad opcode
193        match opcode {
194            OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => {
195                return Err(Error::Protocol(ProtocolError::InvalidOpcode(first & 0x0F)))
196            }
197            _ => (),
198        }
199
200        let hdr = FrameHeader { is_final, rsv1, rsv2, rsv3, opcode, mask };
201
202        Ok(Some((hdr, length)))
203    }
204}
205
206/// A struct representing a WebSocket frame.
207#[derive(Debug, Clone, Eq, PartialEq)]
208pub struct Frame {
209    header: FrameHeader,
210    payload: Vec<u8>,
211}
212
213impl Frame {
214    /// Get the length of the frame.
215    /// This is the length of the header + the length of the payload.
216    #[inline]
217    pub fn len(&self) -> usize {
218        let length = self.payload.len();
219        self.header.len(length as u64) + length
220    }
221
222    /// Check if the frame is empty.
223    #[inline]
224    pub fn is_empty(&self) -> bool {
225        self.len() == 0
226    }
227
228    /// Get a reference to the frame's header.
229    #[inline]
230    pub fn header(&self) -> &FrameHeader {
231        &self.header
232    }
233
234    /// Get a mutable reference to the frame's header.
235    #[inline]
236    pub fn header_mut(&mut self) -> &mut FrameHeader {
237        &mut self.header
238    }
239
240    /// Get a reference to the frame's payload.
241    #[inline]
242    pub fn payload(&self) -> &Vec<u8> {
243        &self.payload
244    }
245
246    /// Get a mutable reference to the frame's payload.
247    #[inline]
248    pub fn payload_mut(&mut self) -> &mut Vec<u8> {
249        &mut self.payload
250    }
251
252    /// Test whether the frame is masked.
253    #[inline]
254    pub(crate) fn is_masked(&self) -> bool {
255        self.header.mask.is_some()
256    }
257
258    /// Generate a random mask for the frame.
259    ///
260    /// This just generates a mask, payload is not changed. The actual masking is performed
261    /// either on `format()` or on `apply_mask()` call.
262    #[inline]
263    pub(crate) fn set_random_mask(&mut self) {
264        self.header.set_random_mask()
265    }
266
267    /// This method unmasks the payload and should only be called on frames that are actually
268    /// masked. In other words, those frames that have just been received from a client endpoint.
269    #[inline]
270    pub(crate) fn apply_mask(&mut self) {
271        if let Some(mask) = self.header.mask.take() {
272            apply_mask(&mut self.payload, mask)
273        }
274    }
275
276    /// Consume the frame into its payload as binary.
277    #[inline]
278    pub fn into_data(self) -> Vec<u8> {
279        self.payload
280    }
281
282    /// Consume the frame into its payload as string.
283    #[inline]
284    pub fn into_string(self) -> StdResult<String, FromUtf8Error> {
285        String::from_utf8(self.payload)
286    }
287
288    /// Get frame payload as `&str`.
289    #[inline]
290    pub fn to_text(&self) -> Result<&str, Utf8Error> {
291        std::str::from_utf8(&self.payload)
292    }
293
294    /// Consume the frame into a closing frame.
295    #[inline]
296    pub(crate) fn into_close(self) -> Result<Option<CloseFrame<'static>>> {
297        match self.payload.len() {
298            0 => Ok(None),
299            1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)),
300            _ => {
301                let mut data = self.payload;
302                let code = u16::from_be_bytes([data[0], data[1]]).into();
303                data.drain(0..2);
304                let text = String::from_utf8(data)?;
305                Ok(Some(CloseFrame { code, reason: text.into() }))
306            }
307        }
308    }
309
310    /// Create a new data frame.
311    #[inline]
312    pub fn message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
313        debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
314
315        Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data }
316    }
317
318    /// Create a new Pong control frame.
319    #[inline]
320    pub fn pong(data: Vec<u8>) -> Frame {
321        Frame {
322            header: FrameHeader {
323                opcode: OpCode::Control(Control::Pong),
324                ..FrameHeader::default()
325            },
326            payload: data,
327        }
328    }
329
330    /// Create a new Ping control frame.
331    #[inline]
332    pub fn ping(data: Vec<u8>) -> Frame {
333        Frame {
334            header: FrameHeader {
335                opcode: OpCode::Control(Control::Ping),
336                ..FrameHeader::default()
337            },
338            payload: data,
339        }
340    }
341
342    /// Create a new Close control frame.
343    #[inline]
344    pub fn close(msg: Option<CloseFrame>) -> Frame {
345        let payload = if let Some(CloseFrame { code, reason }) = msg {
346            let mut p = Vec::with_capacity(reason.as_bytes().len() + 2);
347            p.extend(u16::from(code).to_be_bytes());
348            p.extend_from_slice(reason.as_bytes());
349            p
350        } else {
351            Vec::new()
352        };
353
354        Frame { header: FrameHeader::default(), payload }
355    }
356
357    /// Create a frame from given header and data.
358    pub fn from_payload(header: FrameHeader, payload: Vec<u8>) -> Self {
359        Frame { header, payload }
360    }
361
362    /// Write a frame out to a buffer
363    pub fn format(mut self, output: &mut impl Write) -> Result<()> {
364        self.header.format(self.payload.len() as u64, output)?;
365        self.apply_mask();
366        output.write_all(self.payload())?;
367        Ok(())
368    }
369}
370
371impl fmt::Display for Frame {
372    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
373        use std::fmt::Write;
374
375        write!(
376            f,
377            "
378<FRAME>
379final: {}
380reserved: {} {} {}
381opcode: {}
382length: {}
383payload length: {}
384payload: 0x{}
385            ",
386            self.header.is_final,
387            self.header.rsv1,
388            self.header.rsv2,
389            self.header.rsv3,
390            self.header.opcode,
391            // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
392            self.len(),
393            self.payload.len(),
394            self.payload.iter().fold(String::new(), |mut output, byte| {
395                _ = write!(output, "{byte:02x}");
396                output
397            })
398        )
399    }
400}
401
402/// Handling of the length format.
403enum LengthFormat {
404    U8(u8),
405    U16,
406    U64,
407}
408
409impl LengthFormat {
410    /// Get the length format for a given data size.
411    #[inline]
412    fn for_length(length: u64) -> Self {
413        if length < 126 {
414            LengthFormat::U8(length as u8)
415        } else if length < 65536 {
416            LengthFormat::U16
417        } else {
418            LengthFormat::U64
419        }
420    }
421
422    /// Get the size of the length encoding.
423    #[inline]
424    fn extra_bytes(&self) -> usize {
425        match *self {
426            LengthFormat::U8(_) => 0,
427            LengthFormat::U16 => 2,
428            LengthFormat::U64 => 8,
429        }
430    }
431
432    /// Encode the given length.
433    #[inline]
434    fn length_byte(&self) -> u8 {
435        match *self {
436            LengthFormat::U8(b) => b,
437            LengthFormat::U16 => 126,
438            LengthFormat::U64 => 127,
439        }
440    }
441
442    /// Get the length format for a given length byte.
443    #[inline]
444    fn for_byte(byte: u8) -> Self {
445        match byte & 0x7F {
446            126 => LengthFormat::U16,
447            127 => LengthFormat::U64,
448            b => LengthFormat::U8(b),
449        }
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456
457    use super::super::coding::{Data, OpCode};
458    use std::io::Cursor;
459
460    #[test]
461    fn parse() {
462        let mut raw: Cursor<Vec<u8>> =
463            Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
464        let (header, length) = FrameHeader::parse(&mut raw).unwrap().unwrap();
465        assert_eq!(length, 7);
466        let mut payload = Vec::new();
467        raw.read_to_end(&mut payload).unwrap();
468        let frame = Frame::from_payload(header, payload);
469        assert_eq!(frame.into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
470    }
471
472    #[test]
473    fn format() {
474        let frame = Frame::ping(vec![0x01, 0x02]);
475        let mut buf = Vec::with_capacity(frame.len());
476        frame.format(&mut buf).unwrap();
477        assert_eq!(buf, vec![0x89, 0x02, 0x01, 0x02]);
478    }
479
480    #[test]
481    fn display() {
482        let f = Frame::message("hi there".into(), OpCode::Data(Data::Text), true);
483        let view = format!("{}", f);
484        assert!(view.contains("payload:"));
485    }
486}