tungstenite/protocol/frame/
mod.rs

1//! Utilities to work with raw WebSocket frames.
2
3pub mod coding;
4
5#[allow(clippy::module_inception)]
6mod frame;
7mod mask;
8
9use crate::{
10    error::{CapacityError, Error, Result},
11    Message, ReadBuffer,
12};
13use log::*;
14use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write};
15
16pub use self::frame::{CloseFrame, Frame, FrameHeader};
17
18/// A reader and writer for WebSocket frames.
19#[derive(Debug)]
20pub struct FrameSocket<Stream> {
21    /// The underlying network stream.
22    stream: Stream,
23    /// Codec for reading/writing frames.
24    codec: FrameCodec,
25}
26
27impl<Stream> FrameSocket<Stream> {
28    /// Create a new frame socket.
29    pub fn new(stream: Stream) -> Self {
30        FrameSocket { stream, codec: FrameCodec::new() }
31    }
32
33    /// Create a new frame socket from partially read data.
34    pub fn from_partially_read(stream: Stream, part: Vec<u8>) -> Self {
35        FrameSocket { stream, codec: FrameCodec::from_partially_read(part) }
36    }
37
38    /// Extract a stream from the socket.
39    pub fn into_inner(self) -> (Stream, Vec<u8>) {
40        (self.stream, self.codec.in_buffer.into_vec())
41    }
42
43    /// Returns a shared reference to the inner stream.
44    pub fn get_ref(&self) -> &Stream {
45        &self.stream
46    }
47
48    /// Returns a mutable reference to the inner stream.
49    pub fn get_mut(&mut self) -> &mut Stream {
50        &mut self.stream
51    }
52}
53
54impl<Stream> FrameSocket<Stream>
55where
56    Stream: Read,
57{
58    /// Read a frame from stream.
59    pub fn read(&mut self, max_size: Option<usize>) -> Result<Option<Frame>> {
60        self.codec.read_frame(&mut self.stream, max_size)
61    }
62}
63
64impl<Stream> FrameSocket<Stream>
65where
66    Stream: Write,
67{
68    /// Writes and immediately flushes a frame.
69    /// Equivalent to calling [`write`](Self::write) then [`flush`](Self::flush).
70    pub fn send(&mut self, frame: Frame) -> Result<()> {
71        self.write(frame)?;
72        self.flush()
73    }
74
75    /// Write a frame to stream.
76    ///
77    /// A subsequent call should be made to [`flush`](Self::flush) to flush writes.
78    ///
79    /// This function guarantees that the frame is queued unless [`Error::WriteBufferFull`]
80    /// is returned.
81    /// In order to handle WouldBlock or Incomplete, call [`flush`](Self::flush) afterwards.
82    pub fn write(&mut self, frame: Frame) -> Result<()> {
83        self.codec.buffer_frame(&mut self.stream, frame)
84    }
85
86    /// Flush writes.
87    pub fn flush(&mut self) -> Result<()> {
88        self.codec.write_out_buffer(&mut self.stream)?;
89        Ok(self.stream.flush()?)
90    }
91}
92
93/// A codec for WebSocket frames.
94#[derive(Debug)]
95pub(super) struct FrameCodec {
96    /// Buffer to read data from the stream.
97    in_buffer: ReadBuffer,
98    /// Buffer to send packets to the network.
99    out_buffer: Vec<u8>,
100    /// Capacity limit for `out_buffer`.
101    max_out_buffer_len: usize,
102    /// Buffer target length to reach before writing to the stream
103    /// on calls to `buffer_frame`.
104    ///
105    /// Setting this to non-zero will buffer small writes from hitting
106    /// the stream.
107    out_buffer_write_len: usize,
108    /// Header and remaining size of the incoming packet being processed.
109    header: Option<(FrameHeader, u64)>,
110}
111
112impl FrameCodec {
113    /// Create a new frame codec.
114    pub(super) fn new() -> Self {
115        Self {
116            in_buffer: ReadBuffer::new(),
117            out_buffer: Vec::new(),
118            max_out_buffer_len: usize::MAX,
119            out_buffer_write_len: 0,
120            header: None,
121        }
122    }
123
124    /// Create a new frame codec from partially read data.
125    pub(super) fn from_partially_read(part: Vec<u8>) -> Self {
126        Self {
127            in_buffer: ReadBuffer::from_partially_read(part),
128            out_buffer: Vec::new(),
129            max_out_buffer_len: usize::MAX,
130            out_buffer_write_len: 0,
131            header: None,
132        }
133    }
134
135    /// Sets a maximum size for the out buffer.
136    pub(super) fn set_max_out_buffer_len(&mut self, max: usize) {
137        self.max_out_buffer_len = max;
138    }
139
140    /// Sets [`Self::buffer_frame`] buffer target length to reach before
141    /// writing to the stream.
142    pub(super) fn set_out_buffer_write_len(&mut self, len: usize) {
143        self.out_buffer_write_len = len;
144    }
145
146    /// Read a frame from the provided stream.
147    pub(super) fn read_frame<Stream>(
148        &mut self,
149        stream: &mut Stream,
150        max_size: Option<usize>,
151    ) -> Result<Option<Frame>>
152    where
153        Stream: Read,
154    {
155        let max_size = max_size.unwrap_or_else(usize::max_value);
156
157        let payload = loop {
158            {
159                let cursor = self.in_buffer.as_cursor_mut();
160
161                if self.header.is_none() {
162                    self.header = FrameHeader::parse(cursor)?;
163                }
164
165                if let Some((_, ref length)) = self.header {
166                    let length = *length;
167
168                    // Enforce frame size limit early and make sure `length`
169                    // is not too big (fits into `usize`).
170                    if length > max_size as u64 {
171                        return Err(Error::Capacity(CapacityError::MessageTooLong {
172                            size: length as usize,
173                            max_size,
174                        }));
175                    }
176
177                    let input_size = cursor.get_ref().len() as u64 - cursor.position();
178                    if length <= input_size {
179                        // No truncation here since `length` is checked above
180                        let mut payload = Vec::with_capacity(length as usize);
181                        if length > 0 {
182                            cursor.take(length).read_to_end(&mut payload)?;
183                        }
184                        break payload;
185                    }
186                }
187            }
188
189            // Not enough data in buffer.
190            let size = self.in_buffer.read_from(stream)?;
191            if size == 0 {
192                trace!("no frame received");
193                return Ok(None);
194            }
195        };
196
197        let (header, length) = self.header.take().expect("Bug: no frame header");
198        debug_assert_eq!(payload.len() as u64, length);
199        let frame = Frame::from_payload(header, payload);
200        trace!("received frame {}", frame);
201        Ok(Some(frame))
202    }
203
204    /// Writes a frame into the `out_buffer`.
205    /// If the out buffer size is over the `out_buffer_write_len` will also write
206    /// the out buffer into the provided `stream`.
207    ///
208    /// To ensure buffered frames are written call [`Self::write_out_buffer`].
209    ///
210    /// May write to the stream, will **not** flush.
211    pub(super) fn buffer_frame<Stream>(&mut self, stream: &mut Stream, frame: Frame) -> Result<()>
212    where
213        Stream: Write,
214    {
215        if frame.len() + self.out_buffer.len() > self.max_out_buffer_len {
216            return Err(Error::WriteBufferFull(Message::Frame(frame)));
217        }
218
219        trace!("writing frame {}", frame);
220
221        self.out_buffer.reserve(frame.len());
222        frame.format(&mut self.out_buffer).expect("Bug: can't write to vector");
223
224        if self.out_buffer.len() > self.out_buffer_write_len {
225            self.write_out_buffer(stream)
226        } else {
227            Ok(())
228        }
229    }
230
231    /// Writes the out_buffer to the provided stream.
232    ///
233    /// Does **not** flush.
234    pub(super) fn write_out_buffer<Stream>(&mut self, stream: &mut Stream) -> Result<()>
235    where
236        Stream: Write,
237    {
238        while !self.out_buffer.is_empty() {
239            let len = stream.write(&self.out_buffer)?;
240            if len == 0 {
241                // This is the same as "Connection reset by peer"
242                return Err(IoError::new(
243                    IoErrorKind::ConnectionReset,
244                    "Connection reset while sending",
245                )
246                .into());
247            }
248            self.out_buffer.drain(0..len);
249        }
250
251        Ok(())
252    }
253}
254
255#[cfg(test)]
256mod tests {
257
258    use crate::error::{CapacityError, Error};
259
260    use super::{Frame, FrameSocket};
261
262    use std::io::Cursor;
263
264    #[test]
265    fn read_frames() {
266        let raw = Cursor::new(vec![
267            0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x82, 0x03, 0x03, 0x02, 0x01,
268            0x99,
269        ]);
270        let mut sock = FrameSocket::new(raw);
271
272        assert_eq!(
273            sock.read(None).unwrap().unwrap().into_data(),
274            vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
275        );
276        assert_eq!(sock.read(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]);
277        assert!(sock.read(None).unwrap().is_none());
278
279        let (_, rest) = sock.into_inner();
280        assert_eq!(rest, vec![0x99]);
281    }
282
283    #[test]
284    fn from_partially_read() {
285        let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
286        let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]);
287        assert_eq!(
288            sock.read(None).unwrap().unwrap().into_data(),
289            vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
290        );
291    }
292
293    #[test]
294    fn write_frames() {
295        let mut sock = FrameSocket::new(Vec::new());
296
297        let frame = Frame::ping(vec![0x04, 0x05]);
298        sock.send(frame).unwrap();
299
300        let frame = Frame::pong(vec![0x01]);
301        sock.send(frame).unwrap();
302
303        let (buf, _) = sock.into_inner();
304        assert_eq!(buf, vec![0x89, 0x02, 0x04, 0x05, 0x8a, 0x01, 0x01]);
305    }
306
307    #[test]
308    fn parse_overflow() {
309        let raw = Cursor::new(vec![
310            0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
311        ]);
312        let mut sock = FrameSocket::new(raw);
313        let _ = sock.read(None); // should not crash
314    }
315
316    #[test]
317    fn size_limit_hit() {
318        let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
319        let mut sock = FrameSocket::new(raw);
320        assert!(matches!(
321            sock.read(Some(5)),
322            Err(Error::Capacity(CapacityError::MessageTooLong { size: 7, max_size: 5 }))
323        ));
324    }
325}