rustls/msgs/
persist.rs

1use crate::enums::{CipherSuite, ProtocolVersion};
2use crate::error::InvalidMessage;
3use crate::msgs::base::{PayloadU16, PayloadU8};
4use crate::msgs::codec::{Codec, Reader};
5use crate::msgs::handshake::CertificateChain;
6#[cfg(feature = "tls12")]
7use crate::msgs::handshake::SessionId;
8#[cfg(feature = "tls12")]
9use crate::tls12::Tls12CipherSuite;
10use crate::tls13::Tls13CipherSuite;
11
12use pki_types::{DnsName, UnixTime};
13use zeroize::Zeroizing;
14
15use alloc::vec::Vec;
16use core::cmp;
17#[cfg(feature = "tls12")]
18use core::mem;
19
20pub(crate) struct Retrieved<T> {
21    pub(crate) value: T,
22    retrieved_at: UnixTime,
23}
24
25impl<T> Retrieved<T> {
26    pub(crate) fn new(value: T, retrieved_at: UnixTime) -> Self {
27        Self {
28            value,
29            retrieved_at,
30        }
31    }
32
33    pub(crate) fn map<M>(&self, f: impl FnOnce(&T) -> Option<&M>) -> Option<Retrieved<&M>> {
34        Some(Retrieved {
35            value: f(&self.value)?,
36            retrieved_at: self.retrieved_at,
37        })
38    }
39}
40
41impl Retrieved<&Tls13ClientSessionValue> {
42    pub(crate) fn obfuscated_ticket_age(&self) -> u32 {
43        let age_secs = self
44            .retrieved_at
45            .as_secs()
46            .saturating_sub(self.value.common.epoch);
47        let age_millis = age_secs as u32 * 1000;
48        age_millis.wrapping_add(self.value.age_add)
49    }
50}
51
52impl<T: core::ops::Deref<Target = ClientSessionCommon>> Retrieved<T> {
53    pub(crate) fn has_expired(&self) -> bool {
54        let common = &*self.value;
55        common.lifetime_secs != 0
56            && common
57                .epoch
58                .saturating_add(u64::from(common.lifetime_secs))
59                < self.retrieved_at.as_secs()
60    }
61}
62
63impl<T> core::ops::Deref for Retrieved<T> {
64    type Target = T;
65
66    fn deref(&self) -> &Self::Target {
67        &self.value
68    }
69}
70
71#[derive(Debug)]
72pub struct Tls13ClientSessionValue {
73    suite: &'static Tls13CipherSuite,
74    age_add: u32,
75    max_early_data_size: u32,
76    pub(crate) common: ClientSessionCommon,
77    quic_params: PayloadU16,
78}
79
80impl Tls13ClientSessionValue {
81    pub(crate) fn new(
82        suite: &'static Tls13CipherSuite,
83        ticket: Vec<u8>,
84        secret: &[u8],
85        server_cert_chain: CertificateChain,
86        time_now: UnixTime,
87        lifetime_secs: u32,
88        age_add: u32,
89        max_early_data_size: u32,
90    ) -> Self {
91        Self {
92            suite,
93            age_add,
94            max_early_data_size,
95            common: ClientSessionCommon::new(
96                ticket,
97                secret,
98                time_now,
99                lifetime_secs,
100                server_cert_chain,
101            ),
102            quic_params: PayloadU16(Vec::new()),
103        }
104    }
105
106    pub fn max_early_data_size(&self) -> u32 {
107        self.max_early_data_size
108    }
109
110    pub fn suite(&self) -> &'static Tls13CipherSuite {
111        self.suite
112    }
113
114    #[doc(hidden)]
115    /// Test only: rewind epoch by `delta` seconds.
116    pub fn rewind_epoch(&mut self, delta: u32) {
117        self.common.epoch -= delta as u64;
118    }
119
120    pub fn set_quic_params(&mut self, quic_params: &[u8]) {
121        self.quic_params = PayloadU16(quic_params.to_vec());
122    }
123
124    pub fn quic_params(&self) -> Vec<u8> {
125        self.quic_params.0.clone()
126    }
127}
128
129impl core::ops::Deref for Tls13ClientSessionValue {
130    type Target = ClientSessionCommon;
131
132    fn deref(&self) -> &Self::Target {
133        &self.common
134    }
135}
136
137#[derive(Debug, Clone)]
138pub struct Tls12ClientSessionValue {
139    #[cfg(feature = "tls12")]
140    suite: &'static Tls12CipherSuite,
141    #[cfg(feature = "tls12")]
142    pub(crate) session_id: SessionId,
143    #[cfg(feature = "tls12")]
144    extended_ms: bool,
145    #[doc(hidden)]
146    #[cfg(feature = "tls12")]
147    pub(crate) common: ClientSessionCommon,
148}
149
150#[cfg(feature = "tls12")]
151impl Tls12ClientSessionValue {
152    pub(crate) fn new(
153        suite: &'static Tls12CipherSuite,
154        session_id: SessionId,
155        ticket: Vec<u8>,
156        master_secret: &[u8],
157        server_cert_chain: CertificateChain,
158        time_now: UnixTime,
159        lifetime_secs: u32,
160        extended_ms: bool,
161    ) -> Self {
162        Self {
163            suite,
164            session_id,
165            extended_ms,
166            common: ClientSessionCommon::new(
167                ticket,
168                master_secret,
169                time_now,
170                lifetime_secs,
171                server_cert_chain,
172            ),
173        }
174    }
175
176    pub(crate) fn take_ticket(&mut self) -> Vec<u8> {
177        mem::take(&mut self.common.ticket.0)
178    }
179
180    pub(crate) fn extended_ms(&self) -> bool {
181        self.extended_ms
182    }
183
184    pub(crate) fn suite(&self) -> &'static Tls12CipherSuite {
185        self.suite
186    }
187
188    #[doc(hidden)]
189    /// Test only: rewind epoch by `delta` seconds.
190    pub fn rewind_epoch(&mut self, delta: u32) {
191        self.common.epoch -= delta as u64;
192    }
193}
194
195#[cfg(feature = "tls12")]
196impl core::ops::Deref for Tls12ClientSessionValue {
197    type Target = ClientSessionCommon;
198
199    fn deref(&self) -> &Self::Target {
200        &self.common
201    }
202}
203
204#[derive(Debug, Clone)]
205pub struct ClientSessionCommon {
206    ticket: PayloadU16,
207    secret: Zeroizing<PayloadU8>,
208    epoch: u64,
209    lifetime_secs: u32,
210    server_cert_chain: CertificateChain,
211}
212
213impl ClientSessionCommon {
214    fn new(
215        ticket: Vec<u8>,
216        secret: &[u8],
217        time_now: UnixTime,
218        lifetime_secs: u32,
219        server_cert_chain: CertificateChain,
220    ) -> Self {
221        Self {
222            ticket: PayloadU16(ticket),
223            secret: Zeroizing::new(PayloadU8(secret.to_vec())),
224            epoch: time_now.as_secs(),
225            lifetime_secs: cmp::min(lifetime_secs, MAX_TICKET_LIFETIME),
226            server_cert_chain,
227        }
228    }
229
230    pub(crate) fn server_cert_chain(&self) -> &CertificateChain {
231        &self.server_cert_chain
232    }
233
234    pub(crate) fn secret(&self) -> &[u8] {
235        self.secret.0.as_ref()
236    }
237
238    pub(crate) fn ticket(&self) -> &[u8] {
239        self.ticket.0.as_ref()
240    }
241}
242
243static MAX_TICKET_LIFETIME: u32 = 7 * 24 * 60 * 60;
244
245/// This is the maximum allowed skew between server and client clocks, over
246/// the maximum ticket lifetime period.  This encompasses TCP retransmission
247/// times in case packet loss occurs when the client sends the ClientHello
248/// or receives the NewSessionTicket, _and_ actual clock skew over this period.
249static MAX_FRESHNESS_SKEW_MS: u32 = 60 * 1000;
250
251// --- Server types ---
252#[derive(Debug)]
253pub struct ServerSessionValue {
254    pub(crate) sni: Option<DnsName<'static>>,
255    pub(crate) version: ProtocolVersion,
256    pub(crate) cipher_suite: CipherSuite,
257    pub(crate) master_secret: Zeroizing<PayloadU8>,
258    pub(crate) extended_ms: bool,
259    pub(crate) client_cert_chain: Option<CertificateChain>,
260    pub(crate) alpn: Option<PayloadU8>,
261    pub(crate) application_data: PayloadU16,
262    pub creation_time_sec: u64,
263    pub(crate) age_obfuscation_offset: u32,
264    freshness: Option<bool>,
265}
266
267impl Codec for ServerSessionValue {
268    fn encode(&self, bytes: &mut Vec<u8>) {
269        if let Some(ref sni) = self.sni {
270            1u8.encode(bytes);
271            let sni_bytes: &str = sni.as_ref();
272            PayloadU8::new(Vec::from(sni_bytes)).encode(bytes);
273        } else {
274            0u8.encode(bytes);
275        }
276        self.version.encode(bytes);
277        self.cipher_suite.encode(bytes);
278        self.master_secret.encode(bytes);
279        (u8::from(self.extended_ms)).encode(bytes);
280        if let Some(ref chain) = self.client_cert_chain {
281            1u8.encode(bytes);
282            chain.encode(bytes);
283        } else {
284            0u8.encode(bytes);
285        }
286        if let Some(ref alpn) = self.alpn {
287            1u8.encode(bytes);
288            alpn.encode(bytes);
289        } else {
290            0u8.encode(bytes);
291        }
292        self.application_data.encode(bytes);
293        self.creation_time_sec.encode(bytes);
294        self.age_obfuscation_offset
295            .encode(bytes);
296    }
297
298    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
299        let has_sni = u8::read(r)?;
300        let sni = if has_sni == 1 {
301            let dns_name = PayloadU8::read(r)?;
302            let dns_name = match DnsName::try_from(dns_name.0.as_slice()) {
303                Ok(dns_name) => dns_name.to_owned(),
304                Err(_) => return Err(InvalidMessage::InvalidServerName),
305            };
306
307            Some(dns_name)
308        } else {
309            None
310        };
311
312        let v = ProtocolVersion::read(r)?;
313        let cs = CipherSuite::read(r)?;
314        let ms = Zeroizing::new(PayloadU8::read(r)?);
315        let ems = u8::read(r)?;
316        let has_ccert = u8::read(r)? == 1;
317        let ccert = if has_ccert {
318            Some(CertificateChain::read(r)?)
319        } else {
320            None
321        };
322        let has_alpn = u8::read(r)? == 1;
323        let alpn = if has_alpn {
324            Some(PayloadU8::read(r)?)
325        } else {
326            None
327        };
328        let application_data = PayloadU16::read(r)?;
329        let creation_time_sec = u64::read(r)?;
330        let age_obfuscation_offset = u32::read(r)?;
331
332        Ok(Self {
333            sni,
334            version: v,
335            cipher_suite: cs,
336            master_secret: ms,
337            extended_ms: ems == 1u8,
338            client_cert_chain: ccert,
339            alpn,
340            application_data,
341            creation_time_sec,
342            age_obfuscation_offset,
343            freshness: None,
344        })
345    }
346}
347
348impl ServerSessionValue {
349    pub(crate) fn new(
350        sni: Option<&DnsName<'_>>,
351        v: ProtocolVersion,
352        cs: CipherSuite,
353        ms: &[u8],
354        client_cert_chain: Option<CertificateChain>,
355        alpn: Option<Vec<u8>>,
356        application_data: Vec<u8>,
357        creation_time: UnixTime,
358        age_obfuscation_offset: u32,
359    ) -> Self {
360        Self {
361            sni: sni.map(|dns| dns.to_owned()),
362            version: v,
363            cipher_suite: cs,
364            master_secret: Zeroizing::new(PayloadU8::new(ms.to_vec())),
365            extended_ms: false,
366            client_cert_chain,
367            alpn: alpn.map(PayloadU8::new),
368            application_data: PayloadU16::new(application_data),
369            creation_time_sec: creation_time.as_secs(),
370            age_obfuscation_offset,
371            freshness: None,
372        }
373    }
374
375    #[cfg(feature = "tls12")]
376    pub(crate) fn set_extended_ms_used(&mut self) {
377        self.extended_ms = true;
378    }
379
380    pub(crate) fn set_freshness(
381        mut self,
382        obfuscated_client_age_ms: u32,
383        time_now: UnixTime,
384    ) -> Self {
385        let client_age_ms = obfuscated_client_age_ms.wrapping_sub(self.age_obfuscation_offset);
386        let server_age_ms = (time_now
387            .as_secs()
388            .saturating_sub(self.creation_time_sec) as u32)
389            .saturating_mul(1000);
390
391        let age_difference = if client_age_ms < server_age_ms {
392            server_age_ms - client_age_ms
393        } else {
394            client_age_ms - server_age_ms
395        };
396
397        self.freshness = Some(age_difference <= MAX_FRESHNESS_SKEW_MS);
398        self
399    }
400
401    pub(crate) fn is_fresh(&self) -> bool {
402        self.freshness.unwrap_or_default()
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409    use crate::enums::*;
410
411    #[test]
412    fn serversessionvalue_is_debug() {
413        use std::{println, vec};
414        let ssv = ServerSessionValue::new(
415            None,
416            ProtocolVersion::TLSv1_3,
417            CipherSuite::TLS13_AES_128_GCM_SHA256,
418            &[1, 2, 3],
419            None,
420            None,
421            vec![4, 5, 6],
422            UnixTime::now(),
423            0x12345678,
424        );
425        println!("{:?}", ssv);
426    }
427
428    #[test]
429    fn serversessionvalue_no_sni() {
430        let bytes = [
431            0x00, 0x03, 0x03, 0xc0, 0x23, 0x03, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00,
432            0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89, 0xfe, 0xed, 0xf0, 0x0d,
433        ];
434        let mut rd = Reader::init(&bytes);
435        let ssv = ServerSessionValue::read(&mut rd).unwrap();
436        assert_eq!(ssv.get_encoding(), bytes);
437    }
438
439    #[test]
440    fn serversessionvalue_with_cert() {
441        let bytes = [
442            0x00, 0x03, 0x03, 0xc0, 0x23, 0x03, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00,
443            0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89, 0xfe, 0xed, 0xf0, 0x0d,
444        ];
445        let mut rd = Reader::init(&bytes);
446        let ssv = ServerSessionValue::read(&mut rd).unwrap();
447        assert_eq!(ssv.get_encoding(), bytes);
448    }
449}