1use crate::enums::{AlertDescription, ContentType, HandshakeType, ProtocolVersion};
2use crate::error::{Error, InvalidMessage, PeerMisbehaved};
3#[cfg(feature = "logging")]
4use crate::log::{debug, warn};
5use crate::msgs::alert::AlertMessagePayload;
6use crate::msgs::base::Payload;
7use crate::msgs::enums::{AlertLevel, KeyUpdateRequest};
8use crate::msgs::fragmenter::MessageFragmenter;
9use crate::msgs::handshake::CertificateChain;
10use crate::msgs::message::MessagePayload;
11use crate::msgs::message::{BorrowedPlainMessage, Message, OpaqueMessage, PlainMessage};
12use crate::quic;
13use crate::record_layer;
14use crate::suites::PartiallyExtractedSecrets;
15use crate::suites::SupportedCipherSuite;
16#[cfg(feature = "tls12")]
17use crate::tls12::ConnectionSecrets;
18use crate::vecbuf::ChunkVecBuffer;
19
20use alloc::boxed::Box;
21use alloc::vec::Vec;
22
23use pki_types::CertificateDer;
24
25pub struct CommonState {
27 pub(crate) negotiated_version: Option<ProtocolVersion>,
28 pub(crate) side: Side,
29 pub(crate) record_layer: record_layer::RecordLayer,
30 pub(crate) suite: Option<SupportedCipherSuite>,
31 pub(crate) alpn_protocol: Option<Vec<u8>>,
32 pub(crate) aligned_handshake: bool,
33 pub(crate) may_send_application_data: bool,
34 pub(crate) may_receive_application_data: bool,
35 pub(crate) early_traffic: bool,
36 sent_fatal_alert: bool,
37 pub(crate) has_received_close_notify: bool,
39 pub(crate) has_seen_eof: bool,
40 pub(crate) received_middlebox_ccs: u8,
41 pub(crate) peer_certificates: Option<CertificateChain>,
42 message_fragmenter: MessageFragmenter,
43 pub(crate) received_plaintext: ChunkVecBuffer,
44 sendable_plaintext: ChunkVecBuffer,
45 pub(crate) sendable_tls: ChunkVecBuffer,
46 queued_key_update_message: Option<Vec<u8>>,
47
48 pub(crate) protocol: Protocol,
50 pub(crate) quic: quic::Quic,
51 pub(crate) enable_secret_extraction: bool,
52}
53
54impl CommonState {
55 pub(crate) fn new(side: Side) -> Self {
56 Self {
57 negotiated_version: None,
58 side,
59 record_layer: record_layer::RecordLayer::new(),
60 suite: None,
61 alpn_protocol: None,
62 aligned_handshake: true,
63 may_send_application_data: false,
64 may_receive_application_data: false,
65 early_traffic: false,
66 sent_fatal_alert: false,
67 has_received_close_notify: false,
68 has_seen_eof: false,
69 received_middlebox_ccs: 0,
70 peer_certificates: None,
71 message_fragmenter: MessageFragmenter::default(),
72 received_plaintext: ChunkVecBuffer::new(Some(DEFAULT_RECEIVED_PLAINTEXT_LIMIT)),
73 sendable_plaintext: ChunkVecBuffer::new(Some(DEFAULT_BUFFER_LIMIT)),
74 sendable_tls: ChunkVecBuffer::new(Some(DEFAULT_BUFFER_LIMIT)),
75 queued_key_update_message: None,
76 protocol: Protocol::Tcp,
77 quic: quic::Quic::default(),
78 enable_secret_extraction: false,
79 }
80 }
81
82 pub fn wants_write(&self) -> bool {
86 !self.sendable_tls.is_empty()
87 }
88
89 pub fn is_handshaking(&self) -> bool {
97 !(self.may_send_application_data && self.may_receive_application_data)
98 }
99
100 pub fn peer_certificates(&self) -> Option<&[CertificateDer<'_>]> {
116 self.peer_certificates.as_deref()
117 }
118
119 pub fn alpn_protocol(&self) -> Option<&[u8]> {
125 self.get_alpn_protocol()
126 }
127
128 pub fn negotiated_cipher_suite(&self) -> Option<SupportedCipherSuite> {
132 self.suite
133 }
134
135 pub fn protocol_version(&self) -> Option<ProtocolVersion> {
139 self.negotiated_version
140 }
141
142 pub(crate) fn is_tls13(&self) -> bool {
143 matches!(self.negotiated_version, Some(ProtocolVersion::TLSv1_3))
144 }
145
146 pub(crate) fn process_main_protocol<Data>(
147 &mut self,
148 msg: Message,
149 mut state: Box<dyn State<Data>>,
150 data: &mut Data,
151 ) -> Result<Box<dyn State<Data>>, Error> {
152 if self.may_receive_application_data && !self.is_tls13() {
155 let reject_ty = match self.side {
156 Side::Client => HandshakeType::HelloRequest,
157 Side::Server => HandshakeType::ClientHello,
158 };
159 if msg.is_handshake_type(reject_ty) {
160 self.send_warning_alert(AlertDescription::NoRenegotiation);
161 return Ok(state);
162 }
163 }
164
165 let mut cx = Context { common: self, data };
166 match state.handle(&mut cx, msg) {
167 Ok(next) => {
168 state = next;
169 Ok(state)
170 }
171 Err(e @ Error::InappropriateMessage { .. })
172 | Err(e @ Error::InappropriateHandshakeMessage { .. }) => {
173 Err(self.send_fatal_alert(AlertDescription::UnexpectedMessage, e))
174 }
175 Err(e) => Err(e),
176 }
177 }
178
179 pub(crate) fn send_some_plaintext(&mut self, data: &[u8]) -> usize {
185 self.perhaps_write_key_update();
186 self.send_plain(data, Limit::Yes)
187 }
188
189 pub(crate) fn send_early_plaintext(&mut self, data: &[u8]) -> usize {
190 debug_assert!(self.early_traffic);
191 debug_assert!(self.record_layer.is_encrypting());
192
193 if data.is_empty() {
194 return 0;
196 }
197
198 self.send_appdata_encrypt(data, Limit::Yes)
199 }
200
201 pub(crate) fn check_aligned_handshake(&mut self) -> Result<(), Error> {
206 if !self.aligned_handshake {
207 Err(self.send_fatal_alert(
208 AlertDescription::UnexpectedMessage,
209 PeerMisbehaved::KeyEpochWithPendingFragment,
210 ))
211 } else {
212 Ok(())
213 }
214 }
215
216 pub(crate) fn send_msg_encrypt(&mut self, m: PlainMessage) {
219 let iter = self
220 .message_fragmenter
221 .fragment_message(&m);
222 for m in iter {
223 self.send_single_fragment(m);
224 }
225 }
226
227 fn send_appdata_encrypt(&mut self, payload: &[u8], limit: Limit) -> usize {
229 let len = match limit {
234 Limit::Yes => self
235 .sendable_tls
236 .apply_limit(payload.len()),
237 Limit::No => payload.len(),
238 };
239
240 let iter = self.message_fragmenter.fragment_slice(
241 ContentType::ApplicationData,
242 ProtocolVersion::TLSv1_2,
243 &payload[..len],
244 );
245 for m in iter {
246 self.send_single_fragment(m);
247 }
248
249 len
250 }
251
252 fn send_single_fragment(&mut self, m: BorrowedPlainMessage) {
253 if self
256 .record_layer
257 .wants_close_before_encrypt()
258 {
259 self.send_close_notify();
260 }
261
262 if self.record_layer.encrypt_exhausted() {
265 return;
266 }
267
268 let em = self.record_layer.encrypt_outgoing(m);
269 self.queue_tls_message(em);
270 }
271
272 fn send_plain(&mut self, data: &[u8], limit: Limit) -> usize {
278 if !self.may_send_application_data {
279 let len = match limit {
282 Limit::Yes => self
283 .sendable_plaintext
284 .append_limited_copy(data),
285 Limit::No => self
286 .sendable_plaintext
287 .append(data.to_vec()),
288 };
289 return len;
290 }
291
292 debug_assert!(self.record_layer.is_encrypting());
293
294 if data.is_empty() {
295 return 0;
297 }
298
299 self.send_appdata_encrypt(data, limit)
300 }
301
302 pub(crate) fn start_outgoing_traffic(&mut self) {
303 self.may_send_application_data = true;
304 self.flush_plaintext();
305 }
306
307 pub(crate) fn start_traffic(&mut self) {
308 self.may_receive_application_data = true;
309 self.start_outgoing_traffic();
310 }
311
312 pub fn set_buffer_limit(&mut self, limit: Option<usize>) {
356 self.sendable_plaintext.set_limit(limit);
357 self.sendable_tls.set_limit(limit);
358 }
359
360 fn flush_plaintext(&mut self) {
363 if !self.may_send_application_data {
364 return;
365 }
366
367 while let Some(buf) = self.sendable_plaintext.pop() {
368 self.send_plain(&buf, Limit::No);
369 }
370 }
371
372 fn queue_tls_message(&mut self, m: OpaqueMessage) {
374 self.sendable_tls.append(m.encode());
375 }
376
377 pub(crate) fn send_msg(&mut self, m: Message, must_encrypt: bool) {
379 {
380 if let Protocol::Quic = self.protocol {
381 if let MessagePayload::Alert(alert) = m.payload {
382 self.quic.alert = Some(alert.description);
383 } else {
384 debug_assert!(
385 matches!(m.payload, MessagePayload::Handshake { .. }),
386 "QUIC uses TLS for the cryptographic handshake only"
387 );
388 let mut bytes = Vec::new();
389 m.payload.encode(&mut bytes);
390 self.quic
391 .hs_queue
392 .push_back((must_encrypt, bytes));
393 }
394 return;
395 }
396 }
397 if !must_encrypt {
398 let msg = &m.into();
399 let iter = self
400 .message_fragmenter
401 .fragment_message(msg);
402 for m in iter {
403 self.queue_tls_message(m.to_unencrypted_opaque());
404 }
405 } else {
406 self.send_msg_encrypt(m.into());
407 }
408 }
409
410 pub(crate) fn take_received_plaintext(&mut self, bytes: Payload) {
411 self.received_plaintext.append(bytes.0);
412 }
413
414 #[cfg(feature = "tls12")]
415 pub(crate) fn start_encryption_tls12(&mut self, secrets: &ConnectionSecrets, side: Side) {
416 let (dec, enc) = secrets.make_cipher_pair(side);
417 self.record_layer
418 .prepare_message_encrypter(enc);
419 self.record_layer
420 .prepare_message_decrypter(dec);
421 }
422
423 pub(crate) fn missing_extension(&mut self, why: PeerMisbehaved) -> Error {
424 self.send_fatal_alert(AlertDescription::MissingExtension, why)
425 }
426
427 fn send_warning_alert(&mut self, desc: AlertDescription) {
428 warn!("Sending warning alert {:?}", desc);
429 self.send_warning_alert_no_log(desc);
430 }
431
432 pub(crate) fn process_alert(&mut self, alert: &AlertMessagePayload) -> Result<(), Error> {
433 if let AlertLevel::Unknown(_) = alert.level {
435 return Err(self.send_fatal_alert(
436 AlertDescription::IllegalParameter,
437 Error::AlertReceived(alert.description),
438 ));
439 }
440
441 if self.may_receive_application_data && alert.description == AlertDescription::CloseNotify {
444 self.has_received_close_notify = true;
445 return Ok(());
446 }
447
448 let err = Error::AlertReceived(alert.description);
451 if alert.level == AlertLevel::Warning {
452 if self.is_tls13() && alert.description != AlertDescription::UserCanceled {
453 return Err(self.send_fatal_alert(AlertDescription::DecodeError, err));
454 } else {
455 warn!("TLS alert warning received: {:#?}", alert);
456 return Ok(());
457 }
458 }
459
460 Err(err)
461 }
462
463 pub(crate) fn send_cert_verify_error_alert(&mut self, err: Error) -> Error {
464 self.send_fatal_alert(
465 match &err {
466 Error::InvalidCertificate(e) => e.clone().into(),
467 Error::PeerMisbehaved(_) => AlertDescription::IllegalParameter,
468 _ => AlertDescription::HandshakeFailure,
469 },
470 err,
471 )
472 }
473
474 pub(crate) fn send_fatal_alert(
475 &mut self,
476 desc: AlertDescription,
477 err: impl Into<Error>,
478 ) -> Error {
479 debug_assert!(!self.sent_fatal_alert);
480 let m = Message::build_alert(AlertLevel::Fatal, desc);
481 self.send_msg(m, self.record_layer.is_encrypting());
482 self.sent_fatal_alert = true;
483 err.into()
484 }
485
486 pub fn send_close_notify(&mut self) {
492 debug!("Sending warning alert {:?}", AlertDescription::CloseNotify);
493 self.send_warning_alert_no_log(AlertDescription::CloseNotify);
494 }
495
496 fn send_warning_alert_no_log(&mut self, desc: AlertDescription) {
497 let m = Message::build_alert(AlertLevel::Warning, desc);
498 self.send_msg(m, self.record_layer.is_encrypting());
499 }
500
501 pub(crate) fn set_max_fragment_size(&mut self, new: Option<usize>) -> Result<(), Error> {
502 self.message_fragmenter
503 .set_max_fragment_size(new)
504 }
505
506 pub(crate) fn get_alpn_protocol(&self) -> Option<&[u8]> {
507 self.alpn_protocol
508 .as_ref()
509 .map(AsRef::as_ref)
510 }
511
512 pub fn wants_read(&self) -> bool {
522 self.received_plaintext.is_empty()
529 && !self.has_received_close_notify
530 && (self.may_send_application_data || self.sendable_tls.is_empty())
531 }
532
533 pub(crate) fn current_io_state(&self) -> IoState {
534 IoState {
535 tls_bytes_to_write: self.sendable_tls.len(),
536 plaintext_bytes_to_read: self.received_plaintext.len(),
537 peer_has_closed: self.has_received_close_notify,
538 }
539 }
540
541 pub(crate) fn is_quic(&self) -> bool {
542 self.protocol == Protocol::Quic
543 }
544
545 pub(crate) fn should_update_key(
546 &mut self,
547 key_update_request: &KeyUpdateRequest,
548 ) -> Result<bool, Error> {
549 match key_update_request {
550 KeyUpdateRequest::UpdateNotRequested => Ok(false),
551 KeyUpdateRequest::UpdateRequested => Ok(self.queued_key_update_message.is_none()),
552 _ => Err(self.send_fatal_alert(
553 AlertDescription::IllegalParameter,
554 InvalidMessage::InvalidKeyUpdate,
555 )),
556 }
557 }
558
559 pub(crate) fn enqueue_key_update_notification(&mut self) {
560 let message = PlainMessage::from(Message::build_key_update_notify());
561 self.queued_key_update_message = Some(
562 self.record_layer
563 .encrypt_outgoing(message.borrow())
564 .encode(),
565 );
566 }
567
568 pub(crate) fn perhaps_write_key_update(&mut self) {
569 if let Some(message) = self.queued_key_update_message.take() {
570 self.sendable_tls.append(message);
571 }
572 }
573}
574
575#[derive(Debug, Eq, PartialEq)]
580pub struct IoState {
581 tls_bytes_to_write: usize,
582 plaintext_bytes_to_read: usize,
583 peer_has_closed: bool,
584}
585
586impl IoState {
587 pub fn tls_bytes_to_write(&self) -> usize {
592 self.tls_bytes_to_write
593 }
594
595 pub fn plaintext_bytes_to_read(&self) -> usize {
598 self.plaintext_bytes_to_read
599 }
600
601 pub fn peer_has_closed(&self) -> bool {
610 self.peer_has_closed
611 }
612}
613
614pub(crate) trait State<Data>: Send + Sync {
615 fn handle(
616 self: Box<Self>,
617 cx: &mut Context<'_, Data>,
618 message: Message,
619 ) -> Result<Box<dyn State<Data>>, Error>;
620
621 fn export_keying_material(
622 &self,
623 _output: &mut [u8],
624 _label: &[u8],
625 _context: Option<&[u8]>,
626 ) -> Result<(), Error> {
627 Err(Error::HandshakeNotComplete)
628 }
629
630 fn extract_secrets(&self) -> Result<PartiallyExtractedSecrets, Error> {
631 Err(Error::HandshakeNotComplete)
632 }
633
634 fn handle_decrypt_error(&self) {}
635}
636
637pub(crate) struct Context<'a, Data> {
638 pub(crate) common: &'a mut CommonState,
639 pub(crate) data: &'a mut Data,
640}
641
642#[derive(Clone, Copy, Debug, PartialEq)]
644pub enum Side {
645 Client,
647 Server,
649}
650
651impl Side {
652 pub(crate) fn peer(&self) -> Self {
653 match self {
654 Self::Client => Self::Server,
655 Self::Server => Self::Client,
656 }
657 }
658}
659
660#[derive(Copy, Clone, Eq, PartialEq, Debug)]
661pub(crate) enum Protocol {
662 Tcp,
663 Quic,
664}
665
666enum Limit {
667 Yes,
668 No,
669}
670
671const DEFAULT_RECEIVED_PLAINTEXT_LIMIT: usize = 16 * 1024;
672const DEFAULT_BUFFER_LIMIT: usize = 64 * 1024;