1use crate::client::{ClientConfig, ClientConnectionData};
3use crate::common_state::{CommonState, Protocol, Side};
4use crate::conn::{ConnectionCore, SideData};
5use crate::crypto::cipher::{AeadKey, Iv};
6use crate::crypto::tls13::{Hkdf, HkdfExpander, OkmBlock};
7use crate::enums::{AlertDescription, ProtocolVersion};
8use crate::error::Error;
9use crate::msgs::deframer::DeframerVecBuffer;
10use crate::msgs::handshake::{ClientExtension, ServerExtension};
11use crate::server::{ServerConfig, ServerConnectionData};
12use crate::tls13::key_schedule::{
13 hkdf_expand_label, hkdf_expand_label_aead_key, hkdf_expand_label_block,
14};
15use crate::tls13::Tls13CipherSuite;
16
17use pki_types::ServerName;
18
19use alloc::boxed::Box;
20use alloc::collections::VecDeque;
21use alloc::sync::Arc;
22use alloc::vec;
23use alloc::vec::Vec;
24use core::fmt::{self, Debug};
25use core::ops::{Deref, DerefMut};
26
27#[derive(Debug)]
29pub enum Connection {
30 Client(ClientConnection),
32 Server(ServerConnection),
34}
35
36impl Connection {
37 pub fn quic_transport_parameters(&self) -> Option<&[u8]> {
41 match self {
42 Self::Client(conn) => conn.quic_transport_parameters(),
43 Self::Server(conn) => conn.quic_transport_parameters(),
44 }
45 }
46
47 pub fn zero_rtt_keys(&self) -> Option<DirectionalKeys> {
49 match self {
50 Self::Client(conn) => conn.zero_rtt_keys(),
51 Self::Server(conn) => conn.zero_rtt_keys(),
52 }
53 }
54
55 pub fn read_hs(&mut self, plaintext: &[u8]) -> Result<(), Error> {
59 match self {
60 Self::Client(conn) => conn.read_hs(plaintext),
61 Self::Server(conn) => conn.read_hs(plaintext),
62 }
63 }
64
65 pub fn write_hs(&mut self, buf: &mut Vec<u8>) -> Option<KeyChange> {
69 match self {
70 Self::Client(conn) => conn.write_hs(buf),
71 Self::Server(conn) => conn.write_hs(buf),
72 }
73 }
74
75 pub fn alert(&self) -> Option<AlertDescription> {
79 match self {
80 Self::Client(conn) => conn.alert(),
81 Self::Server(conn) => conn.alert(),
82 }
83 }
84
85 #[inline]
101 pub fn export_keying_material<T: AsMut<[u8]>>(
102 &self,
103 output: T,
104 label: &[u8],
105 context: Option<&[u8]>,
106 ) -> Result<T, Error> {
107 match self {
108 Self::Client(conn) => conn
109 .core
110 .export_keying_material(output, label, context),
111 Self::Server(conn) => conn
112 .core
113 .export_keying_material(output, label, context),
114 }
115 }
116}
117
118impl Deref for Connection {
119 type Target = CommonState;
120
121 fn deref(&self) -> &Self::Target {
122 match self {
123 Self::Client(conn) => &conn.core.common_state,
124 Self::Server(conn) => &conn.core.common_state,
125 }
126 }
127}
128
129impl DerefMut for Connection {
130 fn deref_mut(&mut self) -> &mut Self::Target {
131 match self {
132 Self::Client(conn) => &mut conn.core.common_state,
133 Self::Server(conn) => &mut conn.core.common_state,
134 }
135 }
136}
137
138pub struct ClientConnection {
140 inner: ConnectionCommon<ClientConnectionData>,
141}
142
143impl ClientConnection {
144 pub fn new(
149 config: Arc<ClientConfig>,
150 quic_version: Version,
151 name: ServerName<'static>,
152 params: Vec<u8>,
153 ) -> Result<Self, Error> {
154 if !config.supports_version(ProtocolVersion::TLSv1_3) {
155 return Err(Error::General(
156 "TLS 1.3 support is required for QUIC".into(),
157 ));
158 }
159
160 if !config.supports_protocol(Protocol::Quic) {
161 return Err(Error::General(
162 "at least one ciphersuite must support QUIC".into(),
163 ));
164 }
165
166 let ext = match quic_version {
167 Version::V1Draft => ClientExtension::TransportParametersDraft(params),
168 Version::V1 | Version::V2 => ClientExtension::TransportParameters(params),
169 };
170
171 let mut inner = ConnectionCore::for_client(config, name, vec![ext], Protocol::Quic)?;
172 inner.common_state.quic.version = quic_version;
173 Ok(Self {
174 inner: inner.into(),
175 })
176 }
177
178 pub fn is_early_data_accepted(&self) -> bool {
184 self.inner.core.is_early_data_accepted()
185 }
186}
187
188impl Deref for ClientConnection {
189 type Target = ConnectionCommon<ClientConnectionData>;
190
191 fn deref(&self) -> &Self::Target {
192 &self.inner
193 }
194}
195
196impl DerefMut for ClientConnection {
197 fn deref_mut(&mut self) -> &mut Self::Target {
198 &mut self.inner
199 }
200}
201
202impl Debug for ClientConnection {
203 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
204 f.debug_struct("quic::ClientConnection")
205 .finish()
206 }
207}
208
209impl From<ClientConnection> for Connection {
210 fn from(c: ClientConnection) -> Self {
211 Self::Client(c)
212 }
213}
214
215pub struct ServerConnection {
217 inner: ConnectionCommon<ServerConnectionData>,
218}
219
220impl ServerConnection {
221 pub fn new(
226 config: Arc<ServerConfig>,
227 quic_version: Version,
228 params: Vec<u8>,
229 ) -> Result<Self, Error> {
230 if !config.supports_version(ProtocolVersion::TLSv1_3) {
231 return Err(Error::General(
232 "TLS 1.3 support is required for QUIC".into(),
233 ));
234 }
235
236 if !config.supports_protocol(Protocol::Quic) {
237 return Err(Error::General(
238 "at least one ciphersuite must support QUIC".into(),
239 ));
240 }
241
242 if config.max_early_data_size != 0 && config.max_early_data_size != 0xffff_ffff {
243 return Err(Error::General(
244 "QUIC sessions must set a max early data of 0 or 2^32-1".into(),
245 ));
246 }
247
248 let ext = match quic_version {
249 Version::V1Draft => ServerExtension::TransportParametersDraft(params),
250 Version::V1 | Version::V2 => ServerExtension::TransportParameters(params),
251 };
252
253 let mut core = ConnectionCore::for_server(config, vec![ext])?;
254 core.common_state.protocol = Protocol::Quic;
255 core.common_state.quic.version = quic_version;
256 Ok(Self { inner: core.into() })
257 }
258
259 pub fn reject_early_data(&mut self) {
265 self.inner.core.reject_early_data()
266 }
267
268 pub fn server_name(&self) -> Option<&str> {
284 self.inner.core.get_sni_str()
285 }
286}
287
288impl Deref for ServerConnection {
289 type Target = ConnectionCommon<ServerConnectionData>;
290
291 fn deref(&self) -> &Self::Target {
292 &self.inner
293 }
294}
295
296impl DerefMut for ServerConnection {
297 fn deref_mut(&mut self) -> &mut Self::Target {
298 &mut self.inner
299 }
300}
301
302impl Debug for ServerConnection {
303 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
304 f.debug_struct("quic::ServerConnection")
305 .finish()
306 }
307}
308
309impl From<ServerConnection> for Connection {
310 fn from(c: ServerConnection) -> Self {
311 Self::Server(c)
312 }
313}
314
315pub struct ConnectionCommon<Data> {
317 core: ConnectionCore<Data>,
318 deframer_buffer: DeframerVecBuffer,
319}
320
321impl<Data: SideData> ConnectionCommon<Data> {
322 pub fn quic_transport_parameters(&self) -> Option<&[u8]> {
330 self.core
331 .common_state
332 .quic
333 .params
334 .as_ref()
335 .map(|v| v.as_ref())
336 }
337
338 pub fn zero_rtt_keys(&self) -> Option<DirectionalKeys> {
340 let suite = self
341 .core
342 .common_state
343 .suite
344 .and_then(|suite| suite.tls13())?;
345 Some(DirectionalKeys::new(
346 suite,
347 suite.quic?,
348 self.core
349 .common_state
350 .quic
351 .early_secret
352 .as_ref()?,
353 self.core.common_state.quic.version,
354 ))
355 }
356
357 pub fn read_hs(&mut self, plaintext: &[u8]) -> Result<(), Error> {
361 self.core.message_deframer.push(
362 ProtocolVersion::TLSv1_3,
363 plaintext,
364 &mut self.deframer_buffer,
365 )?;
366 self.core
367 .process_new_packets(&mut self.deframer_buffer)?;
368 Ok(())
369 }
370
371 pub fn write_hs(&mut self, buf: &mut Vec<u8>) -> Option<KeyChange> {
375 self.core
376 .common_state
377 .quic
378 .write_hs(buf)
379 }
380
381 pub fn alert(&self) -> Option<AlertDescription> {
385 self.core.common_state.quic.alert
386 }
387}
388
389impl<Data> Deref for ConnectionCommon<Data> {
390 type Target = CommonState;
391
392 fn deref(&self) -> &Self::Target {
393 &self.core.common_state
394 }
395}
396
397impl<Data> DerefMut for ConnectionCommon<Data> {
398 fn deref_mut(&mut self) -> &mut Self::Target {
399 &mut self.core.common_state
400 }
401}
402
403impl<Data> From<ConnectionCore<Data>> for ConnectionCommon<Data> {
404 fn from(core: ConnectionCore<Data>) -> Self {
405 Self {
406 core,
407 deframer_buffer: DeframerVecBuffer::default(),
408 }
409 }
410}
411
412#[derive(Default)]
413pub(crate) struct Quic {
414 pub(crate) params: Option<Vec<u8>>,
416 pub(crate) alert: Option<AlertDescription>,
417 pub(crate) hs_queue: VecDeque<(bool, Vec<u8>)>,
418 pub(crate) early_secret: Option<OkmBlock>,
419 pub(crate) hs_secrets: Option<Secrets>,
420 pub(crate) traffic_secrets: Option<Secrets>,
421 pub(crate) returned_traffic_keys: bool,
423 pub(crate) version: Version,
424}
425
426impl Quic {
427 pub(crate) fn write_hs(&mut self, buf: &mut Vec<u8>) -> Option<KeyChange> {
428 while let Some((_, msg)) = self.hs_queue.pop_front() {
429 buf.extend_from_slice(&msg);
430 if let Some(&(true, _)) = self.hs_queue.front() {
431 if self.hs_secrets.is_some() {
432 break;
434 }
435 }
436 }
437
438 if let Some(secrets) = self.hs_secrets.take() {
439 return Some(KeyChange::Handshake {
440 keys: Keys::new(&secrets),
441 });
442 }
443
444 if let Some(mut secrets) = self.traffic_secrets.take() {
445 if !self.returned_traffic_keys {
446 self.returned_traffic_keys = true;
447 let keys = Keys::new(&secrets);
448 secrets.update();
449 return Some(KeyChange::OneRtt {
450 keys,
451 next: secrets,
452 });
453 }
454 }
455
456 None
457 }
458}
459
460#[derive(Clone)]
462pub struct Secrets {
463 pub(crate) client: OkmBlock,
465 pub(crate) server: OkmBlock,
467 suite: &'static Tls13CipherSuite,
469 quic: &'static dyn Algorithm,
470 side: Side,
471 version: Version,
472}
473
474impl Secrets {
475 pub(crate) fn new(
476 client: OkmBlock,
477 server: OkmBlock,
478 suite: &'static Tls13CipherSuite,
479 quic: &'static dyn Algorithm,
480 side: Side,
481 version: Version,
482 ) -> Self {
483 Self {
484 client,
485 server,
486 suite,
487 quic,
488 side,
489 version,
490 }
491 }
492
493 pub fn next_packet_keys(&mut self) -> PacketKeySet {
495 let keys = PacketKeySet::new(self);
496 self.update();
497 keys
498 }
499
500 pub(crate) fn update(&mut self) {
501 self.client = hkdf_expand_label_block(
502 self.suite
503 .hkdf_provider
504 .expander_for_okm(&self.client)
505 .as_ref(),
506 self.version.key_update_label(),
507 &[],
508 );
509 self.server = hkdf_expand_label_block(
510 self.suite
511 .hkdf_provider
512 .expander_for_okm(&self.server)
513 .as_ref(),
514 self.version.key_update_label(),
515 &[],
516 );
517 }
518
519 fn local_remote(&self) -> (&OkmBlock, &OkmBlock) {
520 match self.side {
521 Side::Client => (&self.client, &self.server),
522 Side::Server => (&self.server, &self.client),
523 }
524 }
525}
526
527pub struct DirectionalKeys {
529 pub header: Box<dyn HeaderProtectionKey>,
531 pub packet: Box<dyn PacketKey>,
533}
534
535impl DirectionalKeys {
536 pub(crate) fn new(
537 suite: &'static Tls13CipherSuite,
538 quic: &'static dyn Algorithm,
539 secret: &OkmBlock,
540 version: Version,
541 ) -> Self {
542 let builder = KeyBuilder::new(secret, version, quic, suite.hkdf_provider);
543 Self {
544 header: builder.header_protection_key(),
545 packet: builder.packet_key(),
546 }
547 }
548}
549
550const TAG_LEN: usize = 16;
552
553pub struct Tag([u8; TAG_LEN]);
555
556impl From<&[u8]> for Tag {
557 fn from(value: &[u8]) -> Self {
558 let mut array = [0u8; TAG_LEN];
559 array.copy_from_slice(value);
560 Self(array)
561 }
562}
563
564impl AsRef<[u8]> for Tag {
565 fn as_ref(&self) -> &[u8] {
566 &self.0
567 }
568}
569
570pub trait Algorithm: Send + Sync {
572 fn packet_key(&self, key: AeadKey, iv: Iv) -> Box<dyn PacketKey>;
577
578 fn header_protection_key(&self, key: AeadKey) -> Box<dyn HeaderProtectionKey>;
582
583 fn aead_key_len(&self) -> usize;
587}
588
589pub trait HeaderProtectionKey: Send + Sync {
591 fn encrypt_in_place(
612 &self,
613 sample: &[u8],
614 first: &mut u8,
615 packet_number: &mut [u8],
616 ) -> Result<(), Error>;
617
618 fn decrypt_in_place(
640 &self,
641 sample: &[u8],
642 first: &mut u8,
643 packet_number: &mut [u8],
644 ) -> Result<(), Error>;
645
646 fn sample_len(&self) -> usize;
648}
649
650pub trait PacketKey: Send + Sync {
652 fn encrypt_in_place(
660 &self,
661 packet_number: u64,
662 header: &[u8],
663 payload: &mut [u8],
664 ) -> Result<Tag, Error>;
665
666 fn decrypt_in_place<'a>(
674 &self,
675 packet_number: u64,
676 header: &[u8],
677 payload: &'a mut [u8],
678 ) -> Result<&'a [u8], Error>;
679
680 fn tag_len(&self) -> usize;
682}
683
684pub struct PacketKeySet {
686 pub local: Box<dyn PacketKey>,
688 pub remote: Box<dyn PacketKey>,
690}
691
692impl PacketKeySet {
693 fn new(secrets: &Secrets) -> Self {
694 let (local, remote) = secrets.local_remote();
695 let (version, alg, hkdf) = (secrets.version, secrets.quic, secrets.suite.hkdf_provider);
696 Self {
697 local: KeyBuilder::new(local, version, alg, hkdf).packet_key(),
698 remote: KeyBuilder::new(remote, version, alg, hkdf).packet_key(),
699 }
700 }
701}
702
703pub(crate) struct KeyBuilder<'a> {
704 expander: Box<dyn HkdfExpander>,
705 version: Version,
706 alg: &'a dyn Algorithm,
707}
708
709impl<'a> KeyBuilder<'a> {
710 pub(crate) fn new(
711 secret: &OkmBlock,
712 version: Version,
713 alg: &'a dyn Algorithm,
714 hkdf: &'a dyn Hkdf,
715 ) -> Self {
716 Self {
717 expander: hkdf.expander_for_okm(secret),
718 version,
719 alg,
720 }
721 }
722
723 pub(crate) fn packet_key(&self) -> Box<dyn PacketKey> {
725 let aead_key_len = self.alg.aead_key_len();
726 let packet_key = hkdf_expand_label_aead_key(
727 self.expander.as_ref(),
728 aead_key_len,
729 self.version.packet_key_label(),
730 &[],
731 );
732
733 let packet_iv =
734 hkdf_expand_label(self.expander.as_ref(), self.version.packet_iv_label(), &[]);
735 self.alg
736 .packet_key(packet_key, packet_iv)
737 }
738
739 pub(crate) fn header_protection_key(&self) -> Box<dyn HeaderProtectionKey> {
741 let header_key = hkdf_expand_label_aead_key(
742 self.expander.as_ref(),
743 self.alg.aead_key_len(),
744 self.version.header_key_label(),
745 &[],
746 );
747 self.alg
748 .header_protection_key(header_key)
749 }
750}
751
752pub struct Keys {
754 pub local: DirectionalKeys,
756 pub remote: DirectionalKeys,
758}
759
760impl Keys {
761 pub fn initial(
763 version: Version,
764 suite: &'static Tls13CipherSuite,
765 quic: &'static dyn Algorithm,
766 client_dst_connection_id: &[u8],
767 side: Side,
768 ) -> Self {
769 const CLIENT_LABEL: &[u8] = b"client in";
770 const SERVER_LABEL: &[u8] = b"server in";
771 let salt = version.initial_salt();
772 let hs_secret = suite
773 .hkdf_provider
774 .extract_from_secret(Some(salt), client_dst_connection_id);
775
776 let secrets = Secrets {
777 version,
778 client: hkdf_expand_label_block(hs_secret.as_ref(), CLIENT_LABEL, &[]),
779 server: hkdf_expand_label_block(hs_secret.as_ref(), SERVER_LABEL, &[]),
780 suite,
781 quic,
782 side,
783 };
784 Self::new(&secrets)
785 }
786
787 fn new(secrets: &Secrets) -> Self {
788 let (local, remote) = secrets.local_remote();
789 Self {
790 local: DirectionalKeys::new(secrets.suite, secrets.quic, local, secrets.version),
791 remote: DirectionalKeys::new(secrets.suite, secrets.quic, remote, secrets.version),
792 }
793 }
794}
795
796#[allow(clippy::large_enum_variant)]
810pub enum KeyChange {
811 Handshake {
813 keys: Keys,
815 },
816 OneRtt {
818 keys: Keys,
820 next: Secrets,
822 },
823}
824
825#[non_exhaustive]
829#[derive(Clone, Copy, Debug)]
830pub enum Version {
831 V1Draft,
833 V1,
835 V2,
837}
838
839impl Version {
840 fn initial_salt(self) -> &'static [u8; 20] {
841 match self {
842 Self::V1Draft => &[
843 0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61,
845 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99,
846 ],
847 Self::V1 => &[
848 0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8,
850 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a,
851 ],
852 Self::V2 => &[
853 0x0d, 0xed, 0xe3, 0xde, 0xf7, 0x00, 0xa6, 0xdb, 0x81, 0x93, 0x81, 0xbe, 0x6e, 0x26,
855 0x9d, 0xcb, 0xf9, 0xbd, 0x2e, 0xd9,
856 ],
857 }
858 }
859
860 pub(crate) fn packet_key_label(&self) -> &'static [u8] {
862 match self {
863 Self::V1Draft | Self::V1 => b"quic key",
864 Self::V2 => b"quicv2 key",
865 }
866 }
867
868 pub(crate) fn packet_iv_label(&self) -> &'static [u8] {
870 match self {
871 Self::V1Draft | Self::V1 => b"quic iv",
872 Self::V2 => b"quicv2 iv",
873 }
874 }
875
876 pub(crate) fn header_key_label(&self) -> &'static [u8] {
878 match self {
879 Self::V1Draft | Self::V1 => b"quic hp",
880 Self::V2 => b"quicv2 hp",
881 }
882 }
883
884 fn key_update_label(&self) -> &'static [u8] {
885 match self {
886 Self::V1Draft | Self::V1 => b"quic ku",
887 Self::V2 => b"quicv2 ku",
888 }
889 }
890}
891
892impl Default for Version {
893 fn default() -> Self {
894 Self::V1
895 }
896}
897
898#[cfg(test)]
899mod tests {
900 use super::PacketKey;
901 use crate::quic::HeaderProtectionKey;
902 use std::prelude::v1::*;
903
904 #[test]
905 fn auto_traits() {
906 fn assert_auto<T: Send + Sync>() {}
907 assert_auto::<Box<dyn PacketKey>>();
908 assert_auto::<Box<dyn HeaderProtectionKey>>();
909 }
910}