rustls/msgs/
deframer.rs

1use alloc::vec::Vec;
2use core::ops::Range;
3use core::slice::SliceIndex;
4use std::io;
5
6use super::base::Payload;
7use super::codec::Codec;
8use super::message::PlainMessage;
9use crate::enums::{ContentType, ProtocolVersion};
10use crate::error::{Error, InvalidMessage, PeerMisbehaved};
11use crate::msgs::codec;
12use crate::msgs::message::{MessageError, OpaqueMessage};
13use crate::record_layer::{Decrypted, RecordLayer};
14
15/// This deframer works to reconstruct TLS messages from a stream of arbitrary-sized reads.
16///
17/// It buffers incoming data into a `Vec` through `read()`, and returns messages through `pop()`.
18/// QUIC connections will call `push()` to append handshake payload data directly.
19#[derive(Default)]
20pub struct MessageDeframer {
21    /// Set if the peer is not talking TLS, but some other
22    /// protocol.  The caller should abort the connection, because
23    /// the deframer cannot recover.
24    last_error: Option<Error>,
25
26    /// If we're in the middle of joining a handshake payload, this is the metadata.
27    joining_hs: Option<HandshakePayloadMeta>,
28}
29
30impl MessageDeframer {
31    /// Return any decrypted messages that the deframer has been able to parse.
32    ///
33    /// Returns an `Error` if the deframer failed to parse some message contents or if decryption
34    /// failed, `Ok(None)` if no full message is buffered or if trial decryption failed, and
35    /// `Ok(Some(_))` if a valid message was found and decrypted successfully.
36    pub fn pop(
37        &mut self,
38        record_layer: &mut RecordLayer,
39        negotiated_version: Option<ProtocolVersion>,
40        buffer: &mut DeframerSliceBuffer,
41    ) -> Result<Option<Deframed>, Error> {
42        if let Some(last_err) = self.last_error.clone() {
43            return Err(last_err);
44        } else if buffer.is_empty() {
45            return Ok(None);
46        }
47
48        // We loop over records we've received but not processed yet.
49        // For records that decrypt as `Handshake`, we keep the current state of the joined
50        // handshake message payload in `self.joining_hs`, appending to it as we see records.
51        let expected_len = loop {
52            let start = match &self.joining_hs {
53                Some(meta) => {
54                    match meta.expected_len {
55                        // We're joining a handshake payload, and we've seen the full payload.
56                        Some(len) if len <= meta.payload.len() => break len,
57                        // Not enough data, and we can't parse any more out of the buffer (QUIC).
58                        _ if meta.quic => return Ok(None),
59                        // Try parsing some more of the encrypted buffered data.
60                        _ => meta.message.end,
61                    }
62                }
63                None => 0,
64            };
65
66            // Does our `buf` contain a full message?  It does if it is big enough to
67            // contain a header, and that header has a length which falls within `buf`.
68            // If so, deframe it and place the message onto the frames output queue.
69            let mut rd = codec::Reader::init(buffer.filled_get(start..));
70            let m = match OpaqueMessage::read(&mut rd) {
71                Ok(m) => m,
72                Err(msg_err) => {
73                    let err_kind = match msg_err {
74                        MessageError::TooShortForHeader | MessageError::TooShortForLength => {
75                            return Ok(None)
76                        }
77                        MessageError::InvalidEmptyPayload => InvalidMessage::InvalidEmptyPayload,
78                        MessageError::MessageTooLarge => InvalidMessage::MessageTooLarge,
79                        MessageError::InvalidContentType => InvalidMessage::InvalidContentType,
80                        MessageError::UnknownProtocolVersion => {
81                            InvalidMessage::UnknownProtocolVersion
82                        }
83                    };
84
85                    return Err(self.set_err(err_kind));
86                }
87            };
88
89            // Return CCS messages and early plaintext alerts immediately without decrypting.
90            let end = start + rd.used();
91            let version_is_tls13 = matches!(negotiated_version, Some(ProtocolVersion::TLSv1_3));
92            let allowed_plaintext = match m.typ {
93                // CCS messages are always plaintext.
94                ContentType::ChangeCipherSpec => true,
95                // Alerts are allowed to be plaintext if-and-only-if:
96                // * The negotiated protocol version is TLS 1.3. - In TLS 1.2 it is unambiguous when
97                //   keying changes based on the CCS message. Only TLS 1.3 requires these heuristics.
98                // * We have not yet decrypted any messages from the peer - if we have we don't
99                //   expect any plaintext.
100                // * The payload size is indicative of a plaintext alert message.
101                ContentType::Alert
102                    if version_is_tls13
103                        && !record_layer.has_decrypted()
104                        && m.payload().len() <= 2 =>
105                {
106                    true
107                }
108                // In other circumstances, we expect all messages to be encrypted.
109                _ => false,
110            };
111            if self.joining_hs.is_none() && allowed_plaintext {
112                // This is unencrypted. We check the contents later.
113                buffer.queue_discard(end);
114                return Ok(Some(Deframed {
115                    want_close_before_decrypt: false,
116                    aligned: true,
117                    trial_decryption_finished: false,
118                    message: m.into_plain_message(),
119                }));
120            }
121
122            // Decrypt the encrypted message (if necessary).
123            let msg = match record_layer.decrypt_incoming(m) {
124                Ok(Some(decrypted)) => {
125                    let Decrypted {
126                        want_close_before_decrypt,
127                        plaintext,
128                    } = decrypted;
129                    debug_assert!(!want_close_before_decrypt);
130                    plaintext
131                }
132                // This was rejected early data, discard it. If we currently have a handshake
133                // payload in progress, this counts as interleaved, so we error out.
134                Ok(None) if self.joining_hs.is_some() => {
135                    return Err(self.set_err(
136                        PeerMisbehaved::RejectedEarlyDataInterleavedWithHandshakeMessage,
137                    ));
138                }
139                Ok(None) => {
140                    buffer.queue_discard(end);
141                    continue;
142                }
143                Err(e) => return Err(e),
144            };
145
146            if self.joining_hs.is_some() && msg.typ != ContentType::Handshake {
147                // "Handshake messages MUST NOT be interleaved with other record
148                // types.  That is, if a handshake message is split over two or more
149                // records, there MUST NOT be any other records between them."
150                // https://www.rfc-editor.org/rfc/rfc8446#section-5.1
151                return Err(self.set_err(PeerMisbehaved::MessageInterleavedWithHandshakeMessage));
152            }
153
154            // If it's not a handshake message, just return it -- no joining necessary.
155            if msg.typ != ContentType::Handshake {
156                let end = start + rd.used();
157                buffer.queue_discard(end);
158                return Ok(Some(Deframed {
159                    want_close_before_decrypt: false,
160                    aligned: true,
161                    trial_decryption_finished: false,
162                    message: msg,
163                }));
164            }
165
166            // If we don't know the payload size yet or if the payload size is larger
167            // than the currently buffered payload, we need to wait for more data.
168            match self.append_hs::<_, false>(msg.version, &msg.payload.0, end, buffer)? {
169                HandshakePayloadState::Blocked => return Ok(None),
170                HandshakePayloadState::Complete(len) => break len,
171                HandshakePayloadState::Continue => continue,
172            }
173        };
174
175        let meta = self.joining_hs.as_mut().unwrap(); // safe after calling `append_hs()`
176
177        // We can now wrap the complete handshake payload in a `PlainMessage`, to be returned.
178        let message = PlainMessage {
179            typ: ContentType::Handshake,
180            version: meta.version,
181            payload: Payload::new(
182                buffer.filled_get(meta.payload.start..meta.payload.start + expected_len),
183            ),
184        };
185
186        // But before we return, update the `joining_hs` state to skip past this payload.
187        if meta.payload.len() > expected_len {
188            // If we have another (beginning of) a handshake payload left in the buffer, update
189            // the payload start to point past the payload we're about to yield, and update the
190            // `expected_len` to match the state of that remaining payload.
191            meta.payload.start += expected_len;
192            meta.expected_len =
193                payload_size(buffer.filled_get(meta.payload.start..meta.payload.end))?;
194        } else {
195            // Otherwise, we've yielded the last handshake payload in the buffer, so we can
196            // discard all of the bytes that we're previously buffered as handshake data.
197            let end = meta.message.end;
198            self.joining_hs = None;
199            buffer.queue_discard(end);
200        }
201
202        Ok(Some(Deframed {
203            want_close_before_decrypt: false,
204            aligned: self.joining_hs.is_none(),
205            trial_decryption_finished: true,
206            message,
207        }))
208    }
209
210    /// Fuses this deframer's error and returns the set value.
211    ///
212    /// Any future calls to `pop` will return `err` again.
213    fn set_err(&mut self, err: impl Into<Error>) -> Error {
214        let err = err.into();
215        self.last_error = Some(err.clone());
216        err
217    }
218
219    /// Allow pushing handshake messages directly into the buffer.
220    pub(crate) fn push(
221        &mut self,
222        version: ProtocolVersion,
223        payload: &[u8],
224        buffer: &mut DeframerVecBuffer,
225    ) -> Result<(), Error> {
226        if !buffer.is_empty() && self.joining_hs.is_none() {
227            return Err(Error::General(
228                "cannot push QUIC messages into unrelated connection".into(),
229            ));
230        } else if let Err(err) = buffer.prepare_read(self.joining_hs.is_some()) {
231            return Err(Error::General(err.into()));
232        }
233
234        let end = buffer.len() + payload.len();
235        self.append_hs::<_, true>(version, payload, end, buffer)?;
236        Ok(())
237    }
238
239    /// Write the handshake message contents into the buffer and update the metadata.
240    ///
241    /// Returns true if a complete message is found.
242    fn append_hs<T: DeframerBuffer<QUIC>, const QUIC: bool>(
243        &mut self,
244        version: ProtocolVersion,
245        payload: &[u8],
246        end: usize,
247        buffer: &mut T,
248    ) -> Result<HandshakePayloadState, Error> {
249        let meta = match &mut self.joining_hs {
250            Some(meta) => {
251                debug_assert_eq!(meta.quic, QUIC);
252
253                // We're joining a handshake message to the previous one here.
254                // Write it into the buffer and update the metadata.
255
256                DeframerBuffer::<QUIC>::copy(buffer, payload, meta.payload.end);
257                meta.message.end = end;
258                meta.payload.end += payload.len();
259
260                // If we haven't parsed the payload size yet, try to do so now.
261                if meta.expected_len.is_none() {
262                    meta.expected_len =
263                        payload_size(buffer.filled_get(meta.payload.start..meta.payload.end))?;
264                }
265
266                meta
267            }
268            None => {
269                // We've found a new handshake message here.
270                // Write it into the buffer and create the metadata.
271
272                let expected_len = payload_size(payload)?;
273                DeframerBuffer::<QUIC>::copy(buffer, payload, 0);
274                self.joining_hs
275                    .insert(HandshakePayloadMeta {
276                        message: Range { start: 0, end },
277                        payload: Range {
278                            start: 0,
279                            end: payload.len(),
280                        },
281                        version,
282                        expected_len,
283                        quic: QUIC,
284                    })
285            }
286        };
287
288        Ok(match meta.expected_len {
289            Some(len) if len <= meta.payload.len() => HandshakePayloadState::Complete(len),
290            _ => match buffer.len() > meta.message.end {
291                true => HandshakePayloadState::Continue,
292                false => HandshakePayloadState::Blocked,
293            },
294        })
295    }
296
297    /// Read some bytes from `rd`, and add them to our internal buffer.
298    #[allow(clippy::comparison_chain)]
299    pub fn read(
300        &mut self,
301        rd: &mut dyn io::Read,
302        buffer: &mut DeframerVecBuffer,
303    ) -> io::Result<usize> {
304        if let Err(err) = buffer.prepare_read(self.joining_hs.is_some()) {
305            return Err(io::Error::new(io::ErrorKind::InvalidData, err));
306        }
307
308        // Try to do the largest reads possible. Note that if
309        // we get a message with a length field out of range here,
310        // we do a zero length read.  That looks like an EOF to
311        // the next layer up, which is fine.
312        let new_bytes = rd.read(buffer.unfilled())?;
313        buffer.advance(new_bytes);
314        Ok(new_bytes)
315    }
316}
317
318#[derive(Default, Debug)]
319pub struct DeframerVecBuffer {
320    /// Buffer of data read from the socket, in the process of being parsed into messages.
321    ///
322    /// For buffer size management, checkout out the [`DeframerVecBuffer::prepare_read()`] method.
323    buf: Vec<u8>,
324
325    /// What size prefix of `buf` is used.
326    used: usize,
327}
328
329impl DeframerVecBuffer {
330    /// Borrows the initialized contents of this buffer and tracks pending discard operations via
331    /// the `discard` reference
332    pub fn borrow(&mut self) -> DeframerSliceBuffer {
333        DeframerSliceBuffer::new(&mut self.buf[..self.used])
334    }
335
336    /// Returns true if there are messages for the caller to process
337    pub fn has_pending(&self) -> bool {
338        !self.is_empty()
339    }
340
341    /// Resize the internal `buf` if necessary for reading more bytes.
342    fn prepare_read(&mut self, is_joining_hs: bool) -> Result<(), &'static str> {
343        // We allow a maximum of 64k of buffered data for handshake messages only. Enforce this
344        // by varying the maximum allowed buffer size here based on whether a prefix of a
345        // handshake payload is currently being buffered. Given that the first read of such a
346        // payload will only ever be 4k bytes, the next time we come around here we allow a
347        // larger buffer size. Once the large message and any following handshake messages in
348        // the same flight have been consumed, `pop()` will call `discard()` to reset `used`.
349        // At this point, the buffer resizing logic below should reduce the buffer size.
350        let allow_max = match is_joining_hs {
351            true => MAX_HANDSHAKE_SIZE as usize,
352            false => OpaqueMessage::MAX_WIRE_SIZE,
353        };
354
355        if self.used >= allow_max {
356            return Err("message buffer full");
357        }
358
359        // If we can and need to increase the buffer size to allow a 4k read, do so. After
360        // dealing with a large handshake message (exceeding `OpaqueMessage::MAX_WIRE_SIZE`),
361        // make sure to reduce the buffer size again (large messages should be rare).
362        // Also, reduce the buffer size if there are neither full nor partial messages in it,
363        // which usually means that the other side suspended sending data.
364        let need_capacity = Ord::min(allow_max, self.used + READ_SIZE);
365        if need_capacity > self.buf.len() {
366            self.buf.resize(need_capacity, 0);
367        } else if self.used == 0 || self.buf.len() > allow_max {
368            self.buf.resize(need_capacity, 0);
369            self.buf.shrink_to(need_capacity);
370        }
371
372        Ok(())
373    }
374
375    /// Discard `taken` bytes from the start of our buffer.
376    pub fn discard(&mut self, taken: usize) {
377        #[allow(clippy::comparison_chain)]
378        if taken < self.used {
379            /* Before:
380             * +----------+----------+----------+
381             * | taken    | pending  |xxxxxxxxxx|
382             * +----------+----------+----------+
383             * 0          ^ taken    ^ self.used
384             *
385             * After:
386             * +----------+----------+----------+
387             * | pending  |xxxxxxxxxxxxxxxxxxxxx|
388             * +----------+----------+----------+
389             * 0          ^ self.used
390             */
391
392            self.buf
393                .copy_within(taken..self.used, 0);
394            self.used -= taken;
395        } else if taken == self.used {
396            self.used = 0;
397        }
398    }
399
400    fn is_empty(&self) -> bool {
401        self.len() == 0
402    }
403
404    fn advance(&mut self, num_bytes: usize) {
405        self.used += num_bytes;
406    }
407
408    fn unfilled(&mut self) -> &mut [u8] {
409        &mut self.buf[self.used..]
410    }
411}
412
413impl FilledDeframerBuffer for DeframerVecBuffer {
414    fn filled_mut(&mut self) -> &mut [u8] {
415        &mut self.buf[..self.used]
416    }
417
418    fn filled(&self) -> &[u8] {
419        &self.buf[..self.used]
420    }
421}
422
423impl DeframerBuffer<true> for DeframerVecBuffer {
424    fn copy(&mut self, src: &[u8], at: usize) {
425        copy_into_buffer(self.unfilled(), src, at);
426        self.advance(src.len());
427    }
428}
429
430impl DeframerBuffer<false> for DeframerVecBuffer {
431    fn copy(&mut self, src: &[u8], at: usize) {
432        self.borrow().copy(src, at)
433    }
434}
435
436/// A borrowed version of [`DeframerVecBuffer`] that tracks discard operations
437pub struct DeframerSliceBuffer<'a> {
438    // a fully initialized buffer that will be deframed
439    buf: &'a mut [u8],
440    // number of bytes to discard from the front of `buf` at a later time
441    discard: usize,
442}
443
444impl<'a> DeframerSliceBuffer<'a> {
445    pub fn new(buf: &'a mut [u8]) -> Self {
446        Self { buf, discard: 0 }
447    }
448
449    /// Tracks a pending discard operation of `num_bytes`
450    pub fn queue_discard(&mut self, num_bytes: usize) {
451        self.discard += num_bytes;
452    }
453
454    /// Returns the number of bytes that need to be discarded
455    pub fn pending_discard(&self) -> usize {
456        self.discard
457    }
458
459    pub fn is_empty(&self) -> bool {
460        self.len() == 0
461    }
462}
463
464impl FilledDeframerBuffer for DeframerSliceBuffer<'_> {
465    fn filled_mut(&mut self) -> &mut [u8] {
466        &mut self.buf[self.discard..]
467    }
468
469    fn filled(&self) -> &[u8] {
470        &self.buf[self.discard..]
471    }
472}
473
474impl DeframerBuffer<false> for DeframerSliceBuffer<'_> {
475    fn copy(&mut self, src: &[u8], at: usize) {
476        copy_into_buffer(self.filled_mut(), src, at)
477    }
478}
479
480trait DeframerBuffer<const QUIC: bool>: FilledDeframerBuffer {
481    /// Copies from the `src` buffer into this buffer at the requested index
482    ///
483    /// If `QUIC` is true the data will be copied into the *un*filled section of the buffer
484    ///
485    /// If `QUIC` is false the data will be copied into the filled section of the buffer
486    fn copy(&mut self, src: &[u8], at: usize);
487}
488
489fn copy_into_buffer(buf: &mut [u8], src: &[u8], at: usize) {
490    buf[at..at + src.len()].copy_from_slice(src);
491}
492
493trait FilledDeframerBuffer {
494    fn filled_mut(&mut self) -> &mut [u8];
495
496    fn filled_get<I>(&self, index: I) -> &I::Output
497    where
498        I: SliceIndex<[u8]>,
499    {
500        self.filled().get(index).unwrap()
501    }
502
503    fn len(&self) -> usize {
504        self.filled().len()
505    }
506
507    fn filled(&self) -> &[u8];
508}
509
510enum HandshakePayloadState {
511    /// Waiting for more data.
512    Blocked,
513    /// We have a complete handshake message.
514    Complete(usize),
515    /// More records available for processing.
516    Continue,
517}
518
519struct HandshakePayloadMeta {
520    /// The range of bytes from the deframer buffer that contains data processed so far.
521    ///
522    /// This will need to be discarded as the last of the handshake message is `pop()`ped.
523    message: Range<usize>,
524    /// The range of bytes from the deframer buffer that contains payload.
525    payload: Range<usize>,
526    /// The protocol version as found in the decrypted handshake message.
527    version: ProtocolVersion,
528    /// The expected size of the handshake payload, if available.
529    ///
530    /// If the received payload exceeds 4 bytes (the handshake payload header), we update
531    /// `expected_len` to contain the payload length as advertised (at most 16_777_215 bytes).
532    expected_len: Option<usize>,
533    /// True if this is a QUIC handshake message.
534    ///
535    /// In the case of QUIC, we get a plaintext handshake data directly from the CRYPTO stream,
536    /// so there's no need to unwrap and decrypt the outer TLS record. This is implemented
537    /// by directly calling `MessageDeframer::push()` from the connection.
538    quic: bool,
539}
540
541/// Determine the expected length of the payload as advertised in the header.
542///
543/// Returns `Err` if the advertised length is larger than what we want to accept
544/// (`MAX_HANDSHAKE_SIZE`), `Ok(None)` if the buffer is too small to contain a complete header,
545/// and `Ok(Some(len))` otherwise.
546fn payload_size(buf: &[u8]) -> Result<Option<usize>, Error> {
547    if buf.len() < HEADER_SIZE {
548        return Ok(None);
549    }
550
551    let (header, _) = buf.split_at(HEADER_SIZE);
552    match codec::u24::read_bytes(&header[1..]) {
553        Ok(len) if len.0 > MAX_HANDSHAKE_SIZE => Err(Error::InvalidMessage(
554            InvalidMessage::HandshakePayloadTooLarge,
555        )),
556        Ok(len) => Ok(Some(HEADER_SIZE + usize::from(len))),
557        _ => Ok(None),
558    }
559}
560
561#[derive(Debug)]
562pub struct Deframed {
563    pub(crate) want_close_before_decrypt: bool,
564    pub(crate) aligned: bool,
565    pub(crate) trial_decryption_finished: bool,
566    pub message: PlainMessage,
567}
568
569const HEADER_SIZE: usize = 1 + 3;
570
571/// TLS allows for handshake messages of up to 16MB.  We
572/// restrict that to 64KB to limit potential for denial-of-
573/// service.
574const MAX_HANDSHAKE_SIZE: u32 = 0xffff;
575
576const READ_SIZE: usize = 4096;
577
578#[cfg(test)]
579mod tests {
580    use std::prelude::v1::*;
581    use std::vec;
582
583    use crate::msgs::message::Message;
584
585    use super::*;
586
587    #[test]
588    fn check_incremental() {
589        let mut d = BufferedDeframer::default();
590        assert!(!d.has_pending());
591        input_whole_incremental(&mut d, FIRST_MESSAGE);
592        assert!(d.has_pending());
593
594        let mut rl = RecordLayer::new();
595        pop_first(&mut d, &mut rl);
596        assert!(!d.has_pending());
597        assert!(d.last_error.is_none());
598    }
599
600    #[test]
601    fn check_incremental_2() {
602        let mut d = BufferedDeframer::default();
603        assert!(!d.has_pending());
604        input_whole_incremental(&mut d, FIRST_MESSAGE);
605        assert!(d.has_pending());
606        input_whole_incremental(&mut d, SECOND_MESSAGE);
607        assert!(d.has_pending());
608
609        let mut rl = RecordLayer::new();
610        pop_first(&mut d, &mut rl);
611        assert!(d.has_pending());
612        pop_second(&mut d, &mut rl);
613        assert!(!d.has_pending());
614        assert!(d.last_error.is_none());
615    }
616
617    #[test]
618    fn check_whole() {
619        let mut d = BufferedDeframer::default();
620        assert!(!d.has_pending());
621        assert_len(FIRST_MESSAGE.len(), d.input_bytes(FIRST_MESSAGE));
622        assert!(d.has_pending());
623
624        let mut rl = RecordLayer::new();
625        pop_first(&mut d, &mut rl);
626        assert!(!d.has_pending());
627        assert!(d.last_error.is_none());
628    }
629
630    #[test]
631    fn check_whole_2() {
632        let mut d = BufferedDeframer::default();
633        assert!(!d.has_pending());
634        assert_len(FIRST_MESSAGE.len(), d.input_bytes(FIRST_MESSAGE));
635        assert_len(SECOND_MESSAGE.len(), d.input_bytes(SECOND_MESSAGE));
636
637        let mut rl = RecordLayer::new();
638        pop_first(&mut d, &mut rl);
639        pop_second(&mut d, &mut rl);
640        assert!(!d.has_pending());
641        assert!(d.last_error.is_none());
642    }
643
644    #[test]
645    fn test_two_in_one_read() {
646        let mut d = BufferedDeframer::default();
647        assert!(!d.has_pending());
648        assert_len(
649            FIRST_MESSAGE.len() + SECOND_MESSAGE.len(),
650            d.input_bytes_concat(FIRST_MESSAGE, SECOND_MESSAGE),
651        );
652
653        let mut rl = RecordLayer::new();
654        pop_first(&mut d, &mut rl);
655        pop_second(&mut d, &mut rl);
656        assert!(!d.has_pending());
657        assert!(d.last_error.is_none());
658    }
659
660    #[test]
661    fn test_two_in_one_read_shortest_first() {
662        let mut d = BufferedDeframer::default();
663        assert!(!d.has_pending());
664        assert_len(
665            FIRST_MESSAGE.len() + SECOND_MESSAGE.len(),
666            d.input_bytes_concat(SECOND_MESSAGE, FIRST_MESSAGE),
667        );
668
669        let mut rl = RecordLayer::new();
670        pop_second(&mut d, &mut rl);
671        pop_first(&mut d, &mut rl);
672        assert!(!d.has_pending());
673        assert!(d.last_error.is_none());
674    }
675
676    #[test]
677    fn test_incremental_with_nonfatal_read_error() {
678        let mut d = BufferedDeframer::default();
679        assert_len(3, d.input_bytes(&FIRST_MESSAGE[..3]));
680        input_error(&mut d);
681        assert_len(FIRST_MESSAGE.len() - 3, d.input_bytes(&FIRST_MESSAGE[3..]));
682
683        let mut rl = RecordLayer::new();
684        pop_first(&mut d, &mut rl);
685        assert!(!d.has_pending());
686        assert!(d.last_error.is_none());
687    }
688
689    #[test]
690    fn test_invalid_contenttype_errors() {
691        let mut d = BufferedDeframer::default();
692        assert_len(
693            INVALID_CONTENTTYPE_MESSAGE.len(),
694            d.input_bytes(INVALID_CONTENTTYPE_MESSAGE),
695        );
696
697        let mut rl = RecordLayer::new();
698        assert_eq!(
699            d.pop(&mut rl, None).unwrap_err(),
700            Error::InvalidMessage(InvalidMessage::InvalidContentType)
701        );
702    }
703
704    #[test]
705    fn test_invalid_version_errors() {
706        let mut d = BufferedDeframer::default();
707        assert_len(
708            INVALID_VERSION_MESSAGE.len(),
709            d.input_bytes(INVALID_VERSION_MESSAGE),
710        );
711
712        let mut rl = RecordLayer::new();
713        assert_eq!(
714            d.pop(&mut rl, None).unwrap_err(),
715            Error::InvalidMessage(InvalidMessage::UnknownProtocolVersion)
716        );
717    }
718
719    #[test]
720    fn test_invalid_length_errors() {
721        let mut d = BufferedDeframer::default();
722        assert_len(
723            INVALID_LENGTH_MESSAGE.len(),
724            d.input_bytes(INVALID_LENGTH_MESSAGE),
725        );
726
727        let mut rl = RecordLayer::new();
728        assert_eq!(
729            d.pop(&mut rl, None).unwrap_err(),
730            Error::InvalidMessage(InvalidMessage::MessageTooLarge)
731        );
732    }
733
734    #[test]
735    fn test_empty_applicationdata() {
736        let mut d = BufferedDeframer::default();
737        assert_len(
738            EMPTY_APPLICATIONDATA_MESSAGE.len(),
739            d.input_bytes(EMPTY_APPLICATIONDATA_MESSAGE),
740        );
741
742        let mut rl = RecordLayer::new();
743        let m = d
744            .pop(&mut rl, None)
745            .unwrap()
746            .unwrap()
747            .message;
748        assert_eq!(m.typ, ContentType::ApplicationData);
749        assert_eq!(m.payload.0.len(), 0);
750        assert!(!d.has_pending());
751        assert!(d.last_error.is_none());
752    }
753
754    #[test]
755    fn test_invalid_empty_errors() {
756        let mut d = BufferedDeframer::default();
757        assert_len(
758            INVALID_EMPTY_MESSAGE.len(),
759            d.input_bytes(INVALID_EMPTY_MESSAGE),
760        );
761
762        let mut rl = RecordLayer::new();
763        assert_eq!(
764            d.pop(&mut rl, None).unwrap_err(),
765            Error::InvalidMessage(InvalidMessage::InvalidEmptyPayload)
766        );
767        // CorruptMessage has been fused
768        assert_eq!(
769            d.pop(&mut rl, None).unwrap_err(),
770            Error::InvalidMessage(InvalidMessage::InvalidEmptyPayload)
771        );
772    }
773
774    #[test]
775    fn test_limited_buffer() {
776        const PAYLOAD_LEN: usize = 16_384;
777        let mut message = Vec::with_capacity(16_389);
778        message.push(0x17); // ApplicationData
779        message.extend(&[0x03, 0x04]); // ProtocolVersion
780        message.extend((PAYLOAD_LEN as u16).to_be_bytes()); // payload length
781        message.extend(&[0; PAYLOAD_LEN]);
782
783        let mut d = BufferedDeframer::default();
784        assert_len(4096, d.input_bytes(&message));
785        assert_len(4096, d.input_bytes(&message));
786        assert_len(4096, d.input_bytes(&message));
787        assert_len(4096, d.input_bytes(&message));
788        assert_len(
789            OpaqueMessage::MAX_WIRE_SIZE - 16_384,
790            d.input_bytes(&message),
791        );
792        assert!(d.input_bytes(&message).is_err());
793    }
794
795    fn input_error(d: &mut BufferedDeframer) {
796        let error = io::Error::from(io::ErrorKind::TimedOut);
797        let mut rd = ErrorRead::new(error);
798        d.read(&mut rd)
799            .expect_err("error not propagated");
800    }
801
802    fn input_whole_incremental(d: &mut BufferedDeframer, bytes: &[u8]) {
803        let before = d.buffer.len();
804
805        for i in 0..bytes.len() {
806            assert_len(1, d.input_bytes(&bytes[i..i + 1]));
807            assert!(d.has_pending());
808        }
809
810        assert_eq!(before + bytes.len(), d.buffer.len());
811    }
812
813    fn pop_first(d: &mut BufferedDeframer, rl: &mut RecordLayer) {
814        let m = d
815            .pop(rl, None)
816            .unwrap()
817            .unwrap()
818            .message;
819        assert_eq!(m.typ, ContentType::Handshake);
820        Message::try_from(m).unwrap();
821    }
822
823    fn pop_second(d: &mut BufferedDeframer, rl: &mut RecordLayer) {
824        let m = d
825            .pop(rl, None)
826            .unwrap()
827            .unwrap()
828            .message;
829        assert_eq!(m.typ, ContentType::Alert);
830        Message::try_from(m).unwrap();
831    }
832
833    // buffered version to ease testing
834    #[derive(Default)]
835    struct BufferedDeframer {
836        inner: MessageDeframer,
837        buffer: DeframerVecBuffer,
838    }
839
840    impl BufferedDeframer {
841        fn input_bytes(&mut self, bytes: &[u8]) -> io::Result<usize> {
842            let mut rd = io::Cursor::new(bytes);
843            self.read(&mut rd)
844        }
845
846        fn input_bytes_concat(&mut self, bytes1: &[u8], bytes2: &[u8]) -> io::Result<usize> {
847            let mut bytes = vec![0u8; bytes1.len() + bytes2.len()];
848            bytes[..bytes1.len()].clone_from_slice(bytes1);
849            bytes[bytes1.len()..].clone_from_slice(bytes2);
850            let mut rd = io::Cursor::new(&bytes);
851            self.read(&mut rd)
852        }
853
854        fn pop(
855            &mut self,
856            record_layer: &mut RecordLayer,
857            negotiated_version: Option<ProtocolVersion>,
858        ) -> Result<Option<Deframed>, Error> {
859            let mut deframer_buffer = self.buffer.borrow();
860            let res = self
861                .inner
862                .pop(record_layer, negotiated_version, &mut deframer_buffer);
863            let discard = deframer_buffer.pending_discard();
864            self.buffer.discard(discard);
865            res
866        }
867
868        fn read(&mut self, rd: &mut dyn io::Read) -> io::Result<usize> {
869            self.inner.read(rd, &mut self.buffer)
870        }
871
872        fn has_pending(&self) -> bool {
873            self.buffer.has_pending()
874        }
875    }
876
877    // grant access to the `MessageDeframer.last_error` field
878    impl core::ops::Deref for BufferedDeframer {
879        type Target = MessageDeframer;
880
881        fn deref(&self) -> &Self::Target {
882            &self.inner
883        }
884    }
885
886    struct ErrorRead {
887        error: Option<io::Error>,
888    }
889
890    impl ErrorRead {
891        fn new(error: io::Error) -> Self {
892            Self { error: Some(error) }
893        }
894    }
895
896    impl io::Read for ErrorRead {
897        fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
898            for (i, b) in buf.iter_mut().enumerate() {
899                *b = i as u8;
900            }
901
902            let error = self.error.take().unwrap();
903            Err(error)
904        }
905    }
906
907    fn assert_len(want: usize, got: io::Result<usize>) {
908        assert_eq!(Some(want), got.ok())
909    }
910
911    const FIRST_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-test.1.bin");
912    const SECOND_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-test.2.bin");
913
914    const EMPTY_APPLICATIONDATA_MESSAGE: &[u8] =
915        include_bytes!("../testdata/deframer-empty-applicationdata.bin");
916
917    const INVALID_EMPTY_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-invalid-empty.bin");
918    const INVALID_CONTENTTYPE_MESSAGE: &[u8] =
919        include_bytes!("../testdata/deframer-invalid-contenttype.bin");
920    const INVALID_VERSION_MESSAGE: &[u8] =
921        include_bytes!("../testdata/deframer-invalid-version.bin");
922    const INVALID_LENGTH_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-invalid-length.bin");
923}