rustls/msgs/
fragmenter.rs

1use crate::enums::ContentType;
2use crate::enums::ProtocolVersion;
3use crate::msgs::message::{BorrowedPlainMessage, PlainMessage};
4use crate::Error;
5pub(crate) const MAX_FRAGMENT_LEN: usize = 16384;
6pub(crate) const PACKET_OVERHEAD: usize = 1 + 2 + 2;
7pub(crate) const MAX_FRAGMENT_SIZE: usize = MAX_FRAGMENT_LEN + PACKET_OVERHEAD;
8
9pub struct MessageFragmenter {
10    max_frag: usize,
11}
12
13impl Default for MessageFragmenter {
14    fn default() -> Self {
15        Self {
16            max_frag: MAX_FRAGMENT_LEN,
17        }
18    }
19}
20
21impl MessageFragmenter {
22    /// Take the Message `msg` and re-fragment it into new
23    /// messages whose fragment is no more than max_frag.
24    /// Return an iterator across those messages.
25    /// Payloads are borrowed.
26    pub fn fragment_message<'a>(
27        &self,
28        msg: &'a PlainMessage,
29    ) -> impl Iterator<Item = BorrowedPlainMessage<'a>> + 'a {
30        self.fragment_slice(msg.typ, msg.version, &msg.payload.0)
31    }
32
33    /// Enqueue borrowed fragments of (version, typ, payload) which
34    /// are no longer than max_frag onto the `out` deque.
35    pub(crate) fn fragment_slice<'a>(
36        &self,
37        typ: ContentType,
38        version: ProtocolVersion,
39        payload: &'a [u8],
40    ) -> impl Iterator<Item = BorrowedPlainMessage<'a>> + 'a {
41        payload
42            .chunks(self.max_frag)
43            .map(move |c| BorrowedPlainMessage {
44                typ,
45                version,
46                payload: c,
47            })
48    }
49
50    /// Set the maximum fragment size that will be produced.
51    ///
52    /// This includes overhead. A `max_fragment_size` of 10 will produce TLS fragments
53    /// up to 10 bytes long.
54    ///
55    /// A `max_fragment_size` of `None` sets the highest allowable fragment size.
56    ///
57    /// Returns BadMaxFragmentSize if the size is smaller than 32 or larger than 16389.
58    pub fn set_max_fragment_size(&mut self, max_fragment_size: Option<usize>) -> Result<(), Error> {
59        self.max_frag = match max_fragment_size {
60            Some(sz @ 32..=MAX_FRAGMENT_SIZE) => sz - PACKET_OVERHEAD,
61            None => MAX_FRAGMENT_LEN,
62            _ => return Err(Error::BadMaxFragmentSize),
63        };
64        Ok(())
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use super::{MessageFragmenter, PACKET_OVERHEAD};
71    use crate::enums::ContentType;
72    use crate::enums::ProtocolVersion;
73    use crate::msgs::base::Payload;
74    use crate::msgs::message::{BorrowedPlainMessage, PlainMessage};
75    use std::prelude::v1::*;
76
77    fn msg_eq(
78        m: &BorrowedPlainMessage,
79        total_len: usize,
80        typ: &ContentType,
81        version: &ProtocolVersion,
82        bytes: &[u8],
83    ) {
84        assert_eq!(&m.typ, typ);
85        assert_eq!(&m.version, version);
86        assert_eq!(m.payload, bytes);
87
88        let buf = m.to_unencrypted_opaque().encode();
89
90        assert_eq!(total_len, buf.len());
91    }
92
93    #[test]
94    fn smoke() {
95        let typ = ContentType::Handshake;
96        let version = ProtocolVersion::TLSv1_2;
97        let data: Vec<u8> = (1..70u8).collect();
98        let m = PlainMessage {
99            typ,
100            version,
101            payload: Payload::new(data),
102        };
103
104        let mut frag = MessageFragmenter::default();
105        frag.set_max_fragment_size(Some(32))
106            .unwrap();
107        let q = frag
108            .fragment_message(&m)
109            .collect::<Vec<_>>();
110        assert_eq!(q.len(), 3);
111        msg_eq(
112            &q[0],
113            32,
114            &typ,
115            &version,
116            &[
117                1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
118                24, 25, 26, 27,
119            ],
120        );
121        msg_eq(
122            &q[1],
123            32,
124            &typ,
125            &version,
126            &[
127                28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
128                49, 50, 51, 52, 53, 54,
129            ],
130        );
131        msg_eq(
132            &q[2],
133            20,
134            &typ,
135            &version,
136            &[55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69],
137        );
138    }
139
140    #[test]
141    fn non_fragment() {
142        let m = PlainMessage {
143            typ: ContentType::Handshake,
144            version: ProtocolVersion::TLSv1_2,
145            payload: Payload::new(b"\x01\x02\x03\x04\x05\x06\x07\x08".to_vec()),
146        };
147
148        let mut frag = MessageFragmenter::default();
149        frag.set_max_fragment_size(Some(32))
150            .unwrap();
151        let q = frag
152            .fragment_message(&m)
153            .collect::<Vec<_>>();
154        assert_eq!(q.len(), 1);
155        msg_eq(
156            &q[0],
157            PACKET_OVERHEAD + 8,
158            &ContentType::Handshake,
159            &ProtocolVersion::TLSv1_2,
160            b"\x01\x02\x03\x04\x05\x06\x07\x08",
161        );
162    }
163}