rustls/client/
handy.rs

1use crate::client;
2use crate::enums::SignatureScheme;
3use crate::error::Error;
4use crate::limited_cache;
5use crate::msgs::handshake::CertificateChain;
6use crate::msgs::persist;
7use crate::sign;
8use crate::NamedGroup;
9
10use pki_types::ServerName;
11
12use alloc::collections::VecDeque;
13use alloc::sync::Arc;
14use core::fmt;
15use std::sync::Mutex;
16
17/// An implementer of `ClientSessionStore` which does nothing.
18#[derive(Debug)]
19pub(super) struct NoClientSessionStorage;
20
21impl client::ClientSessionStore for NoClientSessionStorage {
22    fn set_kx_hint(&self, _: ServerName<'static>, _: NamedGroup) {}
23
24    fn kx_hint(&self, _: &ServerName<'_>) -> Option<NamedGroup> {
25        None
26    }
27
28    fn set_tls12_session(&self, _: ServerName<'static>, _: persist::Tls12ClientSessionValue) {}
29
30    fn tls12_session(&self, _: &ServerName<'_>) -> Option<persist::Tls12ClientSessionValue> {
31        None
32    }
33
34    fn remove_tls12_session(&self, _: &ServerName<'_>) {}
35
36    fn insert_tls13_ticket(&self, _: ServerName<'static>, _: persist::Tls13ClientSessionValue) {}
37
38    fn take_tls13_ticket(&self, _: &ServerName<'_>) -> Option<persist::Tls13ClientSessionValue> {
39        None
40    }
41}
42
43const MAX_TLS13_TICKETS_PER_SERVER: usize = 8;
44
45struct ServerData {
46    kx_hint: Option<NamedGroup>,
47
48    // Zero or one TLS1.2 sessions.
49    #[cfg(feature = "tls12")]
50    tls12: Option<persist::Tls12ClientSessionValue>,
51
52    // Up to MAX_TLS13_TICKETS_PER_SERVER TLS1.3 tickets, oldest first.
53    tls13: VecDeque<persist::Tls13ClientSessionValue>,
54}
55
56impl Default for ServerData {
57    fn default() -> Self {
58        Self {
59            kx_hint: None,
60            #[cfg(feature = "tls12")]
61            tls12: None,
62            tls13: VecDeque::with_capacity(MAX_TLS13_TICKETS_PER_SERVER),
63        }
64    }
65}
66
67/// An implementer of `ClientSessionStore` that stores everything
68/// in memory.
69///
70/// It enforces a limit on the number of entries to bound memory usage.
71pub struct ClientSessionMemoryCache {
72    servers: Mutex<limited_cache::LimitedCache<ServerName<'static>, ServerData>>,
73}
74
75impl ClientSessionMemoryCache {
76    /// Make a new ClientSessionMemoryCache.  `size` is the
77    /// maximum number of stored sessions.
78    pub fn new(size: usize) -> Self {
79        let max_servers =
80            size.saturating_add(MAX_TLS13_TICKETS_PER_SERVER - 1) / MAX_TLS13_TICKETS_PER_SERVER;
81        Self {
82            servers: Mutex::new(limited_cache::LimitedCache::new(max_servers)),
83        }
84    }
85}
86
87impl client::ClientSessionStore for ClientSessionMemoryCache {
88    fn set_kx_hint(&self, server_name: ServerName<'static>, group: NamedGroup) {
89        self.servers
90            .lock()
91            .unwrap()
92            .get_or_insert_default_and_edit(server_name, |data| data.kx_hint = Some(group));
93    }
94
95    fn kx_hint(&self, server_name: &ServerName<'_>) -> Option<NamedGroup> {
96        self.servers
97            .lock()
98            .unwrap()
99            .get(server_name)
100            .and_then(|sd| sd.kx_hint)
101    }
102
103    fn set_tls12_session(
104        &self,
105        _server_name: ServerName<'static>,
106        _value: persist::Tls12ClientSessionValue,
107    ) {
108        #[cfg(feature = "tls12")]
109        self.servers
110            .lock()
111            .unwrap()
112            .get_or_insert_default_and_edit(_server_name.clone(), |data| data.tls12 = Some(_value));
113    }
114
115    fn tls12_session(
116        &self,
117        _server_name: &ServerName<'_>,
118    ) -> Option<persist::Tls12ClientSessionValue> {
119        #[cfg(not(feature = "tls12"))]
120        return None;
121
122        #[cfg(feature = "tls12")]
123        self.servers
124            .lock()
125            .unwrap()
126            .get(_server_name)
127            .and_then(|sd| sd.tls12.as_ref().cloned())
128    }
129
130    fn remove_tls12_session(&self, _server_name: &ServerName<'static>) {
131        #[cfg(feature = "tls12")]
132        self.servers
133            .lock()
134            .unwrap()
135            .get_mut(_server_name)
136            .and_then(|data| data.tls12.take());
137    }
138
139    fn insert_tls13_ticket(
140        &self,
141        server_name: ServerName<'static>,
142        value: persist::Tls13ClientSessionValue,
143    ) {
144        self.servers
145            .lock()
146            .unwrap()
147            .get_or_insert_default_and_edit(server_name.clone(), |data| {
148                if data.tls13.len() == data.tls13.capacity() {
149                    data.tls13.pop_front();
150                }
151                data.tls13.push_back(value);
152            });
153    }
154
155    fn take_tls13_ticket(
156        &self,
157        server_name: &ServerName<'static>,
158    ) -> Option<persist::Tls13ClientSessionValue> {
159        self.servers
160            .lock()
161            .unwrap()
162            .get_mut(server_name)
163            .and_then(|data| data.tls13.pop_back())
164    }
165}
166
167impl fmt::Debug for ClientSessionMemoryCache {
168    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169        // Note: we omit self.servers as it may contain sensitive data.
170        f.debug_struct("ClientSessionMemoryCache")
171            .finish()
172    }
173}
174
175#[derive(Debug)]
176pub(super) struct FailResolveClientCert {}
177
178impl client::ResolvesClientCert for FailResolveClientCert {
179    fn resolve(
180        &self,
181        _root_hint_subjects: &[&[u8]],
182        _sigschemes: &[SignatureScheme],
183    ) -> Option<Arc<sign::CertifiedKey>> {
184        None
185    }
186
187    fn has_certs(&self) -> bool {
188        false
189    }
190}
191
192#[derive(Debug)]
193pub(super) struct AlwaysResolvesClientCert(Arc<sign::CertifiedKey>);
194
195impl AlwaysResolvesClientCert {
196    pub(super) fn new(
197        private_key: Arc<dyn sign::SigningKey>,
198        chain: CertificateChain,
199    ) -> Result<Self, Error> {
200        Ok(Self(Arc::new(sign::CertifiedKey::new(
201            chain.0,
202            private_key,
203        ))))
204    }
205}
206
207impl client::ResolvesClientCert for AlwaysResolvesClientCert {
208    fn resolve(
209        &self,
210        _root_hint_subjects: &[&[u8]],
211        _sigschemes: &[SignatureScheme],
212    ) -> Option<Arc<sign::CertifiedKey>> {
213        Some(Arc::clone(&self.0))
214    }
215
216    fn has_certs(&self) -> bool {
217        true
218    }
219}
220
221#[cfg(all(test, any(feature = "ring", feature = "aws_lc_rs")))]
222mod tests {
223    use super::NoClientSessionStorage;
224    use crate::client::ClientSessionStore;
225    use crate::msgs::enums::NamedGroup;
226    use crate::msgs::handshake::CertificateChain;
227    #[cfg(feature = "tls12")]
228    use crate::msgs::handshake::SessionId;
229    use crate::msgs::persist::Tls13ClientSessionValue;
230    use crate::suites::SupportedCipherSuite;
231    use crate::test_provider::cipher_suite;
232    use alloc::vec::Vec;
233
234    use pki_types::{ServerName, UnixTime};
235
236    #[test]
237    fn test_noclientsessionstorage_does_nothing() {
238        let c = NoClientSessionStorage {};
239        let name = ServerName::try_from("example.com").unwrap();
240        let now = UnixTime::now();
241
242        c.set_kx_hint(name.clone(), NamedGroup::X25519);
243        assert_eq!(None, c.kx_hint(&name));
244
245        #[cfg(feature = "tls12")]
246        {
247            use crate::msgs::persist::Tls12ClientSessionValue;
248            let SupportedCipherSuite::Tls12(tls12_suite) =
249                cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
250            else {
251                unreachable!()
252            };
253
254            c.set_tls12_session(
255                name.clone(),
256                Tls12ClientSessionValue::new(
257                    tls12_suite,
258                    SessionId::empty(),
259                    Vec::new(),
260                    &[],
261                    CertificateChain::default(),
262                    now,
263                    0,
264                    true,
265                ),
266            );
267            assert!(c.tls12_session(&name).is_none());
268            c.remove_tls12_session(&name);
269        }
270
271        #[cfg_attr(not(feature = "tls12"), allow(clippy::infallible_destructuring_match))]
272        let tls13_suite = match cipher_suite::TLS13_AES_256_GCM_SHA384 {
273            SupportedCipherSuite::Tls13(inner) => inner,
274            #[cfg(feature = "tls12")]
275            _ => unreachable!(),
276        };
277        c.insert_tls13_ticket(
278            name.clone(),
279            Tls13ClientSessionValue::new(
280                tls13_suite,
281                Vec::new(),
282                &[],
283                CertificateChain::default(),
284                now,
285                0,
286                0,
287                0,
288            ),
289        );
290        assert!(c.take_tls13_ticket(&name).is_none());
291    }
292}