tungstenite/protocol/
message.rs

1use std::{
2    convert::{AsRef, From, Into, TryFrom},
3    fmt,
4    result::Result as StdResult,
5    str,
6};
7
8use super::frame::{CloseFrame, Frame};
9use crate::error::{CapacityError, Error, Result};
10
11mod string_collect {
12    use utf8::DecodeError;
13
14    use crate::error::{Error, Result};
15
16    #[derive(Debug)]
17    pub struct StringCollector {
18        data: String,
19        incomplete: Option<utf8::Incomplete>,
20    }
21
22    impl StringCollector {
23        pub fn new() -> Self {
24            StringCollector { data: String::new(), incomplete: None }
25        }
26
27        pub fn len(&self) -> usize {
28            self.data
29                .len()
30                .saturating_add(self.incomplete.map(|i| i.buffer_len as usize).unwrap_or(0))
31        }
32
33        pub fn extend<T: AsRef<[u8]>>(&mut self, tail: T) -> Result<()> {
34            let mut input: &[u8] = tail.as_ref();
35
36            if let Some(mut incomplete) = self.incomplete.take() {
37                if let Some((result, rest)) = incomplete.try_complete(input) {
38                    input = rest;
39                    if let Ok(text) = result {
40                        self.data.push_str(text);
41                    } else {
42                        return Err(Error::Utf8);
43                    }
44                } else {
45                    input = &[];
46                    self.incomplete = Some(incomplete);
47                }
48            }
49
50            if !input.is_empty() {
51                match utf8::decode(input) {
52                    Ok(text) => {
53                        self.data.push_str(text);
54                        Ok(())
55                    }
56                    Err(DecodeError::Incomplete { valid_prefix, incomplete_suffix }) => {
57                        self.data.push_str(valid_prefix);
58                        self.incomplete = Some(incomplete_suffix);
59                        Ok(())
60                    }
61                    Err(DecodeError::Invalid { valid_prefix, .. }) => {
62                        self.data.push_str(valid_prefix);
63                        Err(Error::Utf8)
64                    }
65                }
66            } else {
67                Ok(())
68            }
69        }
70
71        pub fn into_string(self) -> Result<String> {
72            if self.incomplete.is_some() {
73                Err(Error::Utf8)
74            } else {
75                Ok(self.data)
76            }
77        }
78    }
79}
80
81use self::string_collect::StringCollector;
82
83/// A struct representing the incomplete message.
84#[derive(Debug)]
85pub struct IncompleteMessage {
86    collector: IncompleteMessageCollector,
87}
88
89#[derive(Debug)]
90enum IncompleteMessageCollector {
91    Text(StringCollector),
92    Binary(Vec<u8>),
93}
94
95impl IncompleteMessage {
96    /// Create new.
97    pub fn new(message_type: IncompleteMessageType) -> Self {
98        IncompleteMessage {
99            collector: match message_type {
100                IncompleteMessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()),
101                IncompleteMessageType::Text => {
102                    IncompleteMessageCollector::Text(StringCollector::new())
103                }
104            },
105        }
106    }
107
108    /// Get the current filled size of the buffer.
109    pub fn len(&self) -> usize {
110        match self.collector {
111            IncompleteMessageCollector::Text(ref t) => t.len(),
112            IncompleteMessageCollector::Binary(ref b) => b.len(),
113        }
114    }
115
116    /// Add more data to an existing message.
117    pub fn extend<T: AsRef<[u8]>>(&mut self, tail: T, size_limit: Option<usize>) -> Result<()> {
118        // Always have a max size. This ensures an error in case of concatenating two buffers
119        // of more than `usize::max_value()` bytes in total.
120        let max_size = size_limit.unwrap_or_else(usize::max_value);
121        let my_size = self.len();
122        let portion_size = tail.as_ref().len();
123        // Be careful about integer overflows here.
124        if my_size > max_size || portion_size > max_size - my_size {
125            return Err(Error::Capacity(CapacityError::MessageTooLong {
126                size: my_size + portion_size,
127                max_size,
128            }));
129        }
130
131        match self.collector {
132            IncompleteMessageCollector::Binary(ref mut v) => {
133                v.extend(tail.as_ref());
134                Ok(())
135            }
136            IncompleteMessageCollector::Text(ref mut t) => t.extend(tail),
137        }
138    }
139
140    /// Convert an incomplete message into a complete one.
141    pub fn complete(self) -> Result<Message> {
142        match self.collector {
143            IncompleteMessageCollector::Binary(v) => Ok(Message::Binary(v)),
144            IncompleteMessageCollector::Text(t) => {
145                let text = t.into_string()?;
146                Ok(Message::Text(text))
147            }
148        }
149    }
150}
151
152/// The type of incomplete message.
153pub enum IncompleteMessageType {
154    Text,
155    Binary,
156}
157
158/// An enum representing the various forms of a WebSocket message.
159#[derive(Debug, Eq, PartialEq, Clone)]
160pub enum Message {
161    /// A text WebSocket message
162    Text(String),
163    /// A binary WebSocket message
164    Binary(Vec<u8>),
165    /// A ping message with the specified payload
166    ///
167    /// The payload here must have a length less than 125 bytes
168    Ping(Vec<u8>),
169    /// A pong message with the specified payload
170    ///
171    /// The payload here must have a length less than 125 bytes
172    Pong(Vec<u8>),
173    /// A close message with the optional close frame.
174    Close(Option<CloseFrame<'static>>),
175    /// Raw frame. Note, that you're not going to get this value while reading the message.
176    Frame(Frame),
177}
178
179impl Message {
180    /// Create a new text WebSocket message from a stringable.
181    pub fn text<S>(string: S) -> Message
182    where
183        S: Into<String>,
184    {
185        Message::Text(string.into())
186    }
187
188    /// Create a new binary WebSocket message by converting to `Vec<u8>`.
189    pub fn binary<B>(bin: B) -> Message
190    where
191        B: Into<Vec<u8>>,
192    {
193        Message::Binary(bin.into())
194    }
195
196    /// Indicates whether a message is a text message.
197    pub fn is_text(&self) -> bool {
198        matches!(*self, Message::Text(_))
199    }
200
201    /// Indicates whether a message is a binary message.
202    pub fn is_binary(&self) -> bool {
203        matches!(*self, Message::Binary(_))
204    }
205
206    /// Indicates whether a message is a ping message.
207    pub fn is_ping(&self) -> bool {
208        matches!(*self, Message::Ping(_))
209    }
210
211    /// Indicates whether a message is a pong message.
212    pub fn is_pong(&self) -> bool {
213        matches!(*self, Message::Pong(_))
214    }
215
216    /// Indicates whether a message is a close message.
217    pub fn is_close(&self) -> bool {
218        matches!(*self, Message::Close(_))
219    }
220
221    /// Get the length of the WebSocket message.
222    pub fn len(&self) -> usize {
223        match *self {
224            Message::Text(ref string) => string.len(),
225            Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => {
226                data.len()
227            }
228            Message::Close(ref data) => data.as_ref().map(|d| d.reason.len()).unwrap_or(0),
229            Message::Frame(ref frame) => frame.len(),
230        }
231    }
232
233    /// Returns true if the WebSocket message has no content.
234    /// For example, if the other side of the connection sent an empty string.
235    pub fn is_empty(&self) -> bool {
236        self.len() == 0
237    }
238
239    /// Consume the WebSocket and return it as binary data.
240    pub fn into_data(self) -> Vec<u8> {
241        match self {
242            Message::Text(string) => string.into_bytes(),
243            Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => data,
244            Message::Close(None) => Vec::new(),
245            Message::Close(Some(frame)) => frame.reason.into_owned().into_bytes(),
246            Message::Frame(frame) => frame.into_data(),
247        }
248    }
249
250    /// Attempt to consume the WebSocket message and convert it to a String.
251    pub fn into_text(self) -> Result<String> {
252        match self {
253            Message::Text(string) => Ok(string),
254            Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => {
255                Ok(String::from_utf8(data)?)
256            }
257            Message::Close(None) => Ok(String::new()),
258            Message::Close(Some(frame)) => Ok(frame.reason.into_owned()),
259            Message::Frame(frame) => Ok(frame.into_string()?),
260        }
261    }
262
263    /// Attempt to get a &str from the WebSocket message,
264    /// this will try to convert binary data to utf8.
265    pub fn to_text(&self) -> Result<&str> {
266        match *self {
267            Message::Text(ref string) => Ok(string),
268            Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => {
269                Ok(str::from_utf8(data)?)
270            }
271            Message::Close(None) => Ok(""),
272            Message::Close(Some(ref frame)) => Ok(&frame.reason),
273            Message::Frame(ref frame) => Ok(frame.to_text()?),
274        }
275    }
276}
277
278impl From<String> for Message {
279    fn from(string: String) -> Self {
280        Message::text(string)
281    }
282}
283
284impl<'s> From<&'s str> for Message {
285    fn from(string: &'s str) -> Self {
286        Message::text(string)
287    }
288}
289
290impl<'b> From<&'b [u8]> for Message {
291    fn from(data: &'b [u8]) -> Self {
292        Message::binary(data)
293    }
294}
295
296impl From<Vec<u8>> for Message {
297    fn from(data: Vec<u8>) -> Self {
298        Message::binary(data)
299    }
300}
301
302impl From<Message> for Vec<u8> {
303    fn from(message: Message) -> Self {
304        message.into_data()
305    }
306}
307
308impl TryFrom<Message> for String {
309    type Error = Error;
310
311    fn try_from(value: Message) -> StdResult<Self, Self::Error> {
312        value.into_text()
313    }
314}
315
316impl fmt::Display for Message {
317    fn fmt(&self, f: &mut fmt::Formatter) -> StdResult<(), fmt::Error> {
318        if let Ok(string) = self.to_text() {
319            write!(f, "{}", string)
320        } else {
321            write!(f, "Binary Data<length={}>", self.len())
322        }
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    #[test]
331    fn display() {
332        let t = Message::text("test".to_owned());
333        assert_eq!(t.to_string(), "test".to_owned());
334
335        let bin = Message::binary(vec![0, 1, 3, 4, 241]);
336        assert_eq!(bin.to_string(), "Binary Data<length=5>".to_owned());
337    }
338
339    #[test]
340    fn binary_convert() {
341        let bin = [6u8, 7, 8, 9, 10, 241];
342        let msg = Message::from(&bin[..]);
343        assert!(msg.is_binary());
344        assert!(msg.into_text().is_err());
345    }
346
347    #[test]
348    fn binary_convert_vec() {
349        let bin = vec![6u8, 7, 8, 9, 10, 241];
350        let msg = Message::from(bin);
351        assert!(msg.is_binary());
352        assert!(msg.into_text().is_err());
353    }
354
355    #[test]
356    fn binary_convert_into_vec() {
357        let bin = vec![6u8, 7, 8, 9, 10, 241];
358        let bin_copy = bin.clone();
359        let msg = Message::from(bin);
360        let serialized: Vec<u8> = msg.into();
361        assert_eq!(bin_copy, serialized);
362    }
363
364    #[test]
365    fn text_convert() {
366        let s = "kiwotsukete";
367        let msg = Message::from(s);
368        assert!(msg.is_text());
369    }
370}