rustls/msgs/
handshake.rs

1#![allow(non_camel_case_types)]
2
3#[cfg(feature = "tls12")]
4use crate::crypto::ActiveKeyExchange;
5use crate::crypto::SecureRandom;
6use crate::enums::{CipherSuite, HandshakeType, ProtocolVersion, SignatureScheme};
7use crate::error::InvalidMessage;
8#[cfg(feature = "logging")]
9use crate::log::warn;
10use crate::msgs::base::{Payload, PayloadU16, PayloadU24, PayloadU8};
11use crate::msgs::codec::{self, Codec, LengthPrefixedBuffer, ListLength, Reader, TlsListElement};
12use crate::msgs::enums::{
13    CertificateStatusType, ClientCertificateType, Compression, ECCurveType, ECPointFormat,
14    EchVersion, ExtensionType, HpkeAead, HpkeKdf, HpkeKem, KeyUpdateRequest, NamedGroup,
15    PSKKeyExchangeMode, ServerNameType,
16};
17use crate::rand;
18use crate::verify::DigitallySignedStruct;
19use crate::x509::wrap_in_sequence;
20
21use pki_types::{CertificateDer, DnsName};
22
23use alloc::collections::BTreeSet;
24#[cfg(feature = "logging")]
25use alloc::string::String;
26use alloc::vec;
27use alloc::vec::Vec;
28use core::fmt;
29use core::ops::Deref;
30
31/// Create a newtype wrapper around a given type.
32///
33/// This is used to create newtypes for the various TLS message types which is used to wrap
34/// the `PayloadU8` or `PayloadU16` types. This is typically used for types where we don't need
35/// anything other than access to the underlying bytes.
36macro_rules! wrapped_payload(
37  ($(#[$comment:meta])* $vis:vis struct $name:ident, $inner:ident,) => {
38    $(#[$comment])*
39    #[derive(Clone, Debug)]
40    $vis struct $name($inner);
41
42    impl From<Vec<u8>> for $name {
43        fn from(v: Vec<u8>) -> Self {
44            Self($inner::new(v))
45        }
46    }
47
48    impl AsRef<[u8]> for $name {
49        fn as_ref(&self) -> &[u8] {
50            self.0.0.as_slice()
51        }
52    }
53
54    impl Codec for $name {
55        fn encode(&self, bytes: &mut Vec<u8>) {
56            self.0.encode(bytes);
57        }
58
59        fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
60            Ok(Self($inner::read(r)?))
61        }
62    }
63  }
64);
65
66#[derive(Clone, Copy, Eq, PartialEq)]
67pub struct Random(pub(crate) [u8; 32]);
68
69impl fmt::Debug for Random {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        super::base::hex(f, &self.0)
72    }
73}
74
75static HELLO_RETRY_REQUEST_RANDOM: Random = Random([
76    0xcf, 0x21, 0xad, 0x74, 0xe5, 0x9a, 0x61, 0x11, 0xbe, 0x1d, 0x8c, 0x02, 0x1e, 0x65, 0xb8, 0x91,
77    0xc2, 0xa2, 0x11, 0x16, 0x7a, 0xbb, 0x8c, 0x5e, 0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c,
78]);
79
80static ZERO_RANDOM: Random = Random([0u8; 32]);
81
82impl Codec for Random {
83    fn encode(&self, bytes: &mut Vec<u8>) {
84        bytes.extend_from_slice(&self.0);
85    }
86
87    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
88        let bytes = match r.take(32) {
89            Some(bytes) => bytes,
90            None => return Err(InvalidMessage::MissingData("Random")),
91        };
92
93        let mut opaque = [0; 32];
94        opaque.clone_from_slice(bytes);
95        Ok(Self(opaque))
96    }
97}
98
99impl Random {
100    pub(crate) fn new(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
101        let mut data = [0u8; 32];
102        secure_random.fill(&mut data)?;
103        Ok(Self(data))
104    }
105}
106
107impl From<[u8; 32]> for Random {
108    #[inline]
109    fn from(bytes: [u8; 32]) -> Self {
110        Self(bytes)
111    }
112}
113
114#[derive(Copy, Clone)]
115pub struct SessionId {
116    len: usize,
117    data: [u8; 32],
118}
119
120impl fmt::Debug for SessionId {
121    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122        super::base::hex(f, &self.data[..self.len])
123    }
124}
125
126impl PartialEq for SessionId {
127    fn eq(&self, other: &Self) -> bool {
128        if self.len != other.len {
129            return false;
130        }
131
132        let mut diff = 0u8;
133        for i in 0..self.len {
134            diff |= self.data[i] ^ other.data[i];
135        }
136
137        diff == 0u8
138    }
139}
140
141impl Codec for SessionId {
142    fn encode(&self, bytes: &mut Vec<u8>) {
143        debug_assert!(self.len <= 32);
144        bytes.push(self.len as u8);
145        bytes.extend_from_slice(&self.data[..self.len]);
146    }
147
148    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
149        let len = u8::read(r)? as usize;
150        if len > 32 {
151            return Err(InvalidMessage::TrailingData("SessionID"));
152        }
153
154        let bytes = match r.take(len) {
155            Some(bytes) => bytes,
156            None => return Err(InvalidMessage::MissingData("SessionID")),
157        };
158
159        let mut out = [0u8; 32];
160        out[..len].clone_from_slice(&bytes[..len]);
161        Ok(Self { data: out, len })
162    }
163}
164
165impl SessionId {
166    pub fn random(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
167        let mut data = [0u8; 32];
168        secure_random.fill(&mut data)?;
169        Ok(Self { data, len: 32 })
170    }
171
172    pub(crate) fn empty() -> Self {
173        Self {
174            data: [0u8; 32],
175            len: 0,
176        }
177    }
178
179    #[cfg(feature = "tls12")]
180    pub(crate) fn is_empty(&self) -> bool {
181        self.len == 0
182    }
183}
184
185#[derive(Clone, Debug)]
186pub struct UnknownExtension {
187    pub(crate) typ: ExtensionType,
188    pub(crate) payload: Payload,
189}
190
191impl UnknownExtension {
192    fn encode(&self, bytes: &mut Vec<u8>) {
193        self.payload.encode(bytes);
194    }
195
196    fn read(typ: ExtensionType, r: &mut Reader) -> Self {
197        let payload = Payload::read(r);
198        Self { typ, payload }
199    }
200}
201
202impl TlsListElement for ECPointFormat {
203    const SIZE_LEN: ListLength = ListLength::U8;
204}
205
206impl TlsListElement for NamedGroup {
207    const SIZE_LEN: ListLength = ListLength::U16;
208}
209
210impl TlsListElement for SignatureScheme {
211    const SIZE_LEN: ListLength = ListLength::U16;
212}
213
214#[derive(Clone, Debug)]
215pub(crate) enum ServerNamePayload {
216    HostName(DnsName<'static>),
217    Unknown(Payload),
218}
219
220impl ServerNamePayload {
221    pub(crate) fn new_hostname(hostname: DnsName<'static>) -> Self {
222        Self::HostName(hostname)
223    }
224
225    fn read_hostname(r: &mut Reader) -> Result<Self, InvalidMessage> {
226        let raw = PayloadU16::read(r)?;
227
228        match DnsName::try_from(raw.0.as_slice()) {
229            Ok(dns_name) => Ok(Self::HostName(dns_name.to_owned())),
230            Err(_) => {
231                warn!(
232                    "Illegal SNI hostname received {:?}",
233                    String::from_utf8_lossy(&raw.0)
234                );
235                Err(InvalidMessage::InvalidServerName)
236            }
237        }
238    }
239
240    fn encode(&self, bytes: &mut Vec<u8>) {
241        match *self {
242            Self::HostName(ref name) => {
243                (name.as_ref().len() as u16).encode(bytes);
244                bytes.extend_from_slice(name.as_ref().as_bytes());
245            }
246            Self::Unknown(ref r) => r.encode(bytes),
247        }
248    }
249}
250
251#[derive(Clone, Debug)]
252pub struct ServerName {
253    pub(crate) typ: ServerNameType,
254    pub(crate) payload: ServerNamePayload,
255}
256
257impl Codec for ServerName {
258    fn encode(&self, bytes: &mut Vec<u8>) {
259        self.typ.encode(bytes);
260        self.payload.encode(bytes);
261    }
262
263    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
264        let typ = ServerNameType::read(r)?;
265
266        let payload = match typ {
267            ServerNameType::HostName => ServerNamePayload::read_hostname(r)?,
268            _ => ServerNamePayload::Unknown(Payload::read(r)),
269        };
270
271        Ok(Self { typ, payload })
272    }
273}
274
275impl TlsListElement for ServerName {
276    const SIZE_LEN: ListLength = ListLength::U16;
277}
278
279pub(crate) trait ConvertServerNameList {
280    fn has_duplicate_names_for_type(&self) -> bool;
281    fn get_single_hostname(&self) -> Option<DnsName<'_>>;
282}
283
284impl ConvertServerNameList for [ServerName] {
285    /// RFC6066: "The ServerNameList MUST NOT contain more than one name of the same name_type."
286    fn has_duplicate_names_for_type(&self) -> bool {
287        let mut seen = BTreeSet::new();
288
289        for name in self {
290            if !seen.insert(name.typ.get_u8()) {
291                return true;
292            }
293        }
294
295        false
296    }
297
298    fn get_single_hostname(&self) -> Option<DnsName<'_>> {
299        fn only_dns_hostnames(name: &ServerName) -> Option<DnsName<'_>> {
300            if let ServerNamePayload::HostName(ref dns) = name.payload {
301                Some(dns.borrow())
302            } else {
303                None
304            }
305        }
306
307        self.iter()
308            .filter_map(only_dns_hostnames)
309            .next()
310    }
311}
312
313wrapped_payload!(pub struct ProtocolName, PayloadU8,);
314
315impl TlsListElement for ProtocolName {
316    const SIZE_LEN: ListLength = ListLength::U16;
317}
318
319pub(crate) trait ConvertProtocolNameList {
320    fn from_slices(names: &[&[u8]]) -> Self;
321    fn to_slices(&self) -> Vec<&[u8]>;
322    fn as_single_slice(&self) -> Option<&[u8]>;
323}
324
325impl ConvertProtocolNameList for Vec<ProtocolName> {
326    fn from_slices(names: &[&[u8]]) -> Self {
327        let mut ret = Self::new();
328
329        for name in names {
330            ret.push(ProtocolName::from(name.to_vec()));
331        }
332
333        ret
334    }
335
336    fn to_slices(&self) -> Vec<&[u8]> {
337        self.iter()
338            .map(|proto| proto.as_ref())
339            .collect::<Vec<&[u8]>>()
340    }
341
342    fn as_single_slice(&self) -> Option<&[u8]> {
343        if self.len() == 1 {
344            Some(self[0].as_ref())
345        } else {
346            None
347        }
348    }
349}
350
351// --- TLS 1.3 Key shares ---
352#[derive(Clone, Debug)]
353pub struct KeyShareEntry {
354    pub(crate) group: NamedGroup,
355    pub(crate) payload: PayloadU16,
356}
357
358impl KeyShareEntry {
359    pub fn new(group: NamedGroup, payload: &[u8]) -> Self {
360        Self {
361            group,
362            payload: PayloadU16::new(payload.to_vec()),
363        }
364    }
365
366    pub fn group(&self) -> NamedGroup {
367        self.group
368    }
369}
370
371impl Codec for KeyShareEntry {
372    fn encode(&self, bytes: &mut Vec<u8>) {
373        self.group.encode(bytes);
374        self.payload.encode(bytes);
375    }
376
377    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
378        let group = NamedGroup::read(r)?;
379        let payload = PayloadU16::read(r)?;
380
381        Ok(Self { group, payload })
382    }
383}
384
385// --- TLS 1.3 PresharedKey offers ---
386#[derive(Clone, Debug)]
387pub(crate) struct PresharedKeyIdentity {
388    pub(crate) identity: PayloadU16,
389    pub(crate) obfuscated_ticket_age: u32,
390}
391
392impl PresharedKeyIdentity {
393    pub(crate) fn new(id: Vec<u8>, age: u32) -> Self {
394        Self {
395            identity: PayloadU16::new(id),
396            obfuscated_ticket_age: age,
397        }
398    }
399}
400
401impl Codec for PresharedKeyIdentity {
402    fn encode(&self, bytes: &mut Vec<u8>) {
403        self.identity.encode(bytes);
404        self.obfuscated_ticket_age.encode(bytes);
405    }
406
407    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
408        Ok(Self {
409            identity: PayloadU16::read(r)?,
410            obfuscated_ticket_age: u32::read(r)?,
411        })
412    }
413}
414
415impl TlsListElement for PresharedKeyIdentity {
416    const SIZE_LEN: ListLength = ListLength::U16;
417}
418
419wrapped_payload!(pub(crate) struct PresharedKeyBinder, PayloadU8,);
420
421impl TlsListElement for PresharedKeyBinder {
422    const SIZE_LEN: ListLength = ListLength::U16;
423}
424
425#[derive(Clone, Debug)]
426pub struct PresharedKeyOffer {
427    pub(crate) identities: Vec<PresharedKeyIdentity>,
428    pub(crate) binders: Vec<PresharedKeyBinder>,
429}
430
431impl PresharedKeyOffer {
432    /// Make a new one with one entry.
433    pub(crate) fn new(id: PresharedKeyIdentity, binder: Vec<u8>) -> Self {
434        Self {
435            identities: vec![id],
436            binders: vec![PresharedKeyBinder::from(binder)],
437        }
438    }
439}
440
441impl Codec for PresharedKeyOffer {
442    fn encode(&self, bytes: &mut Vec<u8>) {
443        self.identities.encode(bytes);
444        self.binders.encode(bytes);
445    }
446
447    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
448        Ok(Self {
449            identities: Vec::read(r)?,
450            binders: Vec::read(r)?,
451        })
452    }
453}
454
455// --- RFC6066 certificate status request ---
456wrapped_payload!(pub(crate) struct ResponderId, PayloadU16,);
457
458impl TlsListElement for ResponderId {
459    const SIZE_LEN: ListLength = ListLength::U16;
460}
461
462#[derive(Clone, Debug)]
463pub struct OcspCertificateStatusRequest {
464    pub(crate) responder_ids: Vec<ResponderId>,
465    pub(crate) extensions: PayloadU16,
466}
467
468impl Codec for OcspCertificateStatusRequest {
469    fn encode(&self, bytes: &mut Vec<u8>) {
470        CertificateStatusType::OCSP.encode(bytes);
471        self.responder_ids.encode(bytes);
472        self.extensions.encode(bytes);
473    }
474
475    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
476        Ok(Self {
477            responder_ids: Vec::read(r)?,
478            extensions: PayloadU16::read(r)?,
479        })
480    }
481}
482
483#[derive(Clone, Debug)]
484pub enum CertificateStatusRequest {
485    Ocsp(OcspCertificateStatusRequest),
486    Unknown((CertificateStatusType, Payload)),
487}
488
489impl Codec for CertificateStatusRequest {
490    fn encode(&self, bytes: &mut Vec<u8>) {
491        match self {
492            Self::Ocsp(ref r) => r.encode(bytes),
493            Self::Unknown((typ, payload)) => {
494                typ.encode(bytes);
495                payload.encode(bytes);
496            }
497        }
498    }
499
500    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
501        let typ = CertificateStatusType::read(r)?;
502
503        match typ {
504            CertificateStatusType::OCSP => {
505                let ocsp_req = OcspCertificateStatusRequest::read(r)?;
506                Ok(Self::Ocsp(ocsp_req))
507            }
508            _ => {
509                let data = Payload::read(r);
510                Ok(Self::Unknown((typ, data)))
511            }
512        }
513    }
514}
515
516impl CertificateStatusRequest {
517    pub(crate) fn build_ocsp() -> Self {
518        let ocsp = OcspCertificateStatusRequest {
519            responder_ids: Vec::new(),
520            extensions: PayloadU16::empty(),
521        };
522        Self::Ocsp(ocsp)
523    }
524}
525
526// ---
527
528impl TlsListElement for PSKKeyExchangeMode {
529    const SIZE_LEN: ListLength = ListLength::U8;
530}
531
532impl TlsListElement for KeyShareEntry {
533    const SIZE_LEN: ListLength = ListLength::U16;
534}
535
536impl TlsListElement for ProtocolVersion {
537    const SIZE_LEN: ListLength = ListLength::U8;
538}
539
540#[derive(Clone, Debug)]
541pub enum ClientExtension {
542    EcPointFormats(Vec<ECPointFormat>),
543    NamedGroups(Vec<NamedGroup>),
544    SignatureAlgorithms(Vec<SignatureScheme>),
545    ServerName(Vec<ServerName>),
546    SessionTicket(ClientSessionTicket),
547    Protocols(Vec<ProtocolName>),
548    SupportedVersions(Vec<ProtocolVersion>),
549    KeyShare(Vec<KeyShareEntry>),
550    PresharedKeyModes(Vec<PSKKeyExchangeMode>),
551    PresharedKey(PresharedKeyOffer),
552    Cookie(PayloadU16),
553    ExtendedMasterSecretRequest,
554    CertificateStatusRequest(CertificateStatusRequest),
555    TransportParameters(Vec<u8>),
556    TransportParametersDraft(Vec<u8>),
557    EarlyData,
558    Unknown(UnknownExtension),
559}
560
561impl ClientExtension {
562    pub(crate) fn get_type(&self) -> ExtensionType {
563        match *self {
564            Self::EcPointFormats(_) => ExtensionType::ECPointFormats,
565            Self::NamedGroups(_) => ExtensionType::EllipticCurves,
566            Self::SignatureAlgorithms(_) => ExtensionType::SignatureAlgorithms,
567            Self::ServerName(_) => ExtensionType::ServerName,
568            Self::SessionTicket(_) => ExtensionType::SessionTicket,
569            Self::Protocols(_) => ExtensionType::ALProtocolNegotiation,
570            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
571            Self::KeyShare(_) => ExtensionType::KeyShare,
572            Self::PresharedKeyModes(_) => ExtensionType::PSKKeyExchangeModes,
573            Self::PresharedKey(_) => ExtensionType::PreSharedKey,
574            Self::Cookie(_) => ExtensionType::Cookie,
575            Self::ExtendedMasterSecretRequest => ExtensionType::ExtendedMasterSecret,
576            Self::CertificateStatusRequest(_) => ExtensionType::StatusRequest,
577            Self::TransportParameters(_) => ExtensionType::TransportParameters,
578            Self::TransportParametersDraft(_) => ExtensionType::TransportParametersDraft,
579            Self::EarlyData => ExtensionType::EarlyData,
580            Self::Unknown(ref r) => r.typ,
581        }
582    }
583}
584
585impl Codec for ClientExtension {
586    fn encode(&self, bytes: &mut Vec<u8>) {
587        self.get_type().encode(bytes);
588
589        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
590        match *self {
591            Self::EcPointFormats(ref r) => r.encode(nested.buf),
592            Self::NamedGroups(ref r) => r.encode(nested.buf),
593            Self::SignatureAlgorithms(ref r) => r.encode(nested.buf),
594            Self::ServerName(ref r) => r.encode(nested.buf),
595            Self::SessionTicket(ClientSessionTicket::Request)
596            | Self::ExtendedMasterSecretRequest
597            | Self::EarlyData => {}
598            Self::SessionTicket(ClientSessionTicket::Offer(ref r)) => r.encode(nested.buf),
599            Self::Protocols(ref r) => r.encode(nested.buf),
600            Self::SupportedVersions(ref r) => r.encode(nested.buf),
601            Self::KeyShare(ref r) => r.encode(nested.buf),
602            Self::PresharedKeyModes(ref r) => r.encode(nested.buf),
603            Self::PresharedKey(ref r) => r.encode(nested.buf),
604            Self::Cookie(ref r) => r.encode(nested.buf),
605            Self::CertificateStatusRequest(ref r) => r.encode(nested.buf),
606            Self::TransportParameters(ref r) | Self::TransportParametersDraft(ref r) => {
607                nested.buf.extend_from_slice(r);
608            }
609            Self::Unknown(ref r) => r.encode(nested.buf),
610        }
611    }
612
613    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
614        let typ = ExtensionType::read(r)?;
615        let len = u16::read(r)? as usize;
616        let mut sub = r.sub(len)?;
617
618        let ext = match typ {
619            ExtensionType::ECPointFormats => Self::EcPointFormats(Vec::read(&mut sub)?),
620            ExtensionType::EllipticCurves => Self::NamedGroups(Vec::read(&mut sub)?),
621            ExtensionType::SignatureAlgorithms => Self::SignatureAlgorithms(Vec::read(&mut sub)?),
622            ExtensionType::ServerName => Self::ServerName(Vec::read(&mut sub)?),
623            ExtensionType::SessionTicket => {
624                if sub.any_left() {
625                    let contents = Payload::read(&mut sub);
626                    Self::SessionTicket(ClientSessionTicket::Offer(contents))
627                } else {
628                    Self::SessionTicket(ClientSessionTicket::Request)
629                }
630            }
631            ExtensionType::ALProtocolNegotiation => Self::Protocols(Vec::read(&mut sub)?),
632            ExtensionType::SupportedVersions => Self::SupportedVersions(Vec::read(&mut sub)?),
633            ExtensionType::KeyShare => Self::KeyShare(Vec::read(&mut sub)?),
634            ExtensionType::PSKKeyExchangeModes => Self::PresharedKeyModes(Vec::read(&mut sub)?),
635            ExtensionType::PreSharedKey => Self::PresharedKey(PresharedKeyOffer::read(&mut sub)?),
636            ExtensionType::Cookie => Self::Cookie(PayloadU16::read(&mut sub)?),
637            ExtensionType::ExtendedMasterSecret if !sub.any_left() => {
638                Self::ExtendedMasterSecretRequest
639            }
640            ExtensionType::StatusRequest => {
641                let csr = CertificateStatusRequest::read(&mut sub)?;
642                Self::CertificateStatusRequest(csr)
643            }
644            ExtensionType::TransportParameters => Self::TransportParameters(sub.rest().to_vec()),
645            ExtensionType::TransportParametersDraft => {
646                Self::TransportParametersDraft(sub.rest().to_vec())
647            }
648            ExtensionType::EarlyData if !sub.any_left() => Self::EarlyData,
649            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
650        };
651
652        sub.expect_empty("ClientExtension")
653            .map(|_| ext)
654    }
655}
656
657fn trim_hostname_trailing_dot_for_sni(dns_name: &DnsName<'_>) -> DnsName<'static> {
658    let dns_name_str = dns_name.as_ref();
659
660    // RFC6066: "The hostname is represented as a byte string using
661    // ASCII encoding without a trailing dot"
662    if dns_name_str.ends_with('.') {
663        let trimmed = &dns_name_str[0..dns_name_str.len() - 1];
664        DnsName::try_from(trimmed)
665            .unwrap()
666            .to_owned()
667    } else {
668        dns_name.to_owned()
669    }
670}
671
672impl ClientExtension {
673    /// Make a basic SNI ServerNameRequest quoting `hostname`.
674    pub(crate) fn make_sni(dns_name: &DnsName<'_>) -> Self {
675        let name = ServerName {
676            typ: ServerNameType::HostName,
677            payload: ServerNamePayload::new_hostname(trim_hostname_trailing_dot_for_sni(dns_name)),
678        };
679
680        Self::ServerName(vec![name])
681    }
682}
683
684#[derive(Clone, Debug)]
685pub enum ClientSessionTicket {
686    Request,
687    Offer(Payload),
688}
689
690#[derive(Clone, Debug)]
691pub enum ServerExtension {
692    EcPointFormats(Vec<ECPointFormat>),
693    ServerNameAck,
694    SessionTicketAck,
695    RenegotiationInfo(PayloadU8),
696    Protocols(Vec<ProtocolName>),
697    KeyShare(KeyShareEntry),
698    PresharedKey(u16),
699    ExtendedMasterSecretAck,
700    CertificateStatusAck,
701    SupportedVersions(ProtocolVersion),
702    TransportParameters(Vec<u8>),
703    TransportParametersDraft(Vec<u8>),
704    EarlyData,
705    Unknown(UnknownExtension),
706}
707
708impl ServerExtension {
709    pub(crate) fn get_type(&self) -> ExtensionType {
710        match *self {
711            Self::EcPointFormats(_) => ExtensionType::ECPointFormats,
712            Self::ServerNameAck => ExtensionType::ServerName,
713            Self::SessionTicketAck => ExtensionType::SessionTicket,
714            Self::RenegotiationInfo(_) => ExtensionType::RenegotiationInfo,
715            Self::Protocols(_) => ExtensionType::ALProtocolNegotiation,
716            Self::KeyShare(_) => ExtensionType::KeyShare,
717            Self::PresharedKey(_) => ExtensionType::PreSharedKey,
718            Self::ExtendedMasterSecretAck => ExtensionType::ExtendedMasterSecret,
719            Self::CertificateStatusAck => ExtensionType::StatusRequest,
720            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
721            Self::TransportParameters(_) => ExtensionType::TransportParameters,
722            Self::TransportParametersDraft(_) => ExtensionType::TransportParametersDraft,
723            Self::EarlyData => ExtensionType::EarlyData,
724            Self::Unknown(ref r) => r.typ,
725        }
726    }
727}
728
729impl Codec for ServerExtension {
730    fn encode(&self, bytes: &mut Vec<u8>) {
731        self.get_type().encode(bytes);
732
733        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
734        match *self {
735            Self::EcPointFormats(ref r) => r.encode(nested.buf),
736            Self::ServerNameAck
737            | Self::SessionTicketAck
738            | Self::ExtendedMasterSecretAck
739            | Self::CertificateStatusAck
740            | Self::EarlyData => {}
741            Self::RenegotiationInfo(ref r) => r.encode(nested.buf),
742            Self::Protocols(ref r) => r.encode(nested.buf),
743            Self::KeyShare(ref r) => r.encode(nested.buf),
744            Self::PresharedKey(r) => r.encode(nested.buf),
745            Self::SupportedVersions(ref r) => r.encode(nested.buf),
746            Self::TransportParameters(ref r) | Self::TransportParametersDraft(ref r) => {
747                nested.buf.extend_from_slice(r);
748            }
749            Self::Unknown(ref r) => r.encode(nested.buf),
750        }
751    }
752
753    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
754        let typ = ExtensionType::read(r)?;
755        let len = u16::read(r)? as usize;
756        let mut sub = r.sub(len)?;
757
758        let ext = match typ {
759            ExtensionType::ECPointFormats => Self::EcPointFormats(Vec::read(&mut sub)?),
760            ExtensionType::ServerName => Self::ServerNameAck,
761            ExtensionType::SessionTicket => Self::SessionTicketAck,
762            ExtensionType::StatusRequest => Self::CertificateStatusAck,
763            ExtensionType::RenegotiationInfo => Self::RenegotiationInfo(PayloadU8::read(&mut sub)?),
764            ExtensionType::ALProtocolNegotiation => Self::Protocols(Vec::read(&mut sub)?),
765            ExtensionType::KeyShare => Self::KeyShare(KeyShareEntry::read(&mut sub)?),
766            ExtensionType::PreSharedKey => Self::PresharedKey(u16::read(&mut sub)?),
767            ExtensionType::ExtendedMasterSecret => Self::ExtendedMasterSecretAck,
768            ExtensionType::SupportedVersions => {
769                Self::SupportedVersions(ProtocolVersion::read(&mut sub)?)
770            }
771            ExtensionType::TransportParameters => Self::TransportParameters(sub.rest().to_vec()),
772            ExtensionType::TransportParametersDraft => {
773                Self::TransportParametersDraft(sub.rest().to_vec())
774            }
775            ExtensionType::EarlyData => Self::EarlyData,
776            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
777        };
778
779        sub.expect_empty("ServerExtension")
780            .map(|_| ext)
781    }
782}
783
784impl ServerExtension {
785    pub(crate) fn make_alpn(proto: &[&[u8]]) -> Self {
786        Self::Protocols(Vec::from_slices(proto))
787    }
788
789    #[cfg(feature = "tls12")]
790    pub(crate) fn make_empty_renegotiation_info() -> Self {
791        let empty = Vec::new();
792        Self::RenegotiationInfo(PayloadU8::new(empty))
793    }
794}
795
796#[derive(Debug)]
797pub struct ClientHelloPayload {
798    pub client_version: ProtocolVersion,
799    pub random: Random,
800    pub session_id: SessionId,
801    pub cipher_suites: Vec<CipherSuite>,
802    pub compression_methods: Vec<Compression>,
803    pub extensions: Vec<ClientExtension>,
804}
805
806impl Codec for ClientHelloPayload {
807    fn encode(&self, bytes: &mut Vec<u8>) {
808        self.client_version.encode(bytes);
809        self.random.encode(bytes);
810        self.session_id.encode(bytes);
811        self.cipher_suites.encode(bytes);
812        self.compression_methods.encode(bytes);
813
814        if !self.extensions.is_empty() {
815            self.extensions.encode(bytes);
816        }
817    }
818
819    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
820        let mut ret = Self {
821            client_version: ProtocolVersion::read(r)?,
822            random: Random::read(r)?,
823            session_id: SessionId::read(r)?,
824            cipher_suites: Vec::read(r)?,
825            compression_methods: Vec::read(r)?,
826            extensions: Vec::new(),
827        };
828
829        if r.any_left() {
830            ret.extensions = Vec::read(r)?;
831        }
832
833        match (r.any_left(), ret.extensions.is_empty()) {
834            (true, _) => Err(InvalidMessage::TrailingData("ClientHelloPayload")),
835            (_, true) => Err(InvalidMessage::MissingData("ClientHelloPayload")),
836            _ => Ok(ret),
837        }
838    }
839}
840
841impl TlsListElement for CipherSuite {
842    const SIZE_LEN: ListLength = ListLength::U16;
843}
844
845impl TlsListElement for Compression {
846    const SIZE_LEN: ListLength = ListLength::U8;
847}
848
849impl TlsListElement for ClientExtension {
850    const SIZE_LEN: ListLength = ListLength::U16;
851}
852
853impl ClientHelloPayload {
854    /// Returns true if there is more than one extension of a given
855    /// type.
856    pub(crate) fn has_duplicate_extension(&self) -> bool {
857        let mut seen = BTreeSet::new();
858
859        for ext in &self.extensions {
860            let typ = ext.get_type().get_u16();
861
862            if seen.contains(&typ) {
863                return true;
864            }
865            seen.insert(typ);
866        }
867
868        false
869    }
870
871    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&ClientExtension> {
872        self.extensions
873            .iter()
874            .find(|x| x.get_type() == ext)
875    }
876
877    pub(crate) fn get_sni_extension(&self) -> Option<&[ServerName]> {
878        let ext = self.find_extension(ExtensionType::ServerName)?;
879        match *ext {
880            ClientExtension::ServerName(ref req) => Some(req),
881            _ => None,
882        }
883    }
884
885    pub fn get_sigalgs_extension(&self) -> Option<&[SignatureScheme]> {
886        let ext = self.find_extension(ExtensionType::SignatureAlgorithms)?;
887        match *ext {
888            ClientExtension::SignatureAlgorithms(ref req) => Some(req),
889            _ => None,
890        }
891    }
892
893    pub(crate) fn get_namedgroups_extension(&self) -> Option<&[NamedGroup]> {
894        let ext = self.find_extension(ExtensionType::EllipticCurves)?;
895        match *ext {
896            ClientExtension::NamedGroups(ref req) => Some(req),
897            _ => None,
898        }
899    }
900
901    #[cfg(feature = "tls12")]
902    pub(crate) fn get_ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
903        let ext = self.find_extension(ExtensionType::ECPointFormats)?;
904        match *ext {
905            ClientExtension::EcPointFormats(ref req) => Some(req),
906            _ => None,
907        }
908    }
909
910    pub(crate) fn get_alpn_extension(&self) -> Option<&Vec<ProtocolName>> {
911        let ext = self.find_extension(ExtensionType::ALProtocolNegotiation)?;
912        match *ext {
913            ClientExtension::Protocols(ref req) => Some(req),
914            _ => None,
915        }
916    }
917
918    pub(crate) fn get_quic_params_extension(&self) -> Option<Vec<u8>> {
919        let ext = self
920            .find_extension(ExtensionType::TransportParameters)
921            .or_else(|| self.find_extension(ExtensionType::TransportParametersDraft))?;
922        match *ext {
923            ClientExtension::TransportParameters(ref bytes)
924            | ClientExtension::TransportParametersDraft(ref bytes) => Some(bytes.to_vec()),
925            _ => None,
926        }
927    }
928
929    #[cfg(feature = "tls12")]
930    pub(crate) fn get_ticket_extension(&self) -> Option<&ClientExtension> {
931        self.find_extension(ExtensionType::SessionTicket)
932    }
933
934    pub(crate) fn get_versions_extension(&self) -> Option<&[ProtocolVersion]> {
935        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
936        match *ext {
937            ClientExtension::SupportedVersions(ref vers) => Some(vers),
938            _ => None,
939        }
940    }
941
942    pub fn get_keyshare_extension(&self) -> Option<&[KeyShareEntry]> {
943        let ext = self.find_extension(ExtensionType::KeyShare)?;
944        match *ext {
945            ClientExtension::KeyShare(ref shares) => Some(shares),
946            _ => None,
947        }
948    }
949
950    pub(crate) fn has_keyshare_extension_with_duplicates(&self) -> bool {
951        if let Some(entries) = self.get_keyshare_extension() {
952            let mut seen = BTreeSet::new();
953
954            for kse in entries {
955                let grp = kse.group.get_u16();
956
957                if !seen.insert(grp) {
958                    return true;
959                }
960            }
961        }
962
963        false
964    }
965
966    pub(crate) fn get_psk(&self) -> Option<&PresharedKeyOffer> {
967        let ext = self.find_extension(ExtensionType::PreSharedKey)?;
968        match *ext {
969            ClientExtension::PresharedKey(ref psk) => Some(psk),
970            _ => None,
971        }
972    }
973
974    pub(crate) fn check_psk_ext_is_last(&self) -> bool {
975        self.extensions
976            .last()
977            .map_or(false, |ext| ext.get_type() == ExtensionType::PreSharedKey)
978    }
979
980    pub(crate) fn get_psk_modes(&self) -> Option<&[PSKKeyExchangeMode]> {
981        let ext = self.find_extension(ExtensionType::PSKKeyExchangeModes)?;
982        match *ext {
983            ClientExtension::PresharedKeyModes(ref psk_modes) => Some(psk_modes),
984            _ => None,
985        }
986    }
987
988    pub(crate) fn psk_mode_offered(&self, mode: PSKKeyExchangeMode) -> bool {
989        self.get_psk_modes()
990            .map(|modes| modes.contains(&mode))
991            .unwrap_or(false)
992    }
993
994    pub(crate) fn set_psk_binder(&mut self, binder: impl Into<Vec<u8>>) {
995        let last_extension = self.extensions.last_mut();
996        if let Some(ClientExtension::PresharedKey(ref mut offer)) = last_extension {
997            offer.binders[0] = PresharedKeyBinder::from(binder.into());
998        }
999    }
1000
1001    #[cfg(feature = "tls12")]
1002    pub(crate) fn ems_support_offered(&self) -> bool {
1003        self.find_extension(ExtensionType::ExtendedMasterSecret)
1004            .is_some()
1005    }
1006
1007    pub(crate) fn early_data_extension_offered(&self) -> bool {
1008        self.find_extension(ExtensionType::EarlyData)
1009            .is_some()
1010    }
1011}
1012
1013#[derive(Debug)]
1014pub(crate) enum HelloRetryExtension {
1015    KeyShare(NamedGroup),
1016    Cookie(PayloadU16),
1017    SupportedVersions(ProtocolVersion),
1018    Unknown(UnknownExtension),
1019}
1020
1021impl HelloRetryExtension {
1022    pub(crate) fn get_type(&self) -> ExtensionType {
1023        match *self {
1024            Self::KeyShare(_) => ExtensionType::KeyShare,
1025            Self::Cookie(_) => ExtensionType::Cookie,
1026            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
1027            Self::Unknown(ref r) => r.typ,
1028        }
1029    }
1030}
1031
1032impl Codec for HelloRetryExtension {
1033    fn encode(&self, bytes: &mut Vec<u8>) {
1034        self.get_type().encode(bytes);
1035
1036        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1037        match *self {
1038            Self::KeyShare(ref r) => r.encode(nested.buf),
1039            Self::Cookie(ref r) => r.encode(nested.buf),
1040            Self::SupportedVersions(ref r) => r.encode(nested.buf),
1041            Self::Unknown(ref r) => r.encode(nested.buf),
1042        }
1043    }
1044
1045    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1046        let typ = ExtensionType::read(r)?;
1047        let len = u16::read(r)? as usize;
1048        let mut sub = r.sub(len)?;
1049
1050        let ext = match typ {
1051            ExtensionType::KeyShare => Self::KeyShare(NamedGroup::read(&mut sub)?),
1052            ExtensionType::Cookie => Self::Cookie(PayloadU16::read(&mut sub)?),
1053            ExtensionType::SupportedVersions => {
1054                Self::SupportedVersions(ProtocolVersion::read(&mut sub)?)
1055            }
1056            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1057        };
1058
1059        sub.expect_empty("HelloRetryExtension")
1060            .map(|_| ext)
1061    }
1062}
1063
1064impl TlsListElement for HelloRetryExtension {
1065    const SIZE_LEN: ListLength = ListLength::U16;
1066}
1067
1068#[derive(Debug)]
1069pub struct HelloRetryRequest {
1070    pub(crate) legacy_version: ProtocolVersion,
1071    pub session_id: SessionId,
1072    pub(crate) cipher_suite: CipherSuite,
1073    pub(crate) extensions: Vec<HelloRetryExtension>,
1074}
1075
1076impl Codec for HelloRetryRequest {
1077    fn encode(&self, bytes: &mut Vec<u8>) {
1078        self.legacy_version.encode(bytes);
1079        HELLO_RETRY_REQUEST_RANDOM.encode(bytes);
1080        self.session_id.encode(bytes);
1081        self.cipher_suite.encode(bytes);
1082        Compression::Null.encode(bytes);
1083        self.extensions.encode(bytes);
1084    }
1085
1086    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1087        let session_id = SessionId::read(r)?;
1088        let cipher_suite = CipherSuite::read(r)?;
1089        let compression = Compression::read(r)?;
1090
1091        if compression != Compression::Null {
1092            return Err(InvalidMessage::UnsupportedCompression);
1093        }
1094
1095        Ok(Self {
1096            legacy_version: ProtocolVersion::Unknown(0),
1097            session_id,
1098            cipher_suite,
1099            extensions: Vec::read(r)?,
1100        })
1101    }
1102}
1103
1104impl HelloRetryRequest {
1105    /// Returns true if there is more than one extension of a given
1106    /// type.
1107    pub(crate) fn has_duplicate_extension(&self) -> bool {
1108        let mut seen = BTreeSet::new();
1109
1110        for ext in &self.extensions {
1111            let typ = ext.get_type().get_u16();
1112
1113            if seen.contains(&typ) {
1114                return true;
1115            }
1116            seen.insert(typ);
1117        }
1118
1119        false
1120    }
1121
1122    pub(crate) fn has_unknown_extension(&self) -> bool {
1123        self.extensions.iter().any(|ext| {
1124            ext.get_type() != ExtensionType::KeyShare
1125                && ext.get_type() != ExtensionType::SupportedVersions
1126                && ext.get_type() != ExtensionType::Cookie
1127        })
1128    }
1129
1130    fn find_extension(&self, ext: ExtensionType) -> Option<&HelloRetryExtension> {
1131        self.extensions
1132            .iter()
1133            .find(|x| x.get_type() == ext)
1134    }
1135
1136    pub fn get_requested_key_share_group(&self) -> Option<NamedGroup> {
1137        let ext = self.find_extension(ExtensionType::KeyShare)?;
1138        match *ext {
1139            HelloRetryExtension::KeyShare(grp) => Some(grp),
1140            _ => None,
1141        }
1142    }
1143
1144    pub(crate) fn get_cookie(&self) -> Option<&PayloadU16> {
1145        let ext = self.find_extension(ExtensionType::Cookie)?;
1146        match *ext {
1147            HelloRetryExtension::Cookie(ref ck) => Some(ck),
1148            _ => None,
1149        }
1150    }
1151
1152    pub(crate) fn get_supported_versions(&self) -> Option<ProtocolVersion> {
1153        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1154        match *ext {
1155            HelloRetryExtension::SupportedVersions(ver) => Some(ver),
1156            _ => None,
1157        }
1158    }
1159}
1160
1161#[derive(Debug)]
1162pub struct ServerHelloPayload {
1163    pub(crate) legacy_version: ProtocolVersion,
1164    pub(crate) random: Random,
1165    pub(crate) session_id: SessionId,
1166    pub(crate) cipher_suite: CipherSuite,
1167    pub(crate) compression_method: Compression,
1168    pub(crate) extensions: Vec<ServerExtension>,
1169}
1170
1171impl Codec for ServerHelloPayload {
1172    fn encode(&self, bytes: &mut Vec<u8>) {
1173        self.legacy_version.encode(bytes);
1174        self.random.encode(bytes);
1175
1176        self.session_id.encode(bytes);
1177        self.cipher_suite.encode(bytes);
1178        self.compression_method.encode(bytes);
1179
1180        if !self.extensions.is_empty() {
1181            self.extensions.encode(bytes);
1182        }
1183    }
1184
1185    // minus version and random, which have already been read.
1186    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1187        let session_id = SessionId::read(r)?;
1188        let suite = CipherSuite::read(r)?;
1189        let compression = Compression::read(r)?;
1190
1191        // RFC5246:
1192        // "The presence of extensions can be detected by determining whether
1193        //  there are bytes following the compression_method field at the end of
1194        //  the ServerHello."
1195        let extensions = if r.any_left() { Vec::read(r)? } else { vec![] };
1196
1197        let ret = Self {
1198            legacy_version: ProtocolVersion::Unknown(0),
1199            random: ZERO_RANDOM,
1200            session_id,
1201            cipher_suite: suite,
1202            compression_method: compression,
1203            extensions,
1204        };
1205
1206        r.expect_empty("ServerHelloPayload")
1207            .map(|_| ret)
1208    }
1209}
1210
1211impl HasServerExtensions for ServerHelloPayload {
1212    fn get_extensions(&self) -> &[ServerExtension] {
1213        &self.extensions
1214    }
1215}
1216
1217impl ServerHelloPayload {
1218    pub(crate) fn get_key_share(&self) -> Option<&KeyShareEntry> {
1219        let ext = self.find_extension(ExtensionType::KeyShare)?;
1220        match *ext {
1221            ServerExtension::KeyShare(ref share) => Some(share),
1222            _ => None,
1223        }
1224    }
1225
1226    pub(crate) fn get_psk_index(&self) -> Option<u16> {
1227        let ext = self.find_extension(ExtensionType::PreSharedKey)?;
1228        match *ext {
1229            ServerExtension::PresharedKey(ref index) => Some(*index),
1230            _ => None,
1231        }
1232    }
1233
1234    pub(crate) fn get_ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
1235        let ext = self.find_extension(ExtensionType::ECPointFormats)?;
1236        match *ext {
1237            ServerExtension::EcPointFormats(ref fmts) => Some(fmts),
1238            _ => None,
1239        }
1240    }
1241
1242    #[cfg(feature = "tls12")]
1243    pub(crate) fn ems_support_acked(&self) -> bool {
1244        self.find_extension(ExtensionType::ExtendedMasterSecret)
1245            .is_some()
1246    }
1247
1248    pub(crate) fn get_supported_versions(&self) -> Option<ProtocolVersion> {
1249        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1250        match *ext {
1251            ServerExtension::SupportedVersions(vers) => Some(vers),
1252            _ => None,
1253        }
1254    }
1255}
1256
1257#[derive(Clone, Default, Debug)]
1258pub struct CertificateChain(pub Vec<CertificateDer<'static>>);
1259
1260impl Codec for CertificateChain {
1261    fn encode(&self, bytes: &mut Vec<u8>) {
1262        Vec::encode(&self.0, bytes)
1263    }
1264
1265    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1266        Vec::read(r).map(Self)
1267    }
1268}
1269
1270impl Deref for CertificateChain {
1271    type Target = [CertificateDer<'static>];
1272
1273    fn deref(&self) -> &[CertificateDer<'static>] {
1274        &self.0
1275    }
1276}
1277
1278impl TlsListElement for CertificateDer<'_> {
1279    const SIZE_LEN: ListLength = ListLength::U24 { max: 0x1_0000 };
1280}
1281
1282// TLS1.3 changes the Certificate payload encoding.
1283// That's annoying. It means the parsing is not
1284// context-free any more.
1285
1286#[derive(Debug)]
1287pub(crate) enum CertificateExtension {
1288    CertificateStatus(CertificateStatus),
1289    Unknown(UnknownExtension),
1290}
1291
1292impl CertificateExtension {
1293    pub(crate) fn get_type(&self) -> ExtensionType {
1294        match *self {
1295            Self::CertificateStatus(_) => ExtensionType::StatusRequest,
1296            Self::Unknown(ref r) => r.typ,
1297        }
1298    }
1299
1300    pub(crate) fn get_cert_status(&self) -> Option<&Vec<u8>> {
1301        match *self {
1302            Self::CertificateStatus(ref cs) => Some(&cs.ocsp_response.0),
1303            _ => None,
1304        }
1305    }
1306}
1307
1308impl Codec for CertificateExtension {
1309    fn encode(&self, bytes: &mut Vec<u8>) {
1310        self.get_type().encode(bytes);
1311
1312        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1313        match *self {
1314            Self::CertificateStatus(ref r) => r.encode(nested.buf),
1315            Self::Unknown(ref r) => r.encode(nested.buf),
1316        }
1317    }
1318
1319    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1320        let typ = ExtensionType::read(r)?;
1321        let len = u16::read(r)? as usize;
1322        let mut sub = r.sub(len)?;
1323
1324        let ext = match typ {
1325            ExtensionType::StatusRequest => {
1326                let st = CertificateStatus::read(&mut sub)?;
1327                Self::CertificateStatus(st)
1328            }
1329            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1330        };
1331
1332        sub.expect_empty("CertificateExtension")
1333            .map(|_| ext)
1334    }
1335}
1336
1337impl TlsListElement for CertificateExtension {
1338    const SIZE_LEN: ListLength = ListLength::U16;
1339}
1340
1341#[derive(Debug)]
1342pub(crate) struct CertificateEntry {
1343    pub(crate) cert: CertificateDer<'static>,
1344    pub(crate) exts: Vec<CertificateExtension>,
1345}
1346
1347impl Codec for CertificateEntry {
1348    fn encode(&self, bytes: &mut Vec<u8>) {
1349        self.cert.encode(bytes);
1350        self.exts.encode(bytes);
1351    }
1352
1353    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1354        Ok(Self {
1355            cert: CertificateDer::read(r)?,
1356            exts: Vec::read(r)?,
1357        })
1358    }
1359}
1360
1361impl CertificateEntry {
1362    pub(crate) fn new(cert: CertificateDer<'static>) -> Self {
1363        Self {
1364            cert,
1365            exts: Vec::new(),
1366        }
1367    }
1368
1369    pub(crate) fn has_duplicate_extension(&self) -> bool {
1370        let mut seen = BTreeSet::new();
1371
1372        for ext in &self.exts {
1373            let typ = ext.get_type().get_u16();
1374
1375            if seen.contains(&typ) {
1376                return true;
1377            }
1378            seen.insert(typ);
1379        }
1380
1381        false
1382    }
1383
1384    pub(crate) fn has_unknown_extension(&self) -> bool {
1385        self.exts
1386            .iter()
1387            .any(|ext| ext.get_type() != ExtensionType::StatusRequest)
1388    }
1389
1390    pub(crate) fn get_ocsp_response(&self) -> Option<&Vec<u8>> {
1391        self.exts
1392            .iter()
1393            .find(|ext| ext.get_type() == ExtensionType::StatusRequest)
1394            .and_then(CertificateExtension::get_cert_status)
1395    }
1396}
1397
1398impl TlsListElement for CertificateEntry {
1399    const SIZE_LEN: ListLength = ListLength::U24 { max: 0x1_0000 };
1400}
1401
1402#[derive(Debug)]
1403pub struct CertificatePayloadTls13 {
1404    pub(crate) context: PayloadU8,
1405    pub(crate) entries: Vec<CertificateEntry>,
1406}
1407
1408impl Codec for CertificatePayloadTls13 {
1409    fn encode(&self, bytes: &mut Vec<u8>) {
1410        self.context.encode(bytes);
1411        self.entries.encode(bytes);
1412    }
1413
1414    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1415        Ok(Self {
1416            context: PayloadU8::read(r)?,
1417            entries: Vec::read(r)?,
1418        })
1419    }
1420}
1421
1422impl CertificatePayloadTls13 {
1423    pub(crate) fn new(entries: Vec<CertificateEntry>) -> Self {
1424        Self {
1425            context: PayloadU8::empty(),
1426            entries,
1427        }
1428    }
1429
1430    pub(crate) fn any_entry_has_duplicate_extension(&self) -> bool {
1431        for entry in &self.entries {
1432            if entry.has_duplicate_extension() {
1433                return true;
1434            }
1435        }
1436
1437        false
1438    }
1439
1440    pub(crate) fn any_entry_has_unknown_extension(&self) -> bool {
1441        for entry in &self.entries {
1442            if entry.has_unknown_extension() {
1443                return true;
1444            }
1445        }
1446
1447        false
1448    }
1449
1450    pub(crate) fn any_entry_has_extension(&self) -> bool {
1451        for entry in &self.entries {
1452            if !entry.exts.is_empty() {
1453                return true;
1454            }
1455        }
1456
1457        false
1458    }
1459
1460    pub(crate) fn get_end_entity_ocsp(&self) -> Vec<u8> {
1461        self.entries
1462            .first()
1463            .and_then(CertificateEntry::get_ocsp_response)
1464            .cloned()
1465            .unwrap_or_default()
1466    }
1467
1468    pub(crate) fn convert(self) -> CertificateChain {
1469        CertificateChain(
1470            self.entries
1471                .into_iter()
1472                .map(|e| e.cert)
1473                .collect(),
1474        )
1475    }
1476}
1477
1478/// Describes supported key exchange mechanisms.
1479#[derive(Clone, Copy, Debug, PartialEq)]
1480#[non_exhaustive]
1481pub enum KeyExchangeAlgorithm {
1482    /// Key exchange performed via elliptic curve Diffie-Hellman.
1483    ECDHE,
1484}
1485
1486// We don't support arbitrary curves.  It's a terrible
1487// idea and unnecessary attack surface.  Please,
1488// get a grip.
1489#[derive(Debug)]
1490pub(crate) struct EcParameters {
1491    pub(crate) curve_type: ECCurveType,
1492    pub(crate) named_group: NamedGroup,
1493}
1494
1495impl Codec for EcParameters {
1496    fn encode(&self, bytes: &mut Vec<u8>) {
1497        self.curve_type.encode(bytes);
1498        self.named_group.encode(bytes);
1499    }
1500
1501    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1502        let ct = ECCurveType::read(r)?;
1503        if ct != ECCurveType::NamedCurve {
1504            return Err(InvalidMessage::UnsupportedCurveType);
1505        }
1506
1507        let grp = NamedGroup::read(r)?;
1508
1509        Ok(Self {
1510            curve_type: ct,
1511            named_group: grp,
1512        })
1513    }
1514}
1515
1516#[derive(Debug)]
1517pub(crate) struct ClientEcdhParams {
1518    pub(crate) public: PayloadU8,
1519}
1520
1521impl Codec for ClientEcdhParams {
1522    fn encode(&self, bytes: &mut Vec<u8>) {
1523        self.public.encode(bytes);
1524    }
1525
1526    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1527        let pb = PayloadU8::read(r)?;
1528        Ok(Self { public: pb })
1529    }
1530}
1531
1532#[derive(Debug)]
1533pub(crate) struct ServerEcdhParams {
1534    pub(crate) curve_params: EcParameters,
1535    pub(crate) public: PayloadU8,
1536}
1537
1538impl ServerEcdhParams {
1539    #[cfg(feature = "tls12")]
1540    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
1541        Self {
1542            curve_params: EcParameters {
1543                curve_type: ECCurveType::NamedCurve,
1544                named_group: kx.group(),
1545            },
1546            public: PayloadU8::new(kx.pub_key().to_vec()),
1547        }
1548    }
1549}
1550
1551impl Codec for ServerEcdhParams {
1552    fn encode(&self, bytes: &mut Vec<u8>) {
1553        self.curve_params.encode(bytes);
1554        self.public.encode(bytes);
1555    }
1556
1557    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1558        let cp = EcParameters::read(r)?;
1559        let pb = PayloadU8::read(r)?;
1560
1561        Ok(Self {
1562            curve_params: cp,
1563            public: pb,
1564        })
1565    }
1566}
1567
1568#[derive(Debug)]
1569pub struct EcdheServerKeyExchange {
1570    pub(crate) params: ServerEcdhParams,
1571    pub(crate) dss: DigitallySignedStruct,
1572}
1573
1574impl Codec for EcdheServerKeyExchange {
1575    fn encode(&self, bytes: &mut Vec<u8>) {
1576        self.params.encode(bytes);
1577        self.dss.encode(bytes);
1578    }
1579
1580    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1581        let params = ServerEcdhParams::read(r)?;
1582        let dss = DigitallySignedStruct::read(r)?;
1583
1584        Ok(Self { params, dss })
1585    }
1586}
1587
1588#[derive(Debug)]
1589pub enum ServerKeyExchangePayload {
1590    Ecdhe(EcdheServerKeyExchange),
1591    Unknown(Payload),
1592}
1593
1594impl Codec for ServerKeyExchangePayload {
1595    fn encode(&self, bytes: &mut Vec<u8>) {
1596        match *self {
1597            Self::Ecdhe(ref x) => x.encode(bytes),
1598            Self::Unknown(ref x) => x.encode(bytes),
1599        }
1600    }
1601
1602    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1603        // read as Unknown, fully parse when we know the
1604        // KeyExchangeAlgorithm
1605        Ok(Self::Unknown(Payload::read(r)))
1606    }
1607}
1608
1609impl ServerKeyExchangePayload {
1610    #[cfg(feature = "tls12")]
1611    pub(crate) fn unwrap_given_kxa(
1612        &self,
1613        kxa: KeyExchangeAlgorithm,
1614    ) -> Option<EcdheServerKeyExchange> {
1615        if let Self::Unknown(ref unk) = *self {
1616            let mut rd = Reader::init(&unk.0);
1617
1618            let result = match kxa {
1619                KeyExchangeAlgorithm::ECDHE => EcdheServerKeyExchange::read(&mut rd),
1620            };
1621
1622            if !rd.any_left() {
1623                return result.ok();
1624            };
1625        }
1626
1627        None
1628    }
1629}
1630
1631// -- EncryptedExtensions (TLS1.3 only) --
1632
1633impl TlsListElement for ServerExtension {
1634    const SIZE_LEN: ListLength = ListLength::U16;
1635}
1636
1637pub(crate) trait HasServerExtensions {
1638    fn get_extensions(&self) -> &[ServerExtension];
1639
1640    /// Returns true if there is more than one extension of a given
1641    /// type.
1642    fn has_duplicate_extension(&self) -> bool {
1643        let mut seen = BTreeSet::new();
1644
1645        for ext in self.get_extensions() {
1646            let typ = ext.get_type().get_u16();
1647
1648            if seen.contains(&typ) {
1649                return true;
1650            }
1651            seen.insert(typ);
1652        }
1653
1654        false
1655    }
1656
1657    fn find_extension(&self, ext: ExtensionType) -> Option<&ServerExtension> {
1658        self.get_extensions()
1659            .iter()
1660            .find(|x| x.get_type() == ext)
1661    }
1662
1663    fn get_alpn_protocol(&self) -> Option<&[u8]> {
1664        let ext = self.find_extension(ExtensionType::ALProtocolNegotiation)?;
1665        match *ext {
1666            ServerExtension::Protocols(ref protos) => protos.as_single_slice(),
1667            _ => None,
1668        }
1669    }
1670
1671    fn get_quic_params_extension(&self) -> Option<Vec<u8>> {
1672        let ext = self
1673            .find_extension(ExtensionType::TransportParameters)
1674            .or_else(|| self.find_extension(ExtensionType::TransportParametersDraft))?;
1675        match *ext {
1676            ServerExtension::TransportParameters(ref bytes)
1677            | ServerExtension::TransportParametersDraft(ref bytes) => Some(bytes.to_vec()),
1678            _ => None,
1679        }
1680    }
1681
1682    fn early_data_extension_offered(&self) -> bool {
1683        self.find_extension(ExtensionType::EarlyData)
1684            .is_some()
1685    }
1686}
1687
1688impl HasServerExtensions for Vec<ServerExtension> {
1689    fn get_extensions(&self) -> &[ServerExtension] {
1690        self
1691    }
1692}
1693
1694impl TlsListElement for ClientCertificateType {
1695    const SIZE_LEN: ListLength = ListLength::U8;
1696}
1697
1698wrapped_payload!(
1699    /// A `DistinguishedName` is a `Vec<u8>` wrapped in internal types.
1700    ///
1701    /// It contains the DER or BER encoded [`Subject` field from RFC 5280](https://datatracker.ietf.org/doc/html/rfc5280#section-4.1.2.6)
1702    /// for a single certificate. The Subject field is [encoded as an RFC 5280 `Name`](https://datatracker.ietf.org/doc/html/rfc5280#page-116).
1703    /// It can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
1704    ///
1705    /// ```ignore
1706    /// for name in distinguished_names {
1707    ///     use x509_parser::prelude::FromDer;
1708    ///     println!("{}", x509_parser::x509::X509Name::from_der(&name.0)?.1);
1709    /// }
1710    /// ```
1711    pub struct DistinguishedName,
1712    PayloadU16,
1713);
1714
1715impl DistinguishedName {
1716    /// Create a [`DistinguishedName`] after prepending its outer SEQUENCE encoding.
1717    ///
1718    /// This can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
1719    ///
1720    /// ```ignore
1721    /// use x509_parser::prelude::FromDer;
1722    /// println!("{}", x509_parser::x509::X509Name::from_der(dn.as_ref())?.1);
1723    /// ```
1724    pub fn in_sequence(bytes: &[u8]) -> Self {
1725        Self(PayloadU16::new(wrap_in_sequence(bytes)))
1726    }
1727}
1728
1729impl TlsListElement for DistinguishedName {
1730    const SIZE_LEN: ListLength = ListLength::U16;
1731}
1732
1733#[derive(Debug)]
1734pub struct CertificateRequestPayload {
1735    pub(crate) certtypes: Vec<ClientCertificateType>,
1736    pub(crate) sigschemes: Vec<SignatureScheme>,
1737    pub(crate) canames: Vec<DistinguishedName>,
1738}
1739
1740impl Codec for CertificateRequestPayload {
1741    fn encode(&self, bytes: &mut Vec<u8>) {
1742        self.certtypes.encode(bytes);
1743        self.sigschemes.encode(bytes);
1744        self.canames.encode(bytes);
1745    }
1746
1747    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1748        let certtypes = Vec::read(r)?;
1749        let sigschemes = Vec::read(r)?;
1750        let canames = Vec::read(r)?;
1751
1752        if sigschemes.is_empty() {
1753            warn!("meaningless CertificateRequest message");
1754            Err(InvalidMessage::NoSignatureSchemes)
1755        } else {
1756            Ok(Self {
1757                certtypes,
1758                sigschemes,
1759                canames,
1760            })
1761        }
1762    }
1763}
1764
1765#[derive(Debug)]
1766pub(crate) enum CertReqExtension {
1767    SignatureAlgorithms(Vec<SignatureScheme>),
1768    AuthorityNames(Vec<DistinguishedName>),
1769    Unknown(UnknownExtension),
1770}
1771
1772impl CertReqExtension {
1773    pub(crate) fn get_type(&self) -> ExtensionType {
1774        match *self {
1775            Self::SignatureAlgorithms(_) => ExtensionType::SignatureAlgorithms,
1776            Self::AuthorityNames(_) => ExtensionType::CertificateAuthorities,
1777            Self::Unknown(ref r) => r.typ,
1778        }
1779    }
1780}
1781
1782impl Codec for CertReqExtension {
1783    fn encode(&self, bytes: &mut Vec<u8>) {
1784        self.get_type().encode(bytes);
1785
1786        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1787        match *self {
1788            Self::SignatureAlgorithms(ref r) => r.encode(nested.buf),
1789            Self::AuthorityNames(ref r) => r.encode(nested.buf),
1790            Self::Unknown(ref r) => r.encode(nested.buf),
1791        }
1792    }
1793
1794    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1795        let typ = ExtensionType::read(r)?;
1796        let len = u16::read(r)? as usize;
1797        let mut sub = r.sub(len)?;
1798
1799        let ext = match typ {
1800            ExtensionType::SignatureAlgorithms => {
1801                let schemes = Vec::read(&mut sub)?;
1802                if schemes.is_empty() {
1803                    return Err(InvalidMessage::NoSignatureSchemes);
1804                }
1805                Self::SignatureAlgorithms(schemes)
1806            }
1807            ExtensionType::CertificateAuthorities => {
1808                let cas = Vec::read(&mut sub)?;
1809                Self::AuthorityNames(cas)
1810            }
1811            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1812        };
1813
1814        sub.expect_empty("CertReqExtension")
1815            .map(|_| ext)
1816    }
1817}
1818
1819impl TlsListElement for CertReqExtension {
1820    const SIZE_LEN: ListLength = ListLength::U16;
1821}
1822
1823#[derive(Debug)]
1824pub struct CertificateRequestPayloadTls13 {
1825    pub(crate) context: PayloadU8,
1826    pub(crate) extensions: Vec<CertReqExtension>,
1827}
1828
1829impl Codec for CertificateRequestPayloadTls13 {
1830    fn encode(&self, bytes: &mut Vec<u8>) {
1831        self.context.encode(bytes);
1832        self.extensions.encode(bytes);
1833    }
1834
1835    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1836        let context = PayloadU8::read(r)?;
1837        let extensions = Vec::read(r)?;
1838
1839        Ok(Self {
1840            context,
1841            extensions,
1842        })
1843    }
1844}
1845
1846impl CertificateRequestPayloadTls13 {
1847    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&CertReqExtension> {
1848        self.extensions
1849            .iter()
1850            .find(|x| x.get_type() == ext)
1851    }
1852
1853    pub(crate) fn get_sigalgs_extension(&self) -> Option<&[SignatureScheme]> {
1854        let ext = self.find_extension(ExtensionType::SignatureAlgorithms)?;
1855        match *ext {
1856            CertReqExtension::SignatureAlgorithms(ref sa) => Some(sa),
1857            _ => None,
1858        }
1859    }
1860
1861    pub(crate) fn get_authorities_extension(&self) -> Option<&[DistinguishedName]> {
1862        let ext = self.find_extension(ExtensionType::CertificateAuthorities)?;
1863        match *ext {
1864            CertReqExtension::AuthorityNames(ref an) => Some(an),
1865            _ => None,
1866        }
1867    }
1868}
1869
1870// -- NewSessionTicket --
1871#[derive(Debug)]
1872pub struct NewSessionTicketPayload {
1873    pub(crate) lifetime_hint: u32,
1874    pub(crate) ticket: PayloadU16,
1875}
1876
1877impl NewSessionTicketPayload {
1878    #[cfg(feature = "tls12")]
1879    pub(crate) fn new(lifetime_hint: u32, ticket: Vec<u8>) -> Self {
1880        Self {
1881            lifetime_hint,
1882            ticket: PayloadU16::new(ticket),
1883        }
1884    }
1885}
1886
1887impl Codec for NewSessionTicketPayload {
1888    fn encode(&self, bytes: &mut Vec<u8>) {
1889        self.lifetime_hint.encode(bytes);
1890        self.ticket.encode(bytes);
1891    }
1892
1893    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1894        let lifetime = u32::read(r)?;
1895        let ticket = PayloadU16::read(r)?;
1896
1897        Ok(Self {
1898            lifetime_hint: lifetime,
1899            ticket,
1900        })
1901    }
1902}
1903
1904// -- NewSessionTicket electric boogaloo --
1905#[derive(Debug)]
1906pub(crate) enum NewSessionTicketExtension {
1907    EarlyData(u32),
1908    Unknown(UnknownExtension),
1909}
1910
1911impl NewSessionTicketExtension {
1912    pub(crate) fn get_type(&self) -> ExtensionType {
1913        match *self {
1914            Self::EarlyData(_) => ExtensionType::EarlyData,
1915            Self::Unknown(ref r) => r.typ,
1916        }
1917    }
1918}
1919
1920impl Codec for NewSessionTicketExtension {
1921    fn encode(&self, bytes: &mut Vec<u8>) {
1922        self.get_type().encode(bytes);
1923
1924        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1925        match *self {
1926            Self::EarlyData(r) => r.encode(nested.buf),
1927            Self::Unknown(ref r) => r.encode(nested.buf),
1928        }
1929    }
1930
1931    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1932        let typ = ExtensionType::read(r)?;
1933        let len = u16::read(r)? as usize;
1934        let mut sub = r.sub(len)?;
1935
1936        let ext = match typ {
1937            ExtensionType::EarlyData => Self::EarlyData(u32::read(&mut sub)?),
1938            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1939        };
1940
1941        sub.expect_empty("NewSessionTicketExtension")
1942            .map(|_| ext)
1943    }
1944}
1945
1946impl TlsListElement for NewSessionTicketExtension {
1947    const SIZE_LEN: ListLength = ListLength::U16;
1948}
1949
1950#[derive(Debug)]
1951pub struct NewSessionTicketPayloadTls13 {
1952    pub(crate) lifetime: u32,
1953    pub(crate) age_add: u32,
1954    pub(crate) nonce: PayloadU8,
1955    pub(crate) ticket: PayloadU16,
1956    pub(crate) exts: Vec<NewSessionTicketExtension>,
1957}
1958
1959impl NewSessionTicketPayloadTls13 {
1960    pub(crate) fn new(lifetime: u32, age_add: u32, nonce: Vec<u8>, ticket: Vec<u8>) -> Self {
1961        Self {
1962            lifetime,
1963            age_add,
1964            nonce: PayloadU8::new(nonce),
1965            ticket: PayloadU16::new(ticket),
1966            exts: vec![],
1967        }
1968    }
1969
1970    pub(crate) fn has_duplicate_extension(&self) -> bool {
1971        let mut seen = BTreeSet::new();
1972
1973        for ext in &self.exts {
1974            let typ = ext.get_type().get_u16();
1975
1976            if seen.contains(&typ) {
1977                return true;
1978            }
1979            seen.insert(typ);
1980        }
1981
1982        false
1983    }
1984
1985    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&NewSessionTicketExtension> {
1986        self.exts
1987            .iter()
1988            .find(|x| x.get_type() == ext)
1989    }
1990
1991    pub(crate) fn get_max_early_data_size(&self) -> Option<u32> {
1992        let ext = self.find_extension(ExtensionType::EarlyData)?;
1993        match *ext {
1994            NewSessionTicketExtension::EarlyData(ref sz) => Some(*sz),
1995            _ => None,
1996        }
1997    }
1998}
1999
2000impl Codec for NewSessionTicketPayloadTls13 {
2001    fn encode(&self, bytes: &mut Vec<u8>) {
2002        self.lifetime.encode(bytes);
2003        self.age_add.encode(bytes);
2004        self.nonce.encode(bytes);
2005        self.ticket.encode(bytes);
2006        self.exts.encode(bytes);
2007    }
2008
2009    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
2010        let lifetime = u32::read(r)?;
2011        let age_add = u32::read(r)?;
2012        let nonce = PayloadU8::read(r)?;
2013        let ticket = PayloadU16::read(r)?;
2014        let exts = Vec::read(r)?;
2015
2016        Ok(Self {
2017            lifetime,
2018            age_add,
2019            nonce,
2020            ticket,
2021            exts,
2022        })
2023    }
2024}
2025
2026// -- RFC6066 certificate status types
2027
2028/// Only supports OCSP
2029#[derive(Debug)]
2030pub struct CertificateStatus {
2031    pub(crate) ocsp_response: PayloadU24,
2032}
2033
2034impl Codec for CertificateStatus {
2035    fn encode(&self, bytes: &mut Vec<u8>) {
2036        CertificateStatusType::OCSP.encode(bytes);
2037        self.ocsp_response.encode(bytes);
2038    }
2039
2040    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
2041        let typ = CertificateStatusType::read(r)?;
2042
2043        match typ {
2044            CertificateStatusType::OCSP => Ok(Self {
2045                ocsp_response: PayloadU24::read(r)?,
2046            }),
2047            _ => Err(InvalidMessage::InvalidCertificateStatusType),
2048        }
2049    }
2050}
2051
2052impl CertificateStatus {
2053    pub(crate) fn new(ocsp: Vec<u8>) -> Self {
2054        Self {
2055            ocsp_response: PayloadU24::new(ocsp),
2056        }
2057    }
2058
2059    #[cfg(feature = "tls12")]
2060    pub(crate) fn into_inner(self) -> Vec<u8> {
2061        self.ocsp_response.0
2062    }
2063}
2064
2065#[derive(Debug)]
2066pub enum HandshakePayload {
2067    HelloRequest,
2068    ClientHello(ClientHelloPayload),
2069    ServerHello(ServerHelloPayload),
2070    HelloRetryRequest(HelloRetryRequest),
2071    Certificate(CertificateChain),
2072    CertificateTls13(CertificatePayloadTls13),
2073    ServerKeyExchange(ServerKeyExchangePayload),
2074    CertificateRequest(CertificateRequestPayload),
2075    CertificateRequestTls13(CertificateRequestPayloadTls13),
2076    CertificateVerify(DigitallySignedStruct),
2077    ServerHelloDone,
2078    EndOfEarlyData,
2079    ClientKeyExchange(Payload),
2080    NewSessionTicket(NewSessionTicketPayload),
2081    NewSessionTicketTls13(NewSessionTicketPayloadTls13),
2082    EncryptedExtensions(Vec<ServerExtension>),
2083    KeyUpdate(KeyUpdateRequest),
2084    Finished(Payload),
2085    CertificateStatus(CertificateStatus),
2086    MessageHash(Payload),
2087    Unknown(Payload),
2088}
2089
2090impl HandshakePayload {
2091    fn encode(&self, bytes: &mut Vec<u8>) {
2092        use self::HandshakePayload::*;
2093        match *self {
2094            HelloRequest | ServerHelloDone | EndOfEarlyData => {}
2095            ClientHello(ref x) => x.encode(bytes),
2096            ServerHello(ref x) => x.encode(bytes),
2097            HelloRetryRequest(ref x) => x.encode(bytes),
2098            Certificate(ref x) => x.encode(bytes),
2099            CertificateTls13(ref x) => x.encode(bytes),
2100            ServerKeyExchange(ref x) => x.encode(bytes),
2101            ClientKeyExchange(ref x) => x.encode(bytes),
2102            CertificateRequest(ref x) => x.encode(bytes),
2103            CertificateRequestTls13(ref x) => x.encode(bytes),
2104            CertificateVerify(ref x) => x.encode(bytes),
2105            NewSessionTicket(ref x) => x.encode(bytes),
2106            NewSessionTicketTls13(ref x) => x.encode(bytes),
2107            EncryptedExtensions(ref x) => x.encode(bytes),
2108            KeyUpdate(ref x) => x.encode(bytes),
2109            Finished(ref x) => x.encode(bytes),
2110            CertificateStatus(ref x) => x.encode(bytes),
2111            MessageHash(ref x) => x.encode(bytes),
2112            Unknown(ref x) => x.encode(bytes),
2113        }
2114    }
2115}
2116
2117#[derive(Debug)]
2118pub struct HandshakeMessagePayload {
2119    pub typ: HandshakeType,
2120    pub payload: HandshakePayload,
2121}
2122
2123impl Codec for HandshakeMessagePayload {
2124    fn encode(&self, bytes: &mut Vec<u8>) {
2125        // output type, length, and encoded payload
2126        match self.typ {
2127            HandshakeType::HelloRetryRequest => HandshakeType::ServerHello,
2128            _ => self.typ,
2129        }
2130        .encode(bytes);
2131
2132        let nested = LengthPrefixedBuffer::new(ListLength::U24 { max: usize::MAX }, bytes);
2133        self.payload.encode(nested.buf);
2134    }
2135
2136    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
2137        Self::read_version(r, ProtocolVersion::TLSv1_2)
2138    }
2139}
2140
2141impl HandshakeMessagePayload {
2142    pub(crate) fn read_version(
2143        r: &mut Reader,
2144        vers: ProtocolVersion,
2145    ) -> Result<Self, InvalidMessage> {
2146        let mut typ = HandshakeType::read(r)?;
2147        let len = codec::u24::read(r)?.0 as usize;
2148        let mut sub = r.sub(len)?;
2149
2150        let payload = match typ {
2151            HandshakeType::HelloRequest if sub.left() == 0 => HandshakePayload::HelloRequest,
2152            HandshakeType::ClientHello => {
2153                HandshakePayload::ClientHello(ClientHelloPayload::read(&mut sub)?)
2154            }
2155            HandshakeType::ServerHello => {
2156                let version = ProtocolVersion::read(&mut sub)?;
2157                let random = Random::read(&mut sub)?;
2158
2159                if random == HELLO_RETRY_REQUEST_RANDOM {
2160                    let mut hrr = HelloRetryRequest::read(&mut sub)?;
2161                    hrr.legacy_version = version;
2162                    typ = HandshakeType::HelloRetryRequest;
2163                    HandshakePayload::HelloRetryRequest(hrr)
2164                } else {
2165                    let mut shp = ServerHelloPayload::read(&mut sub)?;
2166                    shp.legacy_version = version;
2167                    shp.random = random;
2168                    HandshakePayload::ServerHello(shp)
2169                }
2170            }
2171            HandshakeType::Certificate if vers == ProtocolVersion::TLSv1_3 => {
2172                let p = CertificatePayloadTls13::read(&mut sub)?;
2173                HandshakePayload::CertificateTls13(p)
2174            }
2175            HandshakeType::Certificate => {
2176                HandshakePayload::Certificate(CertificateChain::read(&mut sub)?)
2177            }
2178            HandshakeType::ServerKeyExchange => {
2179                let p = ServerKeyExchangePayload::read(&mut sub)?;
2180                HandshakePayload::ServerKeyExchange(p)
2181            }
2182            HandshakeType::ServerHelloDone => {
2183                sub.expect_empty("ServerHelloDone")?;
2184                HandshakePayload::ServerHelloDone
2185            }
2186            HandshakeType::ClientKeyExchange => {
2187                HandshakePayload::ClientKeyExchange(Payload::read(&mut sub))
2188            }
2189            HandshakeType::CertificateRequest if vers == ProtocolVersion::TLSv1_3 => {
2190                let p = CertificateRequestPayloadTls13::read(&mut sub)?;
2191                HandshakePayload::CertificateRequestTls13(p)
2192            }
2193            HandshakeType::CertificateRequest => {
2194                let p = CertificateRequestPayload::read(&mut sub)?;
2195                HandshakePayload::CertificateRequest(p)
2196            }
2197            HandshakeType::CertificateVerify => {
2198                HandshakePayload::CertificateVerify(DigitallySignedStruct::read(&mut sub)?)
2199            }
2200            HandshakeType::NewSessionTicket if vers == ProtocolVersion::TLSv1_3 => {
2201                let p = NewSessionTicketPayloadTls13::read(&mut sub)?;
2202                HandshakePayload::NewSessionTicketTls13(p)
2203            }
2204            HandshakeType::NewSessionTicket => {
2205                let p = NewSessionTicketPayload::read(&mut sub)?;
2206                HandshakePayload::NewSessionTicket(p)
2207            }
2208            HandshakeType::EncryptedExtensions => {
2209                HandshakePayload::EncryptedExtensions(Vec::read(&mut sub)?)
2210            }
2211            HandshakeType::KeyUpdate => {
2212                HandshakePayload::KeyUpdate(KeyUpdateRequest::read(&mut sub)?)
2213            }
2214            HandshakeType::EndOfEarlyData => {
2215                sub.expect_empty("EndOfEarlyData")?;
2216                HandshakePayload::EndOfEarlyData
2217            }
2218            HandshakeType::Finished => HandshakePayload::Finished(Payload::read(&mut sub)),
2219            HandshakeType::CertificateStatus => {
2220                HandshakePayload::CertificateStatus(CertificateStatus::read(&mut sub)?)
2221            }
2222            HandshakeType::MessageHash => {
2223                // does not appear on the wire
2224                return Err(InvalidMessage::UnexpectedMessage("MessageHash"));
2225            }
2226            HandshakeType::HelloRetryRequest => {
2227                // not legal on wire
2228                return Err(InvalidMessage::UnexpectedMessage("HelloRetryRequest"));
2229            }
2230            _ => HandshakePayload::Unknown(Payload::read(&mut sub)),
2231        };
2232
2233        sub.expect_empty("HandshakeMessagePayload")
2234            .map(|_| Self { typ, payload })
2235    }
2236
2237    pub(crate) fn build_key_update_notify() -> Self {
2238        Self {
2239            typ: HandshakeType::KeyUpdate,
2240            payload: HandshakePayload::KeyUpdate(KeyUpdateRequest::UpdateNotRequested),
2241        }
2242    }
2243
2244    pub(crate) fn get_encoding_for_binder_signing(&self) -> Vec<u8> {
2245        let mut ret = self.get_encoding();
2246
2247        let binder_len = match self.payload {
2248            HandshakePayload::ClientHello(ref ch) => match ch.extensions.last() {
2249                Some(ClientExtension::PresharedKey(ref offer)) => {
2250                    let mut binders_encoding = Vec::new();
2251                    offer
2252                        .binders
2253                        .encode(&mut binders_encoding);
2254                    binders_encoding.len()
2255                }
2256                _ => 0,
2257            },
2258            _ => 0,
2259        };
2260
2261        let ret_len = ret.len() - binder_len;
2262        ret.truncate(ret_len);
2263        ret
2264    }
2265
2266    pub(crate) fn build_handshake_hash(hash: &[u8]) -> Self {
2267        Self {
2268            typ: HandshakeType::MessageHash,
2269            payload: HandshakePayload::MessageHash(Payload::new(hash.to_vec())),
2270        }
2271    }
2272}
2273
2274#[derive(Clone, Debug, Default, Eq, PartialEq)]
2275pub struct HpkeSymmetricCipherSuite {
2276    pub kdf_id: HpkeKdf,
2277    pub aead_id: HpkeAead,
2278}
2279
2280impl Codec for HpkeSymmetricCipherSuite {
2281    fn encode(&self, bytes: &mut Vec<u8>) {
2282        self.kdf_id.encode(bytes);
2283        self.aead_id.encode(bytes);
2284    }
2285
2286    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
2287        Ok(Self {
2288            kdf_id: HpkeKdf::read(r)?,
2289            aead_id: HpkeAead::read(r)?,
2290        })
2291    }
2292}
2293
2294impl TlsListElement for HpkeSymmetricCipherSuite {
2295    const SIZE_LEN: ListLength = ListLength::U16;
2296}
2297
2298#[derive(Clone, Debug)]
2299pub struct HpkeKeyConfig {
2300    pub config_id: u8,
2301    pub kem_id: HpkeKem,
2302    pub public_key: PayloadU16,
2303    pub symmetric_cipher_suites: Vec<HpkeSymmetricCipherSuite>,
2304}
2305
2306impl Codec for HpkeKeyConfig {
2307    fn encode(&self, bytes: &mut Vec<u8>) {
2308        self.config_id.encode(bytes);
2309        self.kem_id.encode(bytes);
2310        self.public_key.encode(bytes);
2311        self.symmetric_cipher_suites
2312            .encode(bytes);
2313    }
2314
2315    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
2316        Ok(Self {
2317            config_id: u8::read(r)?,
2318            kem_id: HpkeKem::read(r)?,
2319            public_key: PayloadU16::read(r)?,
2320            symmetric_cipher_suites: Vec::<HpkeSymmetricCipherSuite>::read(r)?,
2321        })
2322    }
2323}
2324
2325#[derive(Clone, Debug)]
2326pub struct EchConfigContents {
2327    pub key_config: HpkeKeyConfig,
2328    pub maximum_name_length: u8,
2329    pub public_name: DnsName<'static>,
2330    pub extensions: PayloadU16,
2331}
2332
2333impl Codec for EchConfigContents {
2334    fn encode(&self, bytes: &mut Vec<u8>) {
2335        self.key_config.encode(bytes);
2336        self.maximum_name_length.encode(bytes);
2337        let dns_name = &self.public_name.borrow();
2338        PayloadU8::encode_slice(dns_name.as_ref().as_ref(), bytes);
2339        self.extensions.encode(bytes);
2340    }
2341
2342    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
2343        Ok(Self {
2344            key_config: HpkeKeyConfig::read(r)?,
2345            maximum_name_length: u8::read(r)?,
2346            public_name: {
2347                DnsName::try_from(PayloadU8::read(r)?.0.as_slice())
2348                    .map_err(|_| InvalidMessage::InvalidServerName)?
2349                    .to_owned()
2350            },
2351            extensions: PayloadU16::read(r)?,
2352        })
2353    }
2354}
2355
2356#[derive(Clone, Debug)]
2357pub struct EchConfig {
2358    pub version: EchVersion,
2359    pub contents: EchConfigContents,
2360}
2361
2362impl Codec for EchConfig {
2363    fn encode(&self, bytes: &mut Vec<u8>) {
2364        self.version.encode(bytes);
2365        let mut contents = Vec::with_capacity(128);
2366        self.contents.encode(&mut contents);
2367        let length: &mut [u8; 2] = &mut [0, 0];
2368        codec::put_u16(contents.len() as u16, length);
2369        bytes.extend_from_slice(length);
2370        bytes.extend(contents);
2371    }
2372
2373    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
2374        let version = EchVersion::read(r)?;
2375        let length = u16::read(r)?;
2376        Ok(Self {
2377            version,
2378            contents: EchConfigContents::read(&mut r.sub(length as usize)?)?,
2379        })
2380    }
2381}
2382
2383impl TlsListElement for EchConfig {
2384    const SIZE_LEN: ListLength = ListLength::U16;
2385}