1use crate::error::Error;
2use crate::limited_cache;
3use crate::msgs::handshake::CertificateChain;
4use crate::server;
5use crate::server::ClientHello;
6use crate::sign;
7use crate::webpki::{verify_server_name, ParsedCertificate};
8
9use pki_types::{DnsName, ServerName};
10
11use alloc::string::{String, ToString};
12use alloc::sync::Arc;
13use alloc::vec::Vec;
14use core::fmt::{Debug, Formatter};
15use std::collections::HashMap;
16use std::sync::Mutex;
17
18#[derive(Debug)]
20pub struct NoServerSessionStorage {}
21
22impl server::StoresServerSessions for NoServerSessionStorage {
23 fn put(&self, _id: Vec<u8>, _sec: Vec<u8>) -> bool {
24 false
25 }
26 fn get(&self, _id: &[u8]) -> Option<Vec<u8>> {
27 None
28 }
29 fn take(&self, _id: &[u8]) -> Option<Vec<u8>> {
30 None
31 }
32 fn can_cache(&self) -> bool {
33 false
34 }
35}
36
37pub struct ServerSessionMemoryCache {
41 cache: Mutex<limited_cache::LimitedCache<Vec<u8>, Vec<u8>>>,
42}
43
44impl ServerSessionMemoryCache {
45 pub fn new(size: usize) -> Arc<Self> {
49 Arc::new(Self {
50 cache: Mutex::new(limited_cache::LimitedCache::new(size)),
51 })
52 }
53}
54
55impl server::StoresServerSessions for ServerSessionMemoryCache {
56 fn put(&self, key: Vec<u8>, value: Vec<u8>) -> bool {
57 self.cache
58 .lock()
59 .unwrap()
60 .insert(key, value);
61 true
62 }
63
64 fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
65 self.cache
66 .lock()
67 .unwrap()
68 .get(key)
69 .cloned()
70 }
71
72 fn take(&self, key: &[u8]) -> Option<Vec<u8>> {
73 self.cache.lock().unwrap().remove(key)
74 }
75
76 fn can_cache(&self) -> bool {
77 true
78 }
79}
80
81impl Debug for ServerSessionMemoryCache {
82 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
83 f.debug_struct("ServerSessionMemoryCache")
84 .finish()
85 }
86}
87
88#[derive(Debug)]
90pub(super) struct NeverProducesTickets {}
91
92impl server::ProducesTickets for NeverProducesTickets {
93 fn enabled(&self) -> bool {
94 false
95 }
96 fn lifetime(&self) -> u32 {
97 0
98 }
99 fn encrypt(&self, _bytes: &[u8]) -> Option<Vec<u8>> {
100 None
101 }
102 fn decrypt(&self, _bytes: &[u8]) -> Option<Vec<u8>> {
103 None
104 }
105}
106
107#[derive(Debug)]
109pub(super) struct AlwaysResolvesChain(Arc<sign::CertifiedKey>);
110
111impl AlwaysResolvesChain {
112 pub(super) fn new(private_key: Arc<dyn sign::SigningKey>, chain: CertificateChain) -> Self {
114 Self(Arc::new(sign::CertifiedKey::new(chain.0, private_key)))
115 }
116
117 pub(super) fn new_with_extras(
121 private_key: Arc<dyn sign::SigningKey>,
122 chain: CertificateChain,
123 ocsp: Vec<u8>,
124 ) -> Self {
125 let mut r = Self::new(private_key, chain);
126
127 {
128 let cert = Arc::make_mut(&mut r.0);
129 if !ocsp.is_empty() {
130 cert.ocsp = Some(ocsp);
131 }
132 }
133
134 r
135 }
136}
137
138impl server::ResolvesServerCert for AlwaysResolvesChain {
139 fn resolve(&self, _client_hello: ClientHello) -> Option<Arc<sign::CertifiedKey>> {
140 Some(Arc::clone(&self.0))
141 }
142}
143
144#[derive(Debug)]
147pub struct ResolvesServerCertUsingSni {
148 by_name: HashMap<String, Arc<sign::CertifiedKey>>,
149}
150
151impl ResolvesServerCertUsingSni {
152 pub fn new() -> Self {
154 Self {
155 by_name: HashMap::new(),
156 }
157 }
158
159 pub fn add(&mut self, name: &str, ck: sign::CertifiedKey) -> Result<(), Error> {
165 let server_name = {
166 let checked_name = DnsName::try_from(name)
167 .map_err(|_| Error::General("Bad DNS name".into()))
168 .map(|name| name.to_lowercase_owned())?;
169 ServerName::DnsName(checked_name)
170 };
171
172 ck.end_entity_cert()
182 .and_then(ParsedCertificate::try_from)
183 .and_then(|cert| verify_server_name(&cert, &server_name))?;
184
185 if let ServerName::DnsName(name) = server_name {
186 self.by_name
187 .insert(name.as_ref().to_string(), Arc::new(ck));
188 }
189 Ok(())
190 }
191}
192
193impl server::ResolvesServerCert for ResolvesServerCertUsingSni {
194 fn resolve(&self, client_hello: ClientHello) -> Option<Arc<sign::CertifiedKey>> {
195 if let Some(name) = client_hello.server_name() {
196 self.by_name.get(name).cloned()
197 } else {
198 None
200 }
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use crate::server::ProducesTickets;
208 use crate::server::ResolvesServerCert;
209 use crate::server::StoresServerSessions;
210 use std::vec;
211
212 #[test]
213 fn test_noserversessionstorage_drops_put() {
214 let c = NoServerSessionStorage {};
215 assert!(!c.put(vec![0x01], vec![0x02]));
216 }
217
218 #[test]
219 fn test_noserversessionstorage_denies_gets() {
220 let c = NoServerSessionStorage {};
221 c.put(vec![0x01], vec![0x02]);
222 assert_eq!(c.get(&[]), None);
223 assert_eq!(c.get(&[0x01]), None);
224 assert_eq!(c.get(&[0x02]), None);
225 }
226
227 #[test]
228 fn test_noserversessionstorage_denies_takes() {
229 let c = NoServerSessionStorage {};
230 assert_eq!(c.take(&[]), None);
231 assert_eq!(c.take(&[0x01]), None);
232 assert_eq!(c.take(&[0x02]), None);
233 }
234
235 #[test]
236 fn test_serversessionmemorycache_accepts_put() {
237 let c = ServerSessionMemoryCache::new(4);
238 assert!(c.put(vec![0x01], vec![0x02]));
239 }
240
241 #[test]
242 fn test_serversessionmemorycache_persists_put() {
243 let c = ServerSessionMemoryCache::new(4);
244 assert!(c.put(vec![0x01], vec![0x02]));
245 assert_eq!(c.get(&[0x01]), Some(vec![0x02]));
246 assert_eq!(c.get(&[0x01]), Some(vec![0x02]));
247 }
248
249 #[test]
250 fn test_serversessionmemorycache_overwrites_put() {
251 let c = ServerSessionMemoryCache::new(4);
252 assert!(c.put(vec![0x01], vec![0x02]));
253 assert!(c.put(vec![0x01], vec![0x04]));
254 assert_eq!(c.get(&[0x01]), Some(vec![0x04]));
255 }
256
257 #[test]
258 fn test_serversessionmemorycache_drops_to_maintain_size_invariant() {
259 let c = ServerSessionMemoryCache::new(2);
260 assert!(c.put(vec![0x01], vec![0x02]));
261 assert!(c.put(vec![0x03], vec![0x04]));
262 assert!(c.put(vec![0x05], vec![0x06]));
263 assert!(c.put(vec![0x07], vec![0x08]));
264 assert!(c.put(vec![0x09], vec![0x0a]));
265
266 let count = c.get(&[0x01]).iter().count()
267 + c.get(&[0x03]).iter().count()
268 + c.get(&[0x05]).iter().count()
269 + c.get(&[0x07]).iter().count()
270 + c.get(&[0x09]).iter().count();
271
272 assert!(count < 5);
273 }
274
275 #[test]
276 fn test_neverproducestickets_does_nothing() {
277 let npt = NeverProducesTickets {};
278 assert!(!npt.enabled());
279 assert_eq!(0, npt.lifetime());
280 assert_eq!(None, npt.encrypt(&[]));
281 assert_eq!(None, npt.decrypt(&[]));
282 }
283
284 #[test]
285 fn test_resolvesservercertusingsni_requires_sni() {
286 let rscsni = ResolvesServerCertUsingSni::new();
287 assert!(rscsni
288 .resolve(ClientHello::new(&None, &[], None, &[]))
289 .is_none());
290 }
291
292 #[test]
293 fn test_resolvesservercertusingsni_handles_unknown_name() {
294 let rscsni = ResolvesServerCertUsingSni::new();
295 let name = DnsName::try_from("hello.com")
296 .unwrap()
297 .to_owned();
298 assert!(rscsni
299 .resolve(ClientHello::new(&Some(name), &[], None, &[]))
300 .is_none());
301 }
302}