rustls/
record_layer.rs

1use crate::crypto::cipher::{MessageDecrypter, MessageEncrypter};
2use crate::error::Error;
3use crate::msgs::message::{BorrowedPlainMessage, OpaqueMessage, PlainMessage};
4
5#[cfg(feature = "logging")]
6use crate::log::trace;
7
8use alloc::boxed::Box;
9
10static SEQ_SOFT_LIMIT: u64 = 0xffff_ffff_ffff_0000u64;
11static SEQ_HARD_LIMIT: u64 = 0xffff_ffff_ffff_fffeu64;
12
13#[derive(PartialEq)]
14enum DirectionState {
15    /// No keying material.
16    Invalid,
17
18    /// Keying material present, but not yet in use.
19    Prepared,
20
21    /// Keying material in use.
22    Active,
23}
24
25/// Record layer that tracks decryption and encryption keys.
26pub struct RecordLayer {
27    message_encrypter: Box<dyn MessageEncrypter>,
28    message_decrypter: Box<dyn MessageDecrypter>,
29    write_seq: u64,
30    read_seq: u64,
31    has_decrypted: bool,
32    encrypt_state: DirectionState,
33    decrypt_state: DirectionState,
34
35    // Message encrypted with other keys may be encountered, so failures
36    // should be swallowed by the caller.  This struct tracks the amount
37    // of message size this is allowed for.
38    trial_decryption_len: Option<usize>,
39}
40
41impl RecordLayer {
42    /// Create new record layer with no keys.
43    pub fn new() -> Self {
44        Self {
45            message_encrypter: <dyn MessageEncrypter>::invalid(),
46            message_decrypter: <dyn MessageDecrypter>::invalid(),
47            write_seq: 0,
48            read_seq: 0,
49            has_decrypted: false,
50            encrypt_state: DirectionState::Invalid,
51            decrypt_state: DirectionState::Invalid,
52            trial_decryption_len: None,
53        }
54    }
55
56    /// Decrypt a TLS message.
57    ///
58    /// `encr` is a decoded message allegedly received from the peer.
59    /// If it can be decrypted, its decryption is returned.  Otherwise,
60    /// an error is returned.
61    pub(crate) fn decrypt_incoming(
62        &mut self,
63        encr: OpaqueMessage,
64    ) -> Result<Option<Decrypted>, Error> {
65        if self.decrypt_state != DirectionState::Active {
66            return Ok(Some(Decrypted {
67                want_close_before_decrypt: false,
68                plaintext: encr.into_plain_message(),
69            }));
70        }
71
72        // Set to `true` if the peer appears to getting close to encrypting
73        // too many messages with this key.
74        //
75        // Perhaps if we send an alert well before their counter wraps, a
76        // buggy peer won't make a terrible mistake here?
77        //
78        // Note that there's no reason to refuse to decrypt: the security
79        // failure has already happened.
80        let want_close_before_decrypt = self.read_seq == SEQ_SOFT_LIMIT;
81
82        let encrypted_len = encr.payload().len();
83        match self
84            .message_decrypter
85            .decrypt(encr, self.read_seq)
86        {
87            Ok(plaintext) => {
88                self.read_seq += 1;
89                if !self.has_decrypted {
90                    self.has_decrypted = true;
91                }
92                Ok(Some(Decrypted {
93                    want_close_before_decrypt,
94                    plaintext,
95                }))
96            }
97            Err(Error::DecryptError) if self.doing_trial_decryption(encrypted_len) => {
98                trace!("Dropping undecryptable message after aborted early_data");
99                Ok(None)
100            }
101            Err(err) => Err(err),
102        }
103    }
104
105    /// Encrypt a TLS message.
106    ///
107    /// `plain` is a TLS message we'd like to send.  This function
108    /// panics if the requisite keying material hasn't been established yet.
109    pub(crate) fn encrypt_outgoing(&mut self, plain: BorrowedPlainMessage) -> OpaqueMessage {
110        debug_assert!(self.encrypt_state == DirectionState::Active);
111        assert!(!self.encrypt_exhausted());
112        let seq = self.write_seq;
113        self.write_seq += 1;
114        self.message_encrypter
115            .encrypt(plain, seq)
116            .unwrap()
117    }
118
119    /// Prepare to use the given `MessageEncrypter` for future message encryption.
120    /// It is not used until you call `start_encrypting`.
121    pub(crate) fn prepare_message_encrypter(&mut self, cipher: Box<dyn MessageEncrypter>) {
122        self.message_encrypter = cipher;
123        self.write_seq = 0;
124        self.encrypt_state = DirectionState::Prepared;
125    }
126
127    /// Prepare to use the given `MessageDecrypter` for future message decryption.
128    /// It is not used until you call `start_decrypting`.
129    pub(crate) fn prepare_message_decrypter(&mut self, cipher: Box<dyn MessageDecrypter>) {
130        self.message_decrypter = cipher;
131        self.read_seq = 0;
132        self.decrypt_state = DirectionState::Prepared;
133    }
134
135    /// Start using the `MessageEncrypter` previously provided to the previous
136    /// call to `prepare_message_encrypter`.
137    pub(crate) fn start_encrypting(&mut self) {
138        debug_assert!(self.encrypt_state == DirectionState::Prepared);
139        self.encrypt_state = DirectionState::Active;
140    }
141
142    /// Start using the `MessageDecrypter` previously provided to the previous
143    /// call to `prepare_message_decrypter`.
144    pub(crate) fn start_decrypting(&mut self) {
145        debug_assert!(self.decrypt_state == DirectionState::Prepared);
146        self.decrypt_state = DirectionState::Active;
147    }
148
149    /// Set and start using the given `MessageEncrypter` for future outgoing
150    /// message encryption.
151    pub(crate) fn set_message_encrypter(&mut self, cipher: Box<dyn MessageEncrypter>) {
152        self.prepare_message_encrypter(cipher);
153        self.start_encrypting();
154    }
155
156    /// Set and start using the given `MessageDecrypter` for future incoming
157    /// message decryption.
158    pub(crate) fn set_message_decrypter(&mut self, cipher: Box<dyn MessageDecrypter>) {
159        self.prepare_message_decrypter(cipher);
160        self.start_decrypting();
161        self.trial_decryption_len = None;
162    }
163
164    /// Set and start using the given `MessageDecrypter` for future incoming
165    /// message decryption, and enable "trial decryption" mode for when TLS1.3
166    /// 0-RTT is attempted but rejected by the server.
167    pub(crate) fn set_message_decrypter_with_trial_decryption(
168        &mut self,
169        cipher: Box<dyn MessageDecrypter>,
170        max_length: usize,
171    ) {
172        self.prepare_message_decrypter(cipher);
173        self.start_decrypting();
174        self.trial_decryption_len = Some(max_length);
175    }
176
177    pub(crate) fn finish_trial_decryption(&mut self) {
178        self.trial_decryption_len = None;
179    }
180
181    /// Return true if we are getting close to encrypting too many
182    /// messages with our encryption key.
183    pub(crate) fn wants_close_before_encrypt(&self) -> bool {
184        self.write_seq == SEQ_SOFT_LIMIT
185    }
186
187    /// Return true if we outright refuse to do anything with the
188    /// encryption key.
189    pub(crate) fn encrypt_exhausted(&self) -> bool {
190        self.write_seq >= SEQ_HARD_LIMIT
191    }
192
193    pub(crate) fn is_encrypting(&self) -> bool {
194        self.encrypt_state == DirectionState::Active
195    }
196
197    /// Return true if we have ever decrypted a message. This is used in place
198    /// of checking the read_seq since that will be reset on key updates.
199    pub(crate) fn has_decrypted(&self) -> bool {
200        self.has_decrypted
201    }
202
203    pub(crate) fn write_seq(&self) -> u64 {
204        self.write_seq
205    }
206
207    pub(crate) fn read_seq(&self) -> u64 {
208        self.read_seq
209    }
210
211    fn doing_trial_decryption(&mut self, requested: usize) -> bool {
212        match self
213            .trial_decryption_len
214            .and_then(|value| value.checked_sub(requested))
215        {
216            Some(remaining) => {
217                self.trial_decryption_len = Some(remaining);
218                true
219            }
220            _ => false,
221        }
222    }
223}
224
225/// Result of decryption.
226#[derive(Debug)]
227pub(crate) struct Decrypted {
228    /// Whether the peer appears to be getting close to encrypting too many messages with this key.
229    pub(crate) want_close_before_decrypt: bool,
230    /// The decrypted message.
231    pub(crate) plaintext: PlainMessage,
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use std::vec;
238
239    #[test]
240    fn test_has_decrypted() {
241        use crate::{ContentType, ProtocolVersion};
242
243        struct PassThroughDecrypter;
244        impl MessageDecrypter for PassThroughDecrypter {
245            fn decrypt(&mut self, m: OpaqueMessage, _: u64) -> Result<PlainMessage, Error> {
246                Ok(m.into_plain_message())
247            }
248        }
249
250        // A record layer starts out invalid, having never decrypted.
251        let mut record_layer = RecordLayer::new();
252        assert!(matches!(
253            record_layer.decrypt_state,
254            DirectionState::Invalid
255        ));
256        assert_eq!(record_layer.read_seq, 0);
257        assert!(!record_layer.has_decrypted());
258
259        // Preparing the record layer should update the decrypt state, but shouldn't affect whether it
260        // has decrypted.
261        record_layer.prepare_message_decrypter(Box::new(PassThroughDecrypter));
262        assert!(matches!(
263            record_layer.decrypt_state,
264            DirectionState::Prepared
265        ));
266        assert_eq!(record_layer.read_seq, 0);
267        assert!(!record_layer.has_decrypted());
268
269        // Starting decryption should update the decrypt state, but not affect whether it has decrypted.
270        record_layer.start_decrypting();
271        assert!(matches!(record_layer.decrypt_state, DirectionState::Active));
272        assert_eq!(record_layer.read_seq, 0);
273        assert!(!record_layer.has_decrypted());
274
275        // Decrypting a message should update the read_seq and track that we have now performed
276        // a decryption.
277        let msg = OpaqueMessage::new(
278            ContentType::Handshake,
279            ProtocolVersion::TLSv1_2,
280            vec![0xC0, 0xFF, 0xEE],
281        );
282        record_layer
283            .decrypt_incoming(msg)
284            .unwrap();
285        assert!(matches!(record_layer.decrypt_state, DirectionState::Active));
286        assert_eq!(record_layer.read_seq, 1);
287        assert!(record_layer.has_decrypted());
288
289        // Resetting the record layer message decrypter (as if a key update occurred) should reset
290        // the read_seq number, but not our knowledge of whether we have decrypted previously.
291        record_layer.set_message_decrypter(Box::new(PassThroughDecrypter));
292        assert!(matches!(record_layer.decrypt_state, DirectionState::Active));
293        assert_eq!(record_layer.read_seq, 0);
294        assert!(record_layer.has_decrypted());
295    }
296}