diff --git a/CHANGELOG.md b/CHANGELOG.md index 83205cf..72c6aea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Unreleased + * Add PSK (Pre-Shared Key) cipher suite for DTLS 1.2 (RFC 4279, RFC 7925) + * `PSK_AES128_CCM_8` (0xC0A8) + * Add `Dtls::new_12_psk()` constructor for PSK-only sessions + * Add `PskResolver` trait and PSK config builder methods + * Fix client to handle optional ServerKeyExchange in PSK handshakes (RFC 4279 §2) + # 0.4.3 * Fix server auto-sensing DTLS version with fragmented ClientHello #87 diff --git a/Cargo.lock b/Cargo.lock index d21a0c0..6393b05 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -250,6 +250,18 @@ dependencies = [ "shlex", ] +[[package]] +name = "ccm" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae3c82e4355234767756212c570e29833699ab63e6ffd161887314cc5b43847" +dependencies = [ + "aead", + "cipher", + "ctr", + "subtle", +] + [[package]] name = "cexpr" version = "0.6.0" @@ -468,10 +480,12 @@ dependencies = [ name = "dimpl" version = "0.4.3" dependencies = [ + "aes", "aes-gcm", "arrayvec", "aws-lc-rs", "bytes", + "ccm", "chacha20", "chacha20poly1305", "der", diff --git a/Cargo.toml b/Cargo.toml index b792359..47a05b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ edition = "2024" license = "MIT OR Apache-2.0" repository = "https://github.com/algesten/dimpl" readme = "README.md" -keywords = ["dtls", "tls", "webrtc"] +keywords = ["dtls", "tls", "webrtc", "psk"] categories = ["network-programming", "cryptography", "security"] # MSRV @@ -17,13 +17,14 @@ rust-version = "1.85.0" default = ["aws-lc-rs", "rcgen"] # Default crypto provider -aws-lc-rs = ["dep:aws-lc-rs", "_crypto-common"] +aws-lc-rs = ["dep:aws-lc-rs", "dep:ccm", "dep:aes", "_crypto-common"] # Pure Rust crypto provider rust-crypto = [ "dep:aes-gcm", "dep:chacha20poly1305", "dep:chacha20", "dep:p256", "dep:p384", "dep:x25519-dalek", "dep:sha2", "dep:hmac", "dep:hkdf", "dep:ecdsa", "dep:generic-array", "dep:rand_core", + "dep:ccm", "dep:aes", "_crypto-common" ] @@ -68,6 +69,8 @@ generic-array = { version = "0.14", optional = true } rand_core = { version = "0.6", optional = true } chacha20poly1305 = { version = "0.10", optional = true } chacha20 = { version = "0.9", optional = true } +ccm = { version = "0.5", default-features = false, optional = true } +aes = { version = "0.8", optional = true } x25519-dalek = { version = "2", optional = true, features = ["static_secrets"] } # certificate generation diff --git a/README.md b/README.md index 9441d4f..2177f17 100644 --- a/README.md +++ b/README.md @@ -22,8 +22,9 @@ verification and SRTP key export yourself. ### Version selection -Three constructors control which DTLS version is used: -- [`Dtls::new_12`][new_12] — explicit DTLS 1.2 +Four constructors control which DTLS version is used: +- [`Dtls::new_12`][new_12] — explicit DTLS 1.2 (certificate‑based) +- [`Dtls::new_12_psk`][new_12_psk] — explicit DTLS 1.2 (PSK, no certificates) - [`Dtls::new_13`][new_13] — explicit DTLS 1.3 - [`Dtls::new_auto`][new_auto] — auto‑sense: the first incoming ClientHello determines the version (based on the @@ -34,6 +35,8 @@ Three constructors control which DTLS version is used: - `ECDHE_ECDSA_AES256_GCM_SHA384` - `ECDHE_ECDSA_AES128_GCM_SHA256` - `ECDHE_ECDSA_CHACHA20_POLY1305_SHA256` +- **PSK cipher suites (TLS 1.2 over DTLS)** + - `PSK_AES128_CCM_8` - **Cipher suites (TLS 1.3 over DTLS)** - `TLS_AES_128_GCM_SHA256` - `TLS_AES_256_GCM_SHA384` @@ -44,7 +47,6 @@ Three constructors control which DTLS version is used: - **DTLS‑SRTP**: Exports keying material for `SRTP_AEAD_AES_256_GCM`, `SRTP_AEAD_AES_128_GCM`, and `SRTP_AES128_CM_SHA1_80` ([RFC 5764], [RFC 7714]). - **Extended Master Secret** ([RFC 7627]) is negotiated and enforced (DTLS 1.2). -- Not supported: PSK cipher suites. ### Certificate model During the handshake the engine emits @@ -131,6 +133,37 @@ let dtls = mk_dtls_client(); let _ = example_event_loop(dtls); ``` +## Example (PSK client) + +```rust +use std::sync::Arc; +use std::time::Instant; + +use dimpl::{Config, Dtls, PskResolver}; + +struct MyPsk; + +impl PskResolver for MyPsk { + fn resolve(&self, identity: &[u8]) -> Option> { + if identity == b"device-01" { + Some(b"shared-secret-key".to_vec()) + } else { + None + } + } +} + +let config = Arc::new( + Config::builder() + .with_psk_client(b"device-01".to_vec(), Arc::new(MyPsk)) + .build() + .unwrap(), +); + +let mut dtls = Dtls::new_12_psk(config, Instant::now()); +dtls.set_active(true); // client role +``` + #### MSRV Rust 1.85.0 @@ -139,6 +172,7 @@ Rust 1.85.0 - Renegotiation is not implemented (WebRTC does full restart). [new_12]: https://docs.rs/dimpl/latest/dimpl/struct.Dtls.html#method.new_12 +[new_12_psk]: https://docs.rs/dimpl/latest/dimpl/struct.Dtls.html#method.new_12_psk [new_13]: https://docs.rs/dimpl/latest/dimpl/struct.Dtls.html#method.new_13 [new_auto]: https://docs.rs/dimpl/latest/dimpl/struct.Dtls.html#method.new_auto [peer_cert]: https://docs.rs/dimpl/latest/dimpl/enum.Output.html#variant.PeerCert diff --git a/src/auto.rs b/src/auto.rs index bfe6af2..52c41d3 100644 --- a/src/auto.rs +++ b/src/auto.rs @@ -105,7 +105,7 @@ impl HybridClientHello { ch_body.push(0); // cipher_suites: 1.3 suites first, then 1.2 suites (filtered by config) - let mut suites: ArrayVec = ArrayVec::new(); + let mut suites: ArrayVec = ArrayVec::new(); for cs in config.dtls13_cipher_suites() { suites.push(cs.suite().as_u16()); } diff --git a/src/config.rs b/src/config.rs index 138155d..f48d295 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,3 +1,6 @@ +use std::fmt; +use std::panic::{RefUnwindSafe, UnwindSafe}; +use std::sync::Arc; use std::time::Duration; use crate::Error; @@ -6,6 +9,41 @@ use crate::crypto::{SupportedDtls13CipherSuite, SupportedKxGroup}; use crate::dtls12::message::Dtls12CipherSuite; use crate::types::{Dtls13CipherSuite, NamedGroup}; +/// Callback for resolving PSK identities to shared secrets. +/// +/// Implement this trait and provide it via [`ConfigBuilder::with_psk_client`] +/// or [`ConfigBuilder::with_psk_server`] to enable PSK cipher suites. +pub trait PskResolver: Send + Sync + UnwindSafe + RefUnwindSafe { + /// Look up a pre-shared key by the peer's identity. + /// + /// Returns the shared secret bytes, or `None` if the identity is unknown. + fn resolve(&self, identity: &[u8]) -> Option>; +} + +/// PSK configuration for a DTLS endpoint. +/// +/// Use [`Psk::Client`] for endpoints that initiate PSK handshakes (send identity), +/// and [`Psk::Server`] for endpoints that resolve incoming identities. +#[derive(Clone)] +pub enum Psk { + /// Client-side PSK: sends `identity` during handshake, uses `resolver` + /// to look up the shared secret. + Client { + /// The identity to send to the server. + identity: Vec, + /// Resolver for looking up shared secrets. + resolver: Arc, + }, + /// Server-side PSK: optionally sends a `hint` to help the client choose + /// an identity, uses `resolver` to look up secrets by client identity. + Server { + /// Optional hint sent to the client in ServerKeyExchange. + hint: Option>, + /// Resolver for looking up shared secrets. + resolver: Arc, + }, +} + #[cfg(feature = "aws-lc-rs")] use crate::crypto::aws_lc_rs; @@ -15,7 +53,7 @@ use crate::crypto::rust_crypto; /// DTLS configuration shared by all connections. /// /// Build with [`Config::builder()`] or use [`Config::default()`]. -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct Config { mtu: usize, max_queue_rx: usize, @@ -31,6 +69,7 @@ pub struct Config { dtls12_cipher_suites: Option>, dtls13_cipher_suites: Option>, kx_groups: Option>, + psk: Option, } impl Config { @@ -51,6 +90,7 @@ impl Config { dtls12_cipher_suites: None, dtls13_cipher_suites: None, kx_groups: None, + psk: None, } } @@ -148,21 +188,58 @@ impl Config { self.aead_encryption_limit } + /// PSK configuration, if any. + pub fn psk(&self) -> Option<&Psk> { + self.psk.as_ref() + } + + /// PSK identity for the client to send during handshake. + pub fn psk_identity(&self) -> Option<&[u8]> { + match &self.psk { + Some(Psk::Client { identity, .. }) => Some(identity), + _ => None, + } + } + + /// PSK identity hint for the server to send during handshake. + pub fn psk_identity_hint(&self) -> Option<&[u8]> { + match &self.psk { + Some(Psk::Server { hint, .. }) => hint.as_deref(), + _ => None, + } + } + + /// PSK resolver for looking up shared secrets by identity. + pub fn psk_resolver(&self) -> Option<&dyn PskResolver> { + match &self.psk { + Some(Psk::Client { resolver, .. } | Psk::Server { resolver, .. }) => { + Some(resolver.as_ref()) + } + None => None, + } + } + /// Allowed DTLS 1.2 cipher suites, filtered by the config's allow-list. /// /// Returns all provider-supported DTLS 1.2 cipher suites when no filter /// is set. When a filter is set via the builder's `dtls12_cipher_suites` /// method, only suites in both the provider and the filter are returned. + /// + /// PSK cipher suites are excluded when no [`PskResolver`] is configured, + /// preventing a certificate-mode endpoint from negotiating a PSK suite + /// and inadvertently skipping certificate authentication. pub fn dtls12_cipher_suites( &self, ) -> impl Iterator + '_ { let filter = self.dtls12_cipher_suites.as_ref(); + let has_psk = self.psk.is_some(); self.crypto_provider .supported_cipher_suites() .filter(move |cs| match filter { Some(list) => list.contains(&cs.suite()), None => true, }) + .filter(move |cs| has_psk || !cs.suite().is_psk()) } /// Allowed DTLS 1.3 cipher suites, filtered by the config's allow-list. @@ -201,7 +278,6 @@ impl Config { } /// Builder for [`Config`]. See each setter for defaults. -#[derive(Debug)] pub struct ConfigBuilder { mtu: usize, max_queue_rx: usize, @@ -217,6 +293,7 @@ pub struct ConfigBuilder { dtls12_cipher_suites: Option>, dtls13_cipher_suites: Option>, kx_groups: Option>, + psk: Option, } impl ConfigBuilder { @@ -360,6 +437,28 @@ impl ConfigBuilder { self } + /// Configure PSK for a client endpoint. + /// + /// The `identity` is sent to the server during the handshake. + /// The `resolver` looks up the shared secret by identity. + pub fn with_psk_client(mut self, identity: Vec, resolver: Arc) -> Self { + self.psk = Some(Psk::Client { identity, resolver }); + self + } + + /// Configure PSK for a server endpoint. + /// + /// The optional `hint` is sent to the client in ServerKeyExchange. + /// The `resolver` looks up the shared secret by client identity. + pub fn with_psk_server( + mut self, + hint: Option>, + resolver: Arc, + ) -> Self { + self.psk = Some(Psk::Server { hint, resolver }); + self + } + /// Build the configuration. /// /// This validates the crypto provider before returning the configuration. @@ -429,14 +528,28 @@ impl ConfigBuilder { )); } + // Check if we have any non-PSK DTLS 1.2 suites that need key exchange groups + let has_non_psk_dtls12 = { + match &self.dtls12_cipher_suites { + Some(list) => crypto_provider + .supported_cipher_suites() + .filter(|cs| list.contains(&cs.suite())) + .any(|cs| !cs.suite().is_psk()), + None => crypto_provider + .supported_cipher_suites() + .any(|cs| !cs.suite().is_psk()), + } + }; + // Validate kx_groups filter: each enabled version needs compatible groups + // (PSK-only DTLS 1.2 configs don't need key exchange groups) let filtered_kx = |kx: &&'static dyn SupportedKxGroup| -> bool { match &self.kx_groups { Some(list) => list.contains(&kx.name()), None => true, } }; - if dtls12_count > 0 { + if has_non_psk_dtls12 { let dtls12_kx_count = crypto_provider .supported_kx_groups() .filter(|kx| filtered_kx(kx)) @@ -478,6 +591,7 @@ impl ConfigBuilder { dtls12_cipher_suites: self.dtls12_cipher_suites, dtls13_cipher_suites: self.dtls13_cipher_suites, kx_groups: self.kx_groups, + psk: self.psk, }) } } @@ -490,6 +604,73 @@ impl Default for Config { } } +impl fmt::Debug for Psk { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Psk::Client { identity, .. } => f + .debug_struct("Psk::Client") + .field("identity", &identity) + .field("resolver", &"...") + .finish(), + Psk::Server { hint, .. } => f + .debug_struct("Psk::Server") + .field("hint", &hint) + .field("resolver", &"...") + .finish(), + } + } +} + +impl fmt::Debug for Config { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Config") + .field("mtu", &self.mtu) + .field("max_queue_rx", &self.max_queue_rx) + .field("max_queue_tx", &self.max_queue_tx) + .field( + "require_client_certificate", + &self.require_client_certificate, + ) + .field("use_server_cookie", &self.use_server_cookie) + .field("flight_start_rto", &self.flight_start_rto) + .field("flight_retries", &self.flight_retries) + .field("handshake_timeout", &self.handshake_timeout) + .field("crypto_provider", &self.crypto_provider) + .field("rng_seed", &self.rng_seed) + .field("aead_encryption_limit", &self.aead_encryption_limit) + .field("dtls12_cipher_suites", &self.dtls12_cipher_suites) + .field("dtls13_cipher_suites", &self.dtls13_cipher_suites) + .field("kx_groups", &self.kx_groups) + .field("psk", &self.psk) + .finish() + } +} + +impl fmt::Debug for ConfigBuilder { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ConfigBuilder") + .field("mtu", &self.mtu) + .field("max_queue_rx", &self.max_queue_rx) + .field("max_queue_tx", &self.max_queue_tx) + .field( + "require_client_certificate", + &self.require_client_certificate, + ) + .field("use_server_cookie", &self.use_server_cookie) + .field("flight_start_rto", &self.flight_start_rto) + .field("flight_retries", &self.flight_retries) + .field("handshake_timeout", &self.handshake_timeout) + .field("crypto_provider", &self.crypto_provider) + .field("rng_seed", &self.rng_seed) + .field("aead_encryption_limit", &self.aead_encryption_limit) + .field("dtls12_cipher_suites", &self.dtls12_cipher_suites) + .field("dtls13_cipher_suites", &self.dtls13_cipher_suites) + .field("kx_groups", &self.kx_groups) + .field("psk", &self.psk) + .finish() + } +} + #[cfg(test)] mod tests { use super::*; @@ -666,11 +847,40 @@ mod tests { fn no_filter_returns_all() { let config = Config::default(); // Default provider should have at least 2 DTLS 1.2 and 2 DTLS 1.3 suites + // (PSK suites are excluded without a resolver, so only non-PSK count) assert!(config.dtls12_cipher_suites().count() >= 2); assert!(config.dtls13_cipher_suites().count() >= 2); assert!(config.kx_groups().count() >= 2); } + #[test] + fn psk_suites_excluded_without_resolver() { + let config = Config::default(); + assert!( + config.dtls12_cipher_suites().all(|cs| !cs.suite().is_psk()), + "PSK suites should be excluded when no PskResolver is configured" + ); + } + + #[test] + fn psk_suites_included_with_resolver() { + struct DummyResolver; + impl PskResolver for DummyResolver { + fn resolve(&self, _identity: &[u8]) -> Option> { + None + } + } + + let config = Config::builder() + .with_psk_server(None, Arc::new(DummyResolver)) + .build() + .expect("config with PSK resolver should build"); + assert!( + config.dtls12_cipher_suites().any(|cs| cs.suite().is_psk()), + "PSK suites should be included when a PskResolver is configured" + ); + } + #[test] fn filter_with_explicit_provider() { #[cfg(feature = "aws-lc-rs")] diff --git a/src/crypto/aws_lc_rs/cipher_suite.rs b/src/crypto/aws_lc_rs/cipher_suite.rs index 83308a7..3bc02c6 100644 --- a/src/crypto/aws_lc_rs/cipher_suite.rs +++ b/src/crypto/aws_lc_rs/cipher_suite.rs @@ -232,16 +232,50 @@ impl SupportedDtls12CipherSuite for ChaCha20Poly1305Sha256 { } } +/// TLS_PSK_WITH_AES_128_CCM_8 cipher suite. +#[derive(Debug)] +struct PskAes128Ccm8; + +impl SupportedDtls12CipherSuite for PskAes128Ccm8 { + fn suite(&self) -> Dtls12CipherSuite { + Dtls12CipherSuite::PSK_AES128_CCM_8 + } + + fn hash_algorithm(&self) -> HashAlgorithm { + HashAlgorithm::SHA256 + } + + fn key_lengths(&self) -> (usize, usize, usize) { + (0, 16, 4) // (mac_key_len, enc_key_len, fixed_iv_len) + } + + fn explicit_nonce_len(&self) -> usize { + 8 + } + + fn tag_len(&self) -> usize { + 8 + } + + fn create_cipher(&self, key: &[u8]) -> Result, String> { + Ok(Box::new(crate::crypto::ccm_cipher::AesCcm8Cipher::new( + key, + )?)) + } +} + /// Static instances of supported DTLS 1.2 cipher suites. static AES_128_GCM_SHA256: Aes128GcmSha256 = Aes128GcmSha256; static AES_256_GCM_SHA384: Aes256GcmSha384 = Aes256GcmSha384; static CHACHA20_POLY1305_SHA256: ChaCha20Poly1305Sha256 = ChaCha20Poly1305Sha256; +static PSK_AES_128_CCM_8: PskAes128Ccm8 = PskAes128Ccm8; /// All supported DTLS 1.2 cipher suites. pub(super) static ALL_CIPHER_SUITES: &[&dyn SupportedDtls12CipherSuite] = &[ &AES_128_GCM_SHA256, &AES_256_GCM_SHA384, &CHACHA20_POLY1305_SHA256, + &PSK_AES_128_CCM_8, ]; // ============================================================================ diff --git a/src/crypto/ccm_cipher.rs b/src/crypto/ccm_cipher.rs new file mode 100644 index 0000000..d5837ac --- /dev/null +++ b/src/crypto/ccm_cipher.rs @@ -0,0 +1,90 @@ +//! AES-128-CCM-8 cipher implementation using the RustCrypto `ccm` crate. +//! +//! Shared by both aws-lc-rs and rust-crypto backends since aws-lc-rs +//! does not expose CCM in its high-level API. + +use ccm::aead::AeadInPlace; +use ccm::aead::KeyInit; +use ccm::consts::{U8, U12}; + +use super::{Aad, Cipher, Nonce}; +use crate::buffer::{Buf, TmpBuf}; + +/// AES-128-CCM with 8-byte tag, 12-byte nonce. +type Aes128Ccm8 = ccm::Ccm; + +/// AES-128-CCM-8 cipher for TLS_PSK_WITH_AES_128_CCM_8. +pub struct AesCcm8Cipher { + cipher: Box, +} + +impl std::fmt::Debug for AesCcm8Cipher { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AesCcm8Cipher").finish_non_exhaustive() + } +} + +impl AesCcm8Cipher { + pub fn new(key: &[u8]) -> Result { + if key.len() != 16 { + return Err(format!("Invalid key size for AES-128-CCM-8: {}", key.len())); + } + let cipher = Aes128Ccm8::new_from_slice(key) + .map_err(|_| "Failed to create AES-128-CCM-8 cipher".to_string())?; + Ok(AesCcm8Cipher { + cipher: Box::new(cipher), + }) + } +} + +impl Cipher for AesCcm8Cipher { + fn encrypt(&mut self, plaintext: &mut Buf, aad: Aad, nonce: Nonce) -> Result<(), String> { + if nonce.len() != 12 { + return Err(format!( + "Invalid nonce length: expected 12, got {}", + nonce.len() + )); + } + + let ccm_nonce = ccm::aead::generic_array::GenericArray::from_slice(&nonce[..12]); + let tag = self + .cipher + .encrypt_in_place_detached(ccm_nonce, &aad[..], plaintext.as_mut()) + .map_err(|_| "AES-128-CCM-8 encryption failed".to_string())?; + + // Append the 8-byte tag + plaintext.extend_from_slice(&tag); + + Ok(()) + } + + fn decrypt(&mut self, ciphertext: &mut TmpBuf, aad: Aad, nonce: Nonce) -> Result<(), String> { + if ciphertext.len() < 8 { + return Err(format!("Ciphertext too short: {}", ciphertext.len())); + } + + if nonce.len() != 12 { + return Err(format!( + "Invalid nonce length: expected 12, got {}", + nonce.len() + )); + } + + let ccm_nonce = ccm::aead::generic_array::GenericArray::from_slice(&nonce[..12]); + + // Split off the 8-byte tag from the end + let data_len = ciphertext.len() - 8; + let mut tag_bytes = [0u8; 8]; + tag_bytes.copy_from_slice(&ciphertext.as_ref()[data_len..]); + let tag = ccm::aead::generic_array::GenericArray::from(tag_bytes); + + // Truncate to just the ciphertext (without tag) + ciphertext.truncate(data_len); + + self.cipher + .decrypt_in_place_detached(ccm_nonce, &aad[..], ciphertext.as_mut(), &tag) + .map_err(|_| "AES-128-CCM-8 decryption failed".to_string())?; + + Ok(()) + } +} diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 9c53469..e2ed7c8 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -12,6 +12,9 @@ pub mod aws_lc_rs; #[cfg(feature = "rust-crypto")] pub mod rust_crypto; +#[cfg(any(feature = "aws-lc-rs", feature = "rust-crypto"))] +pub(crate) mod ccm_cipher; + mod dtls_aead; mod provider; mod validation; diff --git a/src/crypto/rust_crypto/cipher_suite.rs b/src/crypto/rust_crypto/cipher_suite.rs index b0520d9..dc4ab0d 100644 --- a/src/crypto/rust_crypto/cipher_suite.rs +++ b/src/crypto/rust_crypto/cipher_suite.rs @@ -282,16 +282,50 @@ impl SupportedDtls12CipherSuite for ChaCha20Poly1305Sha256 { } } +/// TLS_PSK_WITH_AES_128_CCM_8 cipher suite. +#[derive(Debug)] +struct PskAes128Ccm8; + +impl SupportedDtls12CipherSuite for PskAes128Ccm8 { + fn suite(&self) -> Dtls12CipherSuite { + Dtls12CipherSuite::PSK_AES128_CCM_8 + } + + fn hash_algorithm(&self) -> HashAlgorithm { + HashAlgorithm::SHA256 + } + + fn key_lengths(&self) -> (usize, usize, usize) { + (0, 16, 4) // (mac_key_len, enc_key_len, fixed_iv_len) + } + + fn explicit_nonce_len(&self) -> usize { + 8 + } + + fn tag_len(&self) -> usize { + 8 + } + + fn create_cipher(&self, key: &[u8]) -> Result, String> { + Ok(Box::new(crate::crypto::ccm_cipher::AesCcm8Cipher::new( + key, + )?)) + } +} + /// Static instances of supported DTLS 1.2 cipher suites. static AES_128_GCM_SHA256: Aes128GcmSha256 = Aes128GcmSha256; static AES_256_GCM_SHA384: Aes256GcmSha384 = Aes256GcmSha384; static CHACHA20_POLY1305_SHA256: ChaCha20Poly1305Sha256 = ChaCha20Poly1305Sha256; +static PSK_AES_128_CCM_8: PskAes128Ccm8 = PskAes128Ccm8; /// All supported DTLS 1.2 cipher suites. pub(super) static ALL_CIPHER_SUITES: &[&dyn SupportedDtls12CipherSuite] = &[ &AES_128_GCM_SHA256, &AES_256_GCM_SHA384, &CHACHA20_POLY1305_SHA256, + &PSK_AES_128_CCM_8, ]; // ============================================================================ diff --git a/src/crypto/validation/mod.rs b/src/crypto/validation/mod.rs index a32eca1..fa041ae 100644 --- a/src/crypto/validation/mod.rs +++ b/src/crypto/validation/mod.rs @@ -48,7 +48,7 @@ impl CryptoProvider { sig_alg: SignatureAlgorithm, ) -> impl Iterator { self.supported_cipher_suites() - .filter(move |cs| cs.suite().signature_algorithm() == sig_alg) + .filter(move |cs| cs.suite().signature_algorithm() == Some(sig_alg)) } /// Check if provider supports ECDH-based cipher suites. @@ -217,7 +217,11 @@ impl CryptoProvider { // Test signature verification for each supported cipher suite for cs in self.supported_cipher_suites() { let hash_alg = cs.suite().hash_algorithm(); - let sig_alg = cs.suite().signature_algorithm(); + let sig_alg = match cs.suite().signature_algorithm() { + Some(alg) => alg, + // PSK suites have no signature — skip validation + None => continue, + }; let (cert_der, signature, test_data) = match (hash_alg, sig_alg) { (HashAlgorithm::SHA256, SignatureAlgorithm::ECDSA) => ( @@ -692,7 +696,9 @@ mod tests_aws_lc_rs { fn test_default_provider_has_cipher_suites() { let provider = aws_lc_rs::default_provider(); let count = provider.supported_cipher_suites().count(); - assert_eq!(count, 3); // AES-128, AES-256, and ChaCha20-Poly1305 + // ECDHE: AES-128, AES-256, ChaCha20 + // PSK: CCM-8 + assert_eq!(count, 4); } #[test] @@ -740,7 +746,9 @@ mod tests_rust_crypto { fn test_default_provider_has_cipher_suites() { let provider = rust_crypto::default_provider(); let count = provider.supported_cipher_suites().count(); - assert_eq!(count, 3); // AES-128, AES-256, and ChaCha20-Poly1305 + // ECDHE: AES-128, AES-256, ChaCha20 + // PSK: CCM-8 + assert_eq!(count, 4); } #[test] diff --git a/src/dtls12/client.rs b/src/dtls12/client.rs index 4bbf0af..fae997a 100644 --- a/src/dtls12/client.rs +++ b/src/dtls12/client.rs @@ -21,12 +21,17 @@ use subtle::ConstantTimeEq; use crate::buffer::{Buf, ToBuf}; use crate::crypto::SrtpProfile; use crate::dtls12::Server; +use crate::dtls12::context::AuthMode; use crate::dtls12::engine::Engine; -use crate::dtls12::message::{Body, CipherSuiteVec, ClientHello, ClientKeyExchange}; -use crate::dtls12::message::{CompressionMethod, ContentType, Cookie, Dtls12CipherSuite}; +use crate::dtls12::message::{ + Body, CipherSuiteVec, ClientHello, ClientKeyExchange, ClientPskKeys, ServerKeyExchangeParams, +}; +use crate::dtls12::message::{ + CompressionMethod, ContentType, Cookie, DigitallySigned, Dtls12CipherSuite, +}; use crate::dtls12::message::{ExtensionType, KeyExchangeAlgorithm, MessageType, ProtocolVersion}; use crate::dtls12::message::{Random, SessionId, SignatureAndHashAlgorithm, UseSrtpExtension}; -use crate::{Error, KeyingMaterial, Output}; +use crate::{Config, DtlsCertificate, Error, KeyingMaterial, Output}; /// DTLS client pub struct Client { @@ -121,11 +126,20 @@ impl Client { pub(crate) fn new_from_hybrid( random: Random, handshake_fragment: &[u8], - config: std::sync::Arc, - certificate: crate::DtlsCertificate, + config: std::sync::Arc, + certificate: DtlsCertificate, now: Instant, ) -> Result { - let mut engine = Engine::new(config, certificate); + let private_key = config + .crypto_provider() + .key_provider + .load_private_key(&certificate.private_key) + .expect("Failed to parse client private key"); + let auth = AuthMode::Certificate { + certificate: certificate.certificate, + private_key, + }; + let mut engine = Engine::new(config, auth); engine.set_client(true); // The hybrid ClientHello was sent with message_seq=0 outside this // engine. Advance the counter so the with-cookie CH gets message_seq=1 @@ -489,7 +503,12 @@ impl State { } trace!("Extended Master Secret enabled"); - Ok(Self::AwaitCertificate) + // PSK suites skip Certificate; go directly to ServerKeyExchange + if cs.is_psk() { + Ok(Self::AwaitServerKeyExchange) + } else { + Ok(Self::AwaitCertificate) + } } fn await_certificate(self, client: &mut Client) -> Result { @@ -537,6 +556,64 @@ impl State { } fn await_server_key_exchange(self, client: &mut Client) -> Result { + let cipher_suite = client + .engine + .cipher_suite() + .ok_or_else(|| Error::UnexpectedMessage("No cipher suite selected".to_string()))?; + + if cipher_suite.is_psk() { + self.await_server_key_exchange_psk(client) + } else { + self.await_server_key_exchange_ecdhe(client) + } + } + + /// PSK ServerKeyExchange carries only an optional identity hint (no signature). + /// Per RFC 4279 §2, ServerKeyExchange is omitted when the server has no hint. + fn await_server_key_exchange_psk(self, client: &mut Client) -> Result { + // If the server skipped ServerKeyExchange (no hint), go straight to ServerHelloDone + let has_done = client + .engine + .has_complete_handshake(MessageType::ServerHelloDone); + if has_done { + return Ok(Self::AwaitServerHelloDone); + } + + let maybe = client.engine.next_handshake( + MessageType::ServerKeyExchange, + &mut client.defragment_buffer, + )?; + + let Some(handshake) = maybe else { + return Ok(self); + }; + + let Body::ServerKeyExchange(ske) = &handshake.body else { + unreachable!() + }; + + // PSK ServerKeyExchange contains only an identity hint per RFC 4279 §2 + // (no curve_type or named_group — those are ECDHE-only parameters). + let hint_range = match &ske.params { + ServerKeyExchangeParams::Psk(psk) => psk.hint_range.clone(), + _ => { + return Err(Error::UnexpectedMessage( + "ECDHE ServerKeyExchange in PSK path".to_string(), + )); + } + }; + + drop(handshake); + + let hint = &client.defragment_buffer[hint_range]; + trace!("PSK identity hint ({} bytes)", hint.len()); + // Hint is informational only; we don't use it for PSK lookup currently + + // PSK has no CertificateRequest + Ok(Self::AwaitServerHelloDone) + } + + fn await_server_key_exchange_ecdhe(self, client: &mut Client) -> Result { let maybe = client.engine.next_handshake( MessageType::ServerKeyExchange, &mut client.defragment_buffer, @@ -566,11 +643,16 @@ impl State { // Extract ECDH params ranges let (curve_type, named_group, public_key_range) = match &server_key_exchange.params { - crate::dtls12::message::ServerKeyExchangeParams::Ecdh(ecdh) => ( + ServerKeyExchangeParams::Ecdh(ecdh) => ( ecdh.curve_type, ecdh.named_group, ecdh.public_key_range.clone(), ), + ServerKeyExchangeParams::Psk(_) => { + return Err(Error::UnexpectedMessage( + "PSK ServerKeyExchange in ECDHE path".to_string(), + )); + } }; ( @@ -617,19 +699,20 @@ impl State { } // Ensure the signature algorithm is compatible with the cipher suite - if signature_algorithm.signature != cipher_suite.signature_algorithm() { - return Err(Error::CryptoError(format!( - "Signature algorithm mismatch: {:?} != {:?}", - signature_algorithm.signature, - cipher_suite.signature_algorithm() - ))); + if let Some(expected_sig) = cipher_suite.signature_algorithm() { + if signature_algorithm.signature != expected_sig { + return Err(Error::CryptoError(format!( + "Signature algorithm mismatch: {:?} != {:?}", + signature_algorithm.signature, expected_sig + ))); + } } // unwrap: is ok because we verify the order of the flight let cert_der = client.server_certificates.first().unwrap(); // Create a temporary DigitallySigned for verification (we only need the algorithm) - let temp_signed = crate::dtls12::message::DigitallySigned { + let temp_signed = DigitallySigned { algorithm: signature_algorithm, signature_range: 0..signature_bytes.len(), }; @@ -690,10 +773,12 @@ impl State { // Check that the hash algorithm that is default fo the PrivateKey in use // is one of the supported by the CertificateRequest + // unwrap: CertificateRequest only received for certificate-based suites let hash_algorithm = client .engine .crypto_context() - .private_key_default_hash_algorithm(); + .private_key_default_hash_algorithm() + .unwrap(); if !cr.supports_hash_algorithm(hash_algorithm) { return Err(Error::CertificateError(format!( @@ -729,6 +814,16 @@ impl State { trace!("Received ServerHelloDone"); + let cipher_suite = client + .engine + .cipher_suite() + .ok_or_else(|| Error::UnexpectedMessage("No cipher suite selected".to_string()))?; + + if cipher_suite.is_psk() { + // PSK: no certificates involved + return Ok(Self::SendClientKeyExchange); + } + // Validate the server certificate if client.server_certificates.is_empty() { return Err(Error::CertificateError( @@ -1130,24 +1225,11 @@ fn handshake_create_client_key_exchange(body: &mut Buf, engine: &mut Engine) -> debug!("Using key exchange algorithm: {:?}", key_exchange_algorithm); - // For ECDHE, get group info before we create the handshake (to avoid borrow issues) - let group_info = if key_exchange_algorithm == KeyExchangeAlgorithm::EECDH { - engine.crypto_context().get_key_exchange_group_info() - } else { - None - }; - - // Generate key exchange data - let public_key = engine - .crypto_context_mut() - .maybe_init_key_exchange() - .map_err(|e| Error::CryptoError(format!("Failed to generate key exchange: {}", e)))?; - - trace!("Generated public key size: {} bytes", public_key.len()); - - // Validate key exchange algorithm match key_exchange_algorithm { KeyExchangeAlgorithm::EECDH => { + // Get group info before the mutable borrow + let group_info = engine.crypto_context().get_key_exchange_group_info(); + // For ECDHE, use the group information we retrieved earlier let Some((curve_type, named_group)) = group_info else { unreachable!("No group info available for ECDHE"); @@ -1157,6 +1239,40 @@ fn handshake_create_client_key_exchange(body: &mut Buf, engine: &mut Engine) -> "Using ECDHE group info: {:?}, {:?}", curve_type, named_group ); + + let public_key = engine + .crypto_context_mut() + .maybe_init_key_exchange() + .map_err(|e| { + Error::CryptoError(format!("Failed to generate key exchange: {}", e)) + })?; + + trace!("Generated public key size: {} bytes", public_key.len()); + ClientKeyExchange::serialize_from_bytes(public_key, body); + } + KeyExchangeAlgorithm::PSK => { + let identity = engine + .config() + .psk_identity() + .ok_or_else(|| Error::PskError("No PSK identity configured".to_string()))? + .to_vec(); + + // Resolve the PSK via the configured resolver + let psk = engine + .config() + .psk_resolver() + .ok_or_else(|| Error::PskError("No PSK resolver configured".to_string()))? + .resolve(&identity) + .ok_or_else(|| Error::PskError("PSK resolver returned no key".to_string()))?; + + // Set the PSK and compute pre-master secret + let crypto = engine.crypto_context_mut(); + crypto.set_psk(psk); + crypto + .compute_psk_pre_master_secret() + .map_err(|e| Error::CryptoError(format!("Failed to compute PSK PMS: {}", e)))?; + + ClientPskKeys::serialize_from_bytes(&identity, body); } _ => { return Err(Error::SecurityError( @@ -1165,9 +1281,6 @@ fn handshake_create_client_key_exchange(body: &mut Buf, engine: &mut Engine) -> } } - // Serialize the public key directly - ClientKeyExchange::serialize_from_bytes(public_key, body); - Ok(()) } @@ -1177,11 +1290,16 @@ fn handshake_create_certificate_verify(body: &mut Buf, engine: &mut Engine) -> R // if we negotiate ECDHE_ECDSA_AES256_GCM_SHA384, we are gogin to use // SHA384 for the signature of the main crypto, but not for CertificateVerify // where a private key using P256 curve means we use SHA256. - let hash_alg = engine.crypto_context().private_key_default_hash_algorithm(); + // unwrap: CertificateVerify only sent for certificate-based suites + let hash_alg = engine + .crypto_context() + .private_key_default_hash_algorithm() + .unwrap(); debug!("Using hash algorithm for signature: {:?}", hash_alg); // Get the signature algorithm type - let sig_alg = engine.crypto_context().signature_algorithm(); + // unwrap: CertificateVerify only sent for certificate-based suites + let sig_alg = engine.crypto_context().signature_algorithm().unwrap(); debug!("Using signature algorithm: {:?}", sig_alg); // Create the signature algorithm diff --git a/src/dtls12/context.rs b/src/dtls12/context.rs index 58887b8..fe4e08c 100644 --- a/src/dtls12/context.rs +++ b/src/dtls12/context.rs @@ -9,8 +9,24 @@ use crate::crypto; use crate::crypto::SrtpProfile; use crate::crypto::{Aad, Iv, Nonce}; use crate::dtls12::message::DigitallySigned; -use crate::dtls12::message::{Asn1Cert, Certificate, CurveType}; -use crate::dtls12::message::{Dtls12CipherSuite, HashAlgorithm, NamedGroup, SignatureAlgorithm}; +use crate::dtls12::message::{Asn1Cert, Certificate}; +use crate::dtls12::message::{ + CurveType, Dtls12CipherSuite, HashAlgorithm, NamedGroup, SignatureAlgorithm, +}; + +/// Authentication mode for a DTLS 1.2 session. +pub enum AuthMode { + /// Certificate-based authentication (ECDHE_ECDSA suites). + Certificate { + /// DER-encoded certificate. + certificate: Vec, + /// Parsed signing key for the certificate. + private_key: Box, + }, + /// Pre-shared key authentication (PSK suites). + /// The actual PSK value is resolved during the handshake via [`CryptoContext::set_psk`]. + Psk, +} /// DTLS 1.2 crypto context holding negotiated keys and ciphers for a session. pub struct CryptoContext { @@ -56,11 +72,11 @@ pub struct CryptoContext { /// Server cipher server_cipher: Option>, - /// Certificate (DER format) - certificate: Vec, + /// Authentication mode: certificate or PSK. + auth: AuthMode, - /// Parsed private key for the certificate with signature algorithm - private_key: Box, + /// Resolved PSK value (set during handshake after identity exchange) + psk: Option>, /// Client random (needed for SRTP key export per RFC 5705) client_random: Option>, @@ -70,28 +86,8 @@ pub struct CryptoContext { } impl CryptoContext { - /// Create a new crypto context - pub fn new( - certificate: Vec, - private_key_bytes: Vec, - config: Arc, - ) -> Self { - // Validate that we have a certificate and private key - if certificate.is_empty() { - panic!("Client certificate cannot be empty"); - } - - if private_key_bytes.is_empty() { - panic!("Client private key cannot be empty"); - } - - // Parse the private key using the provider - let private_key = config - .crypto_provider() - .key_provider - .load_private_key(&private_key_bytes) - .expect("Failed to parse client private key"); - + /// Create a new crypto context with the given authentication mode. + pub fn new(auth: AuthMode, config: Arc) -> Self { CryptoContext { config, key_exchange: None, @@ -107,8 +103,8 @@ impl CryptoContext { pre_master_secret: None, client_cipher: None, server_cipher: None, - certificate, - private_key, + auth, + psk: None, client_random: None, server_random: None, } @@ -154,6 +150,28 @@ impl CryptoContext { Ok(()) } + /// Set the resolved PSK value for this session. + pub fn set_psk(&mut self, psk: Vec) { + self.psk = Some(psk); + } + + /// Compute PSK pre-master secret per RFC 4279 §2. + /// + /// Format: `uint16(N) || zeros(N) || uint16(N) || PSK(N)` + /// where N is the PSK length. + pub fn compute_psk_pre_master_secret(&mut self) -> Result<(), String> { + let psk = self.psk.as_ref().ok_or("PSK not set")?; + let n = psk.len(); + // Total: 2 + N + 2 + N = 2N + 4 + let mut pms = Buf::new(); + pms.extend_from_slice(&(n as u16).to_be_bytes()); + pms.extend_from_slice(&vec![0u8; n]); + pms.extend_from_slice(&(n as u16).to_be_bytes()); + pms.extend_from_slice(psk); + self.pre_master_secret = Some(pms); + Ok(()) + } + /// Initialize ECDHE key exchange (server role) and return our ephemeral public key pub fn init_ecdh_server( &mut self, @@ -370,31 +388,41 @@ impl CryptoContext { } } - /// Get client certificate for authentication + /// Get client certificate for authentication. + /// Panics if no certificate is configured (PSK-only mode). pub fn get_client_certificate(&self) -> Certificate { - // We validate in constructor, so we can assume we have a certificate - // Create an Asn1Cert with a range covering the entire certificate - let cert = Asn1Cert(0..self.certificate.len()); + // unwrap: only called for certificate-based suites + let AuthMode::Certificate { certificate, .. } = &self.auth else { + panic!("get_client_certificate called in PSK mode"); + }; + let cert = Asn1Cert(0..certificate.len()); let mut certs = ArrayVec::new(); certs.push(cert); Certificate::new(certs) } - /// Serialize client certificate for authentication + /// Serialize client certificate for authentication. + /// Panics if no certificate is configured (PSK-only mode). pub fn serialize_client_certificate(&self, output: &mut Buf) { let cert = self.get_client_certificate(); - cert.serialize(&self.certificate, output); + let AuthMode::Certificate { certificate, .. } = &self.auth else { + panic!("serialize_client_certificate called in PSK mode"); + }; + cert.serialize(certificate, output); } - /// Sign the provided data using the client's private key - /// Returns the signature or an error if signing fails + /// Sign the provided data using the client's private key. + /// Returns an error if no private key is configured (PSK-only mode). pub fn sign_data( &mut self, data: &[u8], _hash_alg: HashAlgorithm, out: &mut Buf, ) -> Result<(), String> { - self.private_key.sign(data, out) + let AuthMode::Certificate { private_key, .. } = &mut self.auth else { + return Err("No private key configured (PSK mode)".to_string()); + }; + private_key.sign(data, out) } /// Generate verify data for a Finished message using PRF @@ -485,7 +513,30 @@ impl CryptoContext { Ok(keying_material) } - /// Get group info for ECDHE key exchange + /// Signature algorithm for the configured private key. + /// Returns None in PSK-only mode. + pub fn signature_algorithm(&self) -> Option { + match &self.auth { + AuthMode::Certificate { private_key, .. } => Some(private_key.algorithm()), + AuthMode::Psk => None, + } + } + + /// Default hash algorithm for the configured private key. + /// Returns None in PSK-only mode. + pub fn private_key_default_hash_algorithm(&self) -> Option { + match &self.auth { + AuthMode::Certificate { private_key, .. } => Some(private_key.hash_algorithm()), + AuthMode::Psk => None, + } + } + + /// Create a hash context for the given algorithm + pub fn create_hash(&self, algorithm: HashAlgorithm) -> Box { + self.provider().hash_provider.create_hash(algorithm) + } + + /// Get the key exchange group info (curve type and named group). pub fn get_key_exchange_group_info(&self) -> Option<(CurveType, NamedGroup)> { // Use stored group if available (after key exchange is consumed) if let Some(group) = self.key_exchange_group { @@ -499,24 +550,18 @@ impl CryptoContext { Some((CurveType::NamedCurve, ke.group())) } - /// Signature algorithm for the configured private key - pub fn signature_algorithm(&self) -> SignatureAlgorithm { - self.private_key.algorithm() - } - - /// Default hash algorithm for the configured private key - pub fn private_key_default_hash_algorithm(&self) -> HashAlgorithm { - self.private_key.hash_algorithm() - } - - /// Create a hash context for the given algorithm - pub fn create_hash(&self, algorithm: HashAlgorithm) -> Box { - self.provider().hash_provider.create_hash(algorithm) - } - /// Check if the client's private key is compatible with a given cipher suite. pub fn is_cipher_suite_compatible(&self, cipher_suite: Dtls12CipherSuite) -> bool { - cipher_suite.signature_algorithm() == self.private_key.algorithm() + match (&self.auth, cipher_suite.signature_algorithm()) { + // Certificate-based suite needs a matching private key + (AuthMode::Certificate { private_key, .. }, Some(sig_alg)) => { + sig_alg == private_key.algorithm() + } + // PSK suite is only compatible in PSK mode + (AuthMode::Psk, None) => true, + // Mismatch: cert context + PSK suite, or PSK context + cert suite + _ => false, + } } /// Get the client write IV if derived. @@ -546,3 +591,88 @@ impl CryptoContext { ) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::Config; + + #[cfg(feature = "rcgen")] + fn cert_auth_mode(config: &Config) -> AuthMode { + let cert = crate::certificate::generate_self_signed_certificate().expect("generate cert"); + let private_key = config + .crypto_provider() + .key_provider + .load_private_key(&cert.private_key) + .expect("parse key"); + AuthMode::Certificate { + certificate: cert.certificate, + private_key, + } + } + + #[test] + #[cfg(feature = "rcgen")] + fn certificate_mode_rejects_psk_suites() { + let config = Arc::new(Config::default()); + let auth = cert_auth_mode(&config); + let ctx = CryptoContext::new(auth, config); + + for suite in Dtls12CipherSuite::supported() { + if suite.is_psk() { + assert!( + !ctx.is_cipher_suite_compatible(*suite), + "Certificate-mode context must reject PSK suite {:?}", + suite + ); + } + } + } + + #[test] + #[cfg(feature = "rcgen")] + fn certificate_mode_accepts_ecdhe_suites() { + let config = Arc::new(Config::default()); + let auth = cert_auth_mode(&config); + let ctx = CryptoContext::new(auth, config); + + // At least one ECDHE_ECDSA suite should be compatible + assert!( + Dtls12CipherSuite::supported() + .iter() + .filter(|s| !s.is_psk()) + .any(|s| ctx.is_cipher_suite_compatible(*s)), + "Certificate-mode context must accept at least one ECDHE suite" + ); + } + + #[test] + fn psk_mode_rejects_certificate_suites() { + let config = Arc::new(Config::default()); + let ctx = CryptoContext::new(AuthMode::Psk, config); + + for suite in Dtls12CipherSuite::supported() { + if !suite.is_psk() { + assert!( + !ctx.is_cipher_suite_compatible(*suite), + "PSK-mode context must reject certificate suite {:?}", + suite + ); + } + } + } + + #[test] + fn psk_mode_accepts_psk_suites() { + let config = Arc::new(Config::default()); + let ctx = CryptoContext::new(AuthMode::Psk, config); + + assert!( + Dtls12CipherSuite::supported() + .iter() + .filter(|s| s.is_psk()) + .any(|s| ctx.is_cipher_suite_compatible(*s)), + "PSK-mode context must accept at least one PSK suite" + ); + } +} diff --git a/src/dtls12/engine.rs b/src/dtls12/engine.rs index 69e310f..c335042 100644 --- a/src/dtls12/engine.rs +++ b/src/dtls12/engine.rs @@ -6,7 +6,7 @@ use std::time::{Duration, Instant}; use super::queue::{QueueRx, QueueTx}; use crate::buffer::{Buf, BufferPool, TmpBuf}; use crate::crypto::{Aad, Iv, Nonce}; -use crate::dtls12::context::CryptoContext; +use crate::dtls12::context::{AuthMode, CryptoContext}; use crate::dtls12::incoming::{Incoming, Record, RecordDecrypt}; use crate::dtls12::message::{Body, HashAlgorithm, Header, MessageType, ProtocolVersion, Sequence}; use crate::dtls12::message::{ContentType, DTLSRecord, Dtls12CipherSuite, Handshake}; @@ -105,17 +105,13 @@ struct Entry { } impl Engine { - pub fn new(config: Arc, certificate: crate::DtlsCertificate) -> Self { + pub fn new(config: Arc, auth: AuthMode) -> Self { let mut rng = SeededRng::new(config.rng_seed()); let flight_backoff = ExponentialBackoff::new(config.flight_start_rto(), config.flight_retries(), &mut rng); - let crypto_context = CryptoContext::new( - certificate.certificate, - certificate.private_key, - Arc::clone(&config), - ); + let crypto_context = CryptoContext::new(auth, Arc::clone(&config)); Self { config, diff --git a/src/dtls12/message/client_key_exchange.rs b/src/dtls12/message/client_key_exchange.rs index 43c5932..38c666a 100644 --- a/src/dtls12/message/client_key_exchange.rs +++ b/src/dtls12/message/client_key_exchange.rs @@ -15,6 +15,7 @@ pub struct ClientKeyExchange { #[derive(Debug, PartialEq, Eq)] pub enum ExchangeKeys { Ecdh(ClientEcdhKeys), + Psk(ClientPskKeys), } /// ECDHE key exchange parameters @@ -72,6 +73,10 @@ impl ClientKeyExchange { let (input, ecdh_keys) = ClientEcdhKeys::parse(input, base_offset)?; (input, ExchangeKeys::Ecdh(ecdh_keys)) } + KeyExchangeAlgorithm::PSK => { + let (input, psk_keys) = ClientPskKeys::parse(input, base_offset)?; + (input, ExchangeKeys::Psk(psk_keys)) + } _ => return Err(Err::Failure(Error::new(input, nom::error::ErrorKind::Tag))), }; @@ -81,6 +86,7 @@ impl ClientKeyExchange { pub fn serialize(&self, buf: &[u8], output: &mut Buf) { match &self.exchange_keys { ExchangeKeys::Ecdh(ecdh_keys) => ecdh_keys.serialize(buf, output), + ExchangeKeys::Psk(psk_keys) => psk_keys.serialize(buf, output), } } @@ -91,6 +97,49 @@ impl ClientKeyExchange { } } +/// PSK identity sent by the client (RFC 4279 §2). +/// +/// Wire format: `uint16 identity_length + identity` +#[derive(Debug, PartialEq, Eq)] +pub struct ClientPskKeys { + pub identity_range: Range, +} + +impl ClientPskKeys { + pub fn identity<'a>(&self, buf: &'a [u8]) -> &'a [u8] { + &buf[self.identity_range.clone()] + } + + pub fn parse(input: &[u8], base_offset: usize) -> IResult<&[u8], ClientPskKeys> { + let original_input = input; + let (input, identity_len) = nom::number::complete::be_u16(input)?; + let (input, identity_slice) = take(identity_len as usize)(input)?; + + let relative_offset = identity_slice.as_ptr() as usize - original_input.as_ptr() as usize; + let start = base_offset + relative_offset; + let end = start + identity_slice.len(); + + Ok(( + input, + ClientPskKeys { + identity_range: start..end, + }, + )) + } + + pub fn serialize(&self, buf: &[u8], output: &mut Buf) { + let identity = self.identity(buf); + output.extend_from_slice(&(identity.len() as u16).to_be_bytes()); + output.extend_from_slice(identity); + } + + /// Serialize directly from identity bytes (for sending). + pub fn serialize_from_bytes(identity: &[u8], output: &mut Buf) { + output.extend_from_slice(&(identity.len() as u16).to_be_bytes()); + output.extend_from_slice(identity); + } +} + #[cfg(test)] mod test { use super::super::KeyExchangeAlgorithm; diff --git a/src/dtls12/message/mod.rs b/src/dtls12/message/mod.rs index 75d75d5..78d8d3d 100644 --- a/src/dtls12/message/mod.rs +++ b/src/dtls12/message/mod.rs @@ -27,7 +27,7 @@ pub use certificate::Certificate; pub use certificate_request::CertificateRequest; pub use certificate_verify::CertificateVerify; pub use client_hello::ClientHello; -pub use client_key_exchange::{ClientKeyExchange, ExchangeKeys}; +pub use client_key_exchange::{ClientKeyExchange, ClientPskKeys, ExchangeKeys}; pub use digitally_signed::DigitallySigned; pub use extension::{Extension, ExtensionType}; pub use extensions::signature_algorithms::SignatureAlgorithmsExtension; @@ -46,7 +46,7 @@ pub use crate::types::{ Random, Sequence, SignatureAlgorithm, }; pub use server_hello::ServerHello; -pub use server_key_exchange::{ServerKeyExchange, ServerKeyExchangeParams}; +pub use server_key_exchange::{PskParams, ServerKeyExchange, ServerKeyExchangeParams}; pub use wrapped::{Asn1Cert, DistinguishedName}; use nom::IResult; @@ -66,6 +66,10 @@ pub enum Dtls12CipherSuite { /// ECDHE with ECDSA authentication, ChaCha20-Poly1305, SHA-256 ECDHE_ECDSA_CHACHA20_POLY1305_SHA256, // 0xCCA9 + // PSK cipher suites (no certificate authentication) + /// PSK with AES-128-CCM-8 (8-byte tag), SHA-256 + PSK_AES128_CCM_8, // 0xC0A8 + /// Unknown or unsupported cipher suite by its IANA value Unknown(u16), } @@ -85,6 +89,9 @@ impl Dtls12CipherSuite { 0xC02B => Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256, 0xCCA9 => Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256, + // PSK + 0xC0A8 => Dtls12CipherSuite::PSK_AES128_CCM_8, + _ => Dtls12CipherSuite::Unknown(value), } } @@ -97,6 +104,8 @@ impl Dtls12CipherSuite { Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 => 0xC02B, Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 => 0xCCA9, + Dtls12CipherSuite::PSK_AES128_CCM_8 => 0xC0A8, + Dtls12CipherSuite::Unknown(value) => *value, } } @@ -113,7 +122,8 @@ impl Dtls12CipherSuite { // AES-GCM suites Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384 | Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 - | Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 => 12, + | Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 + | Dtls12CipherSuite::PSK_AES128_CCM_8 => 12, Dtls12CipherSuite::Unknown(_) => 12, // Default length for unknown cipher suites } @@ -129,6 +139,8 @@ impl Dtls12CipherSuite { KeyExchangeAlgorithm::EECDH } + Dtls12CipherSuite::PSK_AES128_CCM_8 => KeyExchangeAlgorithm::PSK, + Dtls12CipherSuite::Unknown(_) => KeyExchangeAlgorithm::Unknown, } } @@ -143,12 +155,18 @@ impl Dtls12CipherSuite { ) } + /// Whether this cipher suite uses PSK (Pre-Shared Key) key exchange. + pub fn is_psk(&self) -> bool { + matches!(self, Dtls12CipherSuite::PSK_AES128_CCM_8) + } + /// All supported cipher suites in server preference order. - pub const fn all() -> &'static [Dtls12CipherSuite; 3] { + pub const fn all() -> &'static [Dtls12CipherSuite; 4] { &[ Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384, Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256, Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256, + Dtls12CipherSuite::PSK_AES128_CCM_8, ] } @@ -179,18 +197,24 @@ impl Dtls12CipherSuite { match self { Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384 => HashAlgorithm::SHA384, Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 - | Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 => HashAlgorithm::SHA256, + | Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 + | Dtls12CipherSuite::PSK_AES128_CCM_8 => HashAlgorithm::SHA256, Dtls12CipherSuite::Unknown(_) => HashAlgorithm::Unknown(0), } } /// The signature algorithm associated with the suite's key exchange. - pub fn signature_algorithm(&self) -> SignatureAlgorithm { + /// + /// Returns `None` for PSK cipher suites (no signature authentication). + pub fn signature_algorithm(&self) -> Option { match self { Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384 | Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 - | Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 => SignatureAlgorithm::ECDSA, - Dtls12CipherSuite::Unknown(_) => SignatureAlgorithm::Unknown(0), + | Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 => { + Some(SignatureAlgorithm::ECDSA) + } + Dtls12CipherSuite::PSK_AES128_CCM_8 => None, + Dtls12CipherSuite::Unknown(_) => Some(SignatureAlgorithm::Unknown(0)), } } @@ -200,7 +224,7 @@ impl Dtls12CipherSuite { } /// Supported DTLS 1.2 cipher suites in server preference order. - pub const fn supported() -> &'static [Dtls12CipherSuite; 3] { + pub const fn supported() -> &'static [Dtls12CipherSuite; 4] { Self::all() } } @@ -213,6 +237,7 @@ pub type CompressionMethodVec = #[allow(clippy::upper_case_acronyms)] pub enum KeyExchangeAlgorithm { EECDH, + PSK, Unknown, } diff --git a/src/dtls12/message/server_key_exchange.rs b/src/dtls12/message/server_key_exchange.rs index 41651fa..e868a76 100644 --- a/src/dtls12/message/server_key_exchange.rs +++ b/src/dtls12/message/server_key_exchange.rs @@ -14,6 +14,7 @@ pub struct ServerKeyExchange { #[derive(Debug, PartialEq, Eq)] pub enum ServerKeyExchangeParams { Ecdh(EcdhParams), + Psk(PskParams), } impl ServerKeyExchange { @@ -27,6 +28,10 @@ impl ServerKeyExchange { let (input, ecdh_params) = EcdhParams::parse(input, base_offset)?; (input, ServerKeyExchangeParams::Ecdh(ecdh_params)) } + KeyExchangeAlgorithm::PSK => { + let (input, psk_params) = PskParams::parse(input, base_offset)?; + (input, ServerKeyExchangeParams::Psk(psk_params)) + } _ => return Err(Err::Failure(Error::new(input, ErrorKind::Tag))), }; @@ -38,12 +43,14 @@ impl ServerKeyExchange { ServerKeyExchangeParams::Ecdh(ecdh_params) => { ecdh_params.serialize(buf, output, with_signature) } + ServerKeyExchangeParams::Psk(psk_params) => psk_params.serialize(buf, output), } } pub fn signature(&self) -> Option<&DigitallySigned> { match &self.params { ServerKeyExchangeParams::Ecdh(ecdh_params) => ecdh_params.signature.as_ref(), + ServerKeyExchangeParams::Psk(_) => None, } } } @@ -113,6 +120,49 @@ impl EcdhParams { } } +/// PSK identity hint (RFC 4279 §2). +/// +/// Wire format: `uint16 hint_length + hint` +#[derive(Debug, PartialEq, Eq)] +pub struct PskParams { + pub hint_range: Range, +} + +impl PskParams { + pub fn hint<'a>(&self, buf: &'a [u8]) -> &'a [u8] { + &buf[self.hint_range.clone()] + } + + pub fn parse(input: &[u8], base_offset: usize) -> IResult<&[u8], PskParams> { + let original_input = input; + let (input, hint_len) = nom::number::complete::be_u16(input)?; + let (input, hint_slice) = take(hint_len as usize)(input)?; + + let relative_offset = hint_slice.as_ptr() as usize - original_input.as_ptr() as usize; + let start = base_offset + relative_offset; + let end = start + hint_slice.len(); + + Ok(( + input, + PskParams { + hint_range: start..end, + }, + )) + } + + pub fn serialize(&self, buf: &[u8], output: &mut Buf) { + let hint = self.hint(buf); + output.extend_from_slice(&(hint.len() as u16).to_be_bytes()); + output.extend_from_slice(hint); + } + + /// Serialize directly from hint bytes (for sending). + pub fn serialize_from_bytes(hint: &[u8], output: &mut Buf) { + output.extend_from_slice(&(hint.len() as u16).to_be_bytes()); + output.extend_from_slice(hint); + } +} + #[cfg(test)] mod test { use super::super::{HashAlgorithm, SignatureAlgorithm, SignatureAndHashAlgorithm}; diff --git a/src/dtls12/server.rs b/src/dtls12/server.rs index b157908..300ca7e 100644 --- a/src/dtls12/server.rs +++ b/src/dtls12/server.rs @@ -23,7 +23,9 @@ use crate::buffer::{Buf, ToBuf}; use crate::crypto::SrtpProfile; use crate::dtls12::Client; use crate::dtls12::client::LocalEvent; +use crate::dtls12::context::AuthMode; use crate::dtls12::engine::Engine; +use crate::dtls12::message::PskParams; use crate::dtls12::message::{Body, CertificateRequest, CertificateTypeVec, Dtls12CipherSuite}; use crate::dtls12::message::{ClientCertificateType, CompressionMethod, ContentType}; use crate::dtls12::message::{Cookie, CurveType, DistinguishedName, ExchangeKeys, ExtensionType}; @@ -76,6 +78,10 @@ pub struct Server { /// Captured session hash for Extended Master Secret (RFC 7627) captured_session_hash: Option, + /// Whether the PSK identity resolved to a real key. + /// Defaults to `true` so non-PSK paths are unaffected. + psk_valid: bool, + /// The last now we seen last_now: Instant, @@ -108,7 +114,22 @@ enum State { impl Server { /// Create a new DTLS server pub fn new(config: Arc, certificate: crate::DtlsCertificate, now: Instant) -> Server { - let engine = Engine::new(config, certificate); + let private_key = config + .crypto_provider() + .key_provider + .load_private_key(&certificate.private_key) + .expect("Failed to parse server private key"); + let auth = AuthMode::Certificate { + certificate: certificate.certificate, + private_key, + }; + let engine = Engine::new(config, auth); + Self::new_with_engine(engine, now) + } + + /// Create a new PSK-only DTLS server (no certificate). + pub fn new_psk(config: Arc, now: Instant) -> Server { + let engine = Engine::new(config, AuthMode::Psk); Self::new_with_engine(engine, now) } @@ -131,6 +152,7 @@ impl Server { client_certificates: Vec::with_capacity(3), defragment_buffer: Buf::new(), captured_session_hash: None, + psk_valid: true, last_now: now, local_events: VecDeque::new(), queued_data: Vec::new(), @@ -439,7 +461,17 @@ impl State { ) })?; - Ok(Self::SendCertificate) + let cs = server + .engine + .cipher_suite() + .ok_or_else(|| Error::UnexpectedMessage("No cipher suite selected".to_string()))?; + + // PSK suites skip Certificate + if cs.is_psk() { + Ok(Self::SendServerKeyExchange) + } else { + Ok(Self::SendCertificate) + } } fn send_certificate(self, server: &mut Server) -> Result { @@ -455,6 +487,15 @@ impl State { fn send_server_key_exchange(self, server: &mut Server) -> Result { trace!("Sending ServerKeyExchange"); + let cs = server + .engine + .cipher_suite() + .ok_or_else(|| Error::UnexpectedMessage("No cipher suite selected".to_string()))?; + + if cs.is_psk() { + return self.send_server_key_exchange_psk(server); + } + let client_random = server .client_random .ok_or_else(|| Error::UnexpectedMessage("No client random".to_string()))?; @@ -489,9 +530,14 @@ impl State { // Select signature/hash for SKE by intersecting client's list // with our key type (prefer SHA256, then SHA384) + // unwrap: ServerKeyExchange signature only needed for certificate-based suites let selected_signature = select_ske_signature_algorithm( server.client_signature_algorithms.as_ref(), - server.engine.crypto_context().signature_algorithm(), + server + .engine + .crypto_context() + .signature_algorithm() + .unwrap(), ); debug!( @@ -519,6 +565,26 @@ impl State { } } + /// PSK ServerKeyExchange: send identity hint only (no ECDHE, no signature). + fn send_server_key_exchange_psk(self, server: &mut Server) -> Result { + let hint = server + .engine + .config() + .psk_identity_hint() + .unwrap_or(&[]) + .to_vec(); + + server + .engine + .create_handshake(MessageType::ServerKeyExchange, move |body, _engine| { + PskParams::serialize_from_bytes(&hint, body); + Ok(()) + })?; + + // PSK never sends CertificateRequest + Ok(Self::SendServerHelloDone) + } + fn send_certificate_request(self, server: &mut Server) -> Result { debug!("Sending CertificateRequest"); // Select CertificateRequest.signature_algorithms as intersection of client's list and our supported @@ -545,6 +611,16 @@ impl State { .engine .create_handshake(MessageType::ServerHelloDone, |_, _| Ok(()))?; + let cs = server + .engine + .cipher_suite() + .ok_or_else(|| Error::UnexpectedMessage("No cipher suite selected".to_string()))?; + + // PSK: no client certificates + if cs.is_psk() { + return Ok(Self::AwaitClientKeyExchange); + } + if server.engine.config().require_client_certificate() { Ok(Self::AwaitCertificate) } else { @@ -619,31 +695,85 @@ impl State { .cipher_suite() .ok_or_else(|| Error::UnexpectedMessage("No cipher suite selected".to_string()))?; - // Extract client's public key range before dropping handshake - let public_key_range = match &ckx.exchange_keys { - ExchangeKeys::Ecdh(keys) => keys.public_key_range.clone(), - }; + if suite.is_psk() { + // Extract PSK identity range before dropping handshake + let identity_range = match &ckx.exchange_keys { + ExchangeKeys::Psk(keys) => keys.identity_range.clone(), + _ => { + return Err(Error::UnexpectedMessage( + "ECDHE ClientKeyExchange in PSK path".to_string(), + )); + } + }; - drop(maybe); + drop(maybe); - // Get the actual public key data from defragment_buffer - let client_pub = &server.defragment_buffer[public_key_range]; + let identity = &server.defragment_buffer[identity_range]; + trace!("PSK identity ({} bytes)", identity.len()); - // Compute shared secret - let mut buf = server.engine.pop_buffer(); - server - .engine - .crypto_context_mut() - .compute_shared_secret(client_pub, &mut buf) - .map_err(|e| Error::CryptoError(format!("Failed to compute shared secret: {}", e)))?; + // Resolve PSK via the configured resolver + let (psk, psk_valid) = { + let resolver = server + .engine + .config() + .psk_resolver() + .ok_or_else(|| Error::PskError("No PSK resolver configured".to_string()))?; + + match resolver.resolve(identity) { + Some(key) => (key, true), + None => { + // Use a dummy PSK so the handshake proceeds identically + // to a valid-identity flow. It will fail at Finished + // verification, making the two cases indistinguishable. + let dummy = vec![0u8; 32]; // length should match your typical PSK size + (dummy, false) + } + } + }; + + // Saving to server struct + server.psk_valid = psk_valid; + + let crypto = server.engine.crypto_context_mut(); + crypto.set_psk(psk); + crypto + .compute_psk_pre_master_secret() + .map_err(|e| Error::CryptoError(format!("Failed to compute PSK PMS: {}", e)))?; + } else { + // Extract client's public key range before dropping handshake + let public_key_range = match &ckx.exchange_keys { + ExchangeKeys::Ecdh(keys) => keys.public_key_range.clone(), + ExchangeKeys::Psk(_) => { + return Err(Error::UnexpectedMessage( + "PSK ClientKeyExchange in ECDHE path".to_string(), + )); + } + }; + + drop(maybe); + + // Get the actual public key data from defragment_buffer + let client_pub = &server.defragment_buffer[public_key_range]; + + // Compute shared secret + let mut buf = server.engine.pop_buffer(); + server + .engine + .crypto_context_mut() + .compute_shared_secret(client_pub, &mut buf) + .map_err(|e| { + Error::CryptoError(format!("Failed to compute shared secret: {}", e)) + })?; + server.engine.push_buffer(buf); + } // Capture session hash for EMS now (up to ClientKeyExchange) let suite_hash = suite.hash_algorithm(); + let mut buf = server.engine.pop_buffer(); server.engine.transcript_hash(suite_hash, &mut buf); server.captured_session_hash = Some(buf); // Derive master secret and keys (needed to decrypt client's Finished) - let suite_hash = suite.hash_algorithm(); let client_random_buf = { let mut b = Buf::new(); server.client_random.unwrap().serialize(&mut b); @@ -820,6 +950,14 @@ impl State { )); } + // Defense-in-depth: dummy PSK should always fail above, + // but reject explicitly in case it accidentally passes. + if !server.psk_valid { + return Err(Error::SecurityError( + "Client Finished verification failed".to_string(), + )); + } + trace!("Client Finished verified successfully"); Ok(Self::SendChangeCipherSpec) diff --git a/src/error.rs b/src/error.rs index dce6ec5..a3245ab 100644 --- a/src/error.rs +++ b/src/error.rs @@ -16,6 +16,8 @@ pub enum Error { CertificateError(String), /// Security policy violation SecurityError(String), + /// PSK (Pre-Shared Key) error + PskError(String), /// Incoming queue exceeded capacity ReceiveQueueFull, /// Outgoing queue exceeded capacity @@ -71,6 +73,7 @@ impl std::fmt::Display for Error { Error::CryptoError(msg) => write!(f, "crypto error: {}", msg), Error::CertificateError(msg) => write!(f, "certificate error: {}", msg), Error::SecurityError(msg) => write!(f, "security error: {}", msg), + Error::PskError(msg) => write!(f, "psk error: {}", msg), Error::ReceiveQueueFull => write!(f, "receive queue full"), Error::TransmitQueueFull => write!(f, "transmit queue full"), Error::IncompleteServerHello => write!(f, "incomplete ServerHello"), diff --git a/src/lib.rs b/src/lib.rs index ecbe4df..d5183d4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,8 +20,9 @@ //! //! ## Version selection //! -//! Three constructors control which DTLS version is used: -//! - [`Dtls::new_12`][new_12] — explicit DTLS 1.2 +//! Four constructors control which DTLS version is used: +//! - [`Dtls::new_12`][new_12] — explicit DTLS 1.2 (certificate‑based) +//! - [`Dtls::new_12_psk`][new_12_psk] — explicit DTLS 1.2 (PSK, no certificates) //! - [`Dtls::new_13`][new_13] — explicit DTLS 1.3 //! - [`Dtls::new_auto`][new_auto] — auto‑sense: the first //! incoming ClientHello determines the version (based on the @@ -32,6 +33,8 @@ //! - `ECDHE_ECDSA_AES256_GCM_SHA384` //! - `ECDHE_ECDSA_AES128_GCM_SHA256` //! - `ECDHE_ECDSA_CHACHA20_POLY1305_SHA256` +//! - **PSK cipher suites (TLS 1.2 over DTLS)** +//! - `PSK_AES128_CCM_8` //! - **Cipher suites (TLS 1.3 over DTLS)** //! - `TLS_AES_128_GCM_SHA256` //! - `TLS_AES_256_GCM_SHA384` @@ -42,7 +45,6 @@ //! - **DTLS‑SRTP**: Exports keying material for `SRTP_AEAD_AES_256_GCM`, //! `SRTP_AEAD_AES_128_GCM`, and `SRTP_AES128_CM_SHA1_80` ([RFC 5764], [RFC 7714]). //! - **Extended Master Secret** ([RFC 7627]) is negotiated and enforced (DTLS 1.2). -//! - Not supported: PSK cipher suites. //! //! ## Certificate model //! During the handshake the engine emits @@ -132,6 +134,37 @@ //! # } //! ``` //! +//! ## Example (PSK client) +//! +//! ```rust,no_run +//! use std::sync::Arc; +//! use std::time::Instant; +//! +//! use dimpl::{Config, Dtls, PskResolver}; +//! +//! struct MyPsk; +//! +//! impl PskResolver for MyPsk { +//! fn resolve(&self, identity: &[u8]) -> Option> { +//! if identity == b"device-01" { +//! Some(b"shared-secret-key".to_vec()) +//! } else { +//! None +//! } +//! } +//! } +//! +//! let config = Arc::new( +//! Config::builder() +//! .with_psk_client(b"device-01".to_vec(), Arc::new(MyPsk)) +//! .build() +//! .unwrap(), +//! ); +//! +//! let mut dtls = Dtls::new_12_psk(config, Instant::now()); +//! dtls.set_active(true); // client role +//! ``` +//! //! ### MSRV //! Rust 1.85.0 //! @@ -140,6 +173,7 @@ //! - Renegotiation is not implemented (WebRTC does full restart). //! //! [new_12]: https://docs.rs/dimpl/latest/dimpl/struct.Dtls.html#method.new_12 +//! [new_12_psk]: https://docs.rs/dimpl/latest/dimpl/struct.Dtls.html#method.new_12_psk //! [new_13]: https://docs.rs/dimpl/latest/dimpl/struct.Dtls.html#method.new_13 //! [new_auto]: https://docs.rs/dimpl/latest/dimpl/struct.Dtls.html#method.new_auto //! [peer_cert]: https://docs.rs/dimpl/latest/dimpl/enum.Output.html#variant.PeerCert @@ -192,7 +226,7 @@ mod error; pub use error::Error; mod config; -pub use config::Config; +pub use config::{Config, ConfigBuilder, Psk, PskResolver}; #[cfg(feature = "rcgen")] pub mod certificate; @@ -260,6 +294,17 @@ impl Dtls { Dtls { inner: Some(inner) } } + /// Create a new DTLS 1.2 PSK-only instance (no certificate). + /// + /// Call [`set_active(true)`](Self::set_active) to switch to client + /// before the handshake begins. The `config` must have a + /// [`PskResolver`] configured, and for clients a PSK identity + /// via [`ConfigBuilder::with_psk_client`](ConfigBuilder). + pub fn new_12_psk(config: Arc, now: Instant) -> Self { + let inner = Inner::Server12(Server12::new_psk(config, now)); + Dtls { inner: Some(inner) } + } + /// Create a new DTLS 1.3 instance in the server role. /// /// Call [`set_active(true)`](Self::set_active) to switch to client diff --git a/tests/dtls12/crypto.rs b/tests/dtls12/crypto.rs index b95f77f..68bfdca 100644 --- a/tests/dtls12/crypto.rs +++ b/tests/dtls12/crypto.rs @@ -67,7 +67,8 @@ fn dtls12_all_cipher_suites() { let _ = env_logger::try_init(); // Loop over all supported cipher suites and ensure we can connect - for &suite in Dtls12CipherSuite::all().iter() { + // Skip PSK suites — they require PSK config, not certificate-based interop + for &suite in Dtls12CipherSuite::all().iter().filter(|s| !s.is_psk()) { eprintln!("Testing suite (dimpl client ↔️ ossl server): {:?}", suite); run_dimpl_client_vs_ossl_server_for_suite(suite); @@ -101,8 +102,8 @@ fn config_for_suite(suite: Dtls12CipherSuite) -> Arc { fn run_dimpl_client_vs_ossl_server_for_suite(suite: Dtls12CipherSuite) { // Generate certificates for both client and server matching the suite's signature algorithm let pkey_type = match suite.signature_algorithm() { - SignatureAlgorithm::ECDSA => DtlsPKeyType::EcDsaP256, - SignatureAlgorithm::RSA => DtlsPKeyType::Rsa2048, + Some(SignatureAlgorithm::ECDSA) => DtlsPKeyType::EcDsaP256, + Some(SignatureAlgorithm::RSA) => DtlsPKeyType::Rsa2048, _ => panic!("Unsupported signature algorithm in suite: {:?}", suite), }; @@ -211,8 +212,8 @@ fn run_dimpl_client_vs_ossl_server_for_suite(suite: Dtls12CipherSuite) { fn run_ossl_client_vs_dimpl_server_for_suite(suite: Dtls12CipherSuite) { // Generate certificates for both ends let pkey_type = match suite.signature_algorithm() { - SignatureAlgorithm::ECDSA => DtlsPKeyType::EcDsaP256, - SignatureAlgorithm::RSA => DtlsPKeyType::Rsa2048, + Some(SignatureAlgorithm::ECDSA) => DtlsPKeyType::EcDsaP256, + Some(SignatureAlgorithm::RSA) => DtlsPKeyType::Rsa2048, _ => panic!("Unsupported signature algorithm in suite: {:?}", suite), }; diff --git a/tests/dtls12/edge.rs b/tests/dtls12/edge.rs index 1e17cb2..1bb677e 100644 --- a/tests/dtls12/edge.rs +++ b/tests/dtls12/edge.rs @@ -3,46 +3,10 @@ use std::sync::Arc; use std::time::{Duration, Instant}; -use dimpl::{Dtls, Output}; +use dimpl::Dtls; use crate::common::*; -/// Collected outputs from polling a DTLS 1.2 endpoint to `Timeout`. -#[derive(Default, Debug)] -struct DrainedOutputs { - packets: Vec>, - connected: bool, - app_data: Vec>, - timeout: Option, -} - -/// Poll until `Timeout`, collecting everything. -fn drain_outputs(endpoint: &mut Dtls) -> DrainedOutputs { - let mut result = DrainedOutputs::default(); - let mut buf = vec![0u8; 2048]; - loop { - match endpoint.poll_output(&mut buf) { - Output::Packet(p) => result.packets.push(p.to_vec()), - Output::Connected => result.connected = true, - Output::ApplicationData(data) => result.app_data.push(data.to_vec()), - Output::Timeout(t) => { - result.timeout = Some(t); - break; - } - _ => {} - } - } - result -} - -/// Deliver a slice of packets to a destination endpoint. -fn deliver_packets(packets: &[Vec], dest: &mut Dtls) { - for p in packets { - // Ignore errors - they may be expected for duplicates/replays - let _ = dest.handle_packet(p); - } -} - /// Complete a full DTLS 1.2 handshake between client and server. /// /// Returns the final `Instant` (time advanced during the handshake). diff --git a/tests/dtls12/main.rs b/tests/dtls12/main.rs index 329b185..c77bc49 100644 --- a/tests/dtls12/main.rs +++ b/tests/dtls12/main.rs @@ -8,5 +8,6 @@ mod edge; mod fragmentation; mod handshake; mod ossl; +mod psk; mod reorder; mod retransmit; diff --git a/tests/dtls12/ossl.rs b/tests/dtls12/ossl.rs index e1803a8..65a7f28 100644 --- a/tests/dtls12/ossl.rs +++ b/tests/dtls12/ossl.rs @@ -1,10 +1,12 @@ //! DTLS 1.2 interop tests: dimpl <-> OpenSSL (client + server). use std::collections::VecDeque; +use std::io::{self, Read, Write}; use std::sync::Arc; use std::time::Instant; -use dimpl::{Config, Dtls, Output}; +use dimpl::crypto::Dtls12CipherSuite; +use dimpl::{Config, Dtls, Output, PskResolver}; use crate::ossl_helper::{DtlsCertOptions, DtlsEvent, OsslDtlsCert}; @@ -892,3 +894,400 @@ fn dtls12_ossl_server_bidirectional_data() { "Client should receive both server messages" ); } + +// ============================================================================ +// PSK interop tests +// ============================================================================ + +const PSK_IDENTITY: &[u8] = b"test-device"; +const PSK_KEY: &[u8] = b"0123456789abcdef"; // 16 bytes + +struct FixedPsk; + +impl PskResolver for FixedPsk { + fn resolve(&self, identity: &[u8]) -> Option> { + if identity == PSK_IDENTITY { + Some(PSK_KEY.to_vec()) + } else { + None + } + } +} + +fn psk_provider() -> dimpl::crypto::CryptoProvider { + let mut provider = Config::default().crypto_provider().clone(); + let psk_suite = provider + .cipher_suites + .iter() + .copied() + .find(|cs| cs.suite() == Dtls12CipherSuite::PSK_AES128_CCM_8) + .expect("PSK_AES128_CCM_8 not in provider"); + + let suites = Box::leak(Box::new([psk_suite])); + provider.cipher_suites = suites; + provider +} + +fn psk_dimpl_client_config() -> Arc { + Arc::new( + Config::builder() + .with_crypto_provider(psk_provider()) + .with_psk_client(PSK_IDENTITY.to_vec(), Arc::new(FixedPsk)) + .build() + .expect("build PSK client config"), + ) +} + +fn psk_dimpl_server_config() -> Arc { + Arc::new( + Config::builder() + .with_crypto_provider(psk_provider()) + .with_psk_server(Some(b"hint".to_vec()), Arc::new(FixedPsk)) + .build() + .expect("build PSK server config"), + ) +} + +/// Create an OpenSSL PSK DTLS context configured as server. +fn ossl_psk_server() -> openssl::ssl::Ssl { + use openssl::ssl::{SslContextBuilder, SslMethod, SslOptions, SslVerifyMode}; + + let mut ctx = SslContextBuilder::new(SslMethod::dtls()).unwrap(); + ctx.set_cipher_list("PSK-AES128-CCM8").unwrap(); + + // No peer cert verification for PSK + ctx.set_verify(SslVerifyMode::NONE); + + let mut options = SslOptions::empty(); + options.insert(SslOptions::NO_DTLSV1); + ctx.set_options(options); + + ctx.set_psk_server_callback(|_ssl, identity, psk_out| { + if let Some(id) = identity { + if id == PSK_IDENTITY { + psk_out[..PSK_KEY.len()].copy_from_slice(PSK_KEY); + return Ok(PSK_KEY.len()); + } + } + Ok(0) + }); + + let ctx = ctx.build(); + let mut ssl = openssl::ssl::Ssl::new(&ctx).unwrap(); + ssl.set_mtu(1150).expect("set MTU"); + ssl +} + +/// Create an OpenSSL PSK DTLS context configured as client. +fn ossl_psk_client() -> openssl::ssl::Ssl { + use openssl::ssl::{SslContextBuilder, SslMethod, SslOptions, SslVerifyMode}; + + let mut ctx = SslContextBuilder::new(SslMethod::dtls()).unwrap(); + ctx.set_cipher_list("PSK-AES128-CCM8").unwrap(); + + ctx.set_verify(SslVerifyMode::NONE); + + let mut options = SslOptions::empty(); + options.insert(SslOptions::NO_DTLSV1); + ctx.set_options(options); + + ctx.set_psk_client_callback(|_ssl, _hint, identity_out, psk_out| { + identity_out[..PSK_IDENTITY.len()].copy_from_slice(PSK_IDENTITY); + identity_out[PSK_IDENTITY.len()] = 0; // null terminate + psk_out[..PSK_KEY.len()].copy_from_slice(PSK_KEY); + Ok(PSK_KEY.len()) + }); + + let ctx = ctx.build(); + let mut ssl = openssl::ssl::Ssl::new(&ctx).unwrap(); + ssl.set_mtu(1150).expect("set MTU"); + ssl +} + +type IoBuffer = crate::ossl_helper::io_buf::IoBuffer; + +/// A minimal OpenSSL PSK endpoint. No certs, no SRTP — just PSK handshake + data. +struct OsslPskEndpoint { + active: bool, + state: Option, +} + +enum OsslPskState { + Init(openssl::ssl::Ssl, IoBuffer), + Handshaking(openssl::ssl::MidHandshakeSslStream), + Established(openssl::ssl::SslStream), +} + +impl OsslPskEndpoint { + fn new(ssl: openssl::ssl::Ssl, active: bool) -> Self { + OsslPskEndpoint { + active, + state: Some(OsslPskState::Init(ssl, IoBuffer::default())), + } + } + + fn io_buf(&mut self) -> &mut IoBuffer { + match self.state.as_mut().expect("state") { + OsslPskState::Init(_, buf) => buf, + OsslPskState::Handshaking(mid) => mid.get_mut(), + OsslPskState::Established(stream) => stream.get_mut(), + } + } + + /// Feed incoming data and drive the handshake. Returns true on first connect. + fn handle_receive(&mut self, data: &[u8]) -> bool { + self.io_buf().set_incoming(data); + self.drive_handshake() + } + + fn drive_handshake(&mut self) -> bool { + let taken = self.state.take().expect("state"); + + let result = match taken { + OsslPskState::Init(ssl, buf) => { + if self.active { + ssl.connect(buf) + } else { + ssl.accept(buf) + } + } + OsslPskState::Handshaking(mid) => mid.handshake(), + OsslPskState::Established(stream) => { + self.state = Some(OsslPskState::Established(stream)); + return false; + } + }; + + match result { + Ok(stream) => { + self.state = Some(OsslPskState::Established(stream)); + true + } + Err(openssl::ssl::HandshakeError::WouldBlock(mid)) => { + self.state = Some(OsslPskState::Handshaking(mid)); + false + } + Err(e) => panic!("OpenSSL PSK handshake error: {:?}", e), + } + } + + fn poll_datagram(&mut self) -> Option { + self.io_buf().pop_outgoing() + } + + fn send_data(&mut self, data: &[u8]) { + if let Some(OsslPskState::Established(stream)) = &mut self.state { + stream.write_all(data).expect("send data"); + } else { + panic!("not connected"); + } + } + + fn read_data(&mut self) -> Option> { + if let Some(OsslPskState::Established(stream)) = &mut self.state { + let mut buf = vec![0u8; 2000]; + match stream.read(&mut buf) { + Ok(n) => { + buf.truncate(n); + Some(buf) + } + Err(e) if e.kind() == io::ErrorKind::WouldBlock => None, + Err(e) => panic!("read error: {:?}", e), + } + } else { + None + } + } +} + +#[test] +#[ignore = "OpenSSL does not support PSK-AES128-CCM8 over DTLS (only TLS)"] +fn dtls12_ossl_psk_dimpl_client_ossl_server() { + env_logger::try_init().ok(); + + let config = psk_dimpl_client_config(); + let now = Instant::now(); + + let mut client = Dtls::new_12_psk(config, now); + client.set_active(true); + + let ssl = ossl_psk_server(); + let mut server = OsslPskEndpoint::new(ssl, false); + + let mut client_connected = false; + let mut server_connected = false; + let mut out_buf = vec![0u8; 2048]; + + for _ in 0..30 { + client.handle_timeout(Instant::now()).unwrap(); + + // Poll dimpl client → OpenSSL server + loop { + match client.poll_output(&mut out_buf) { + Output::Packet(data) => { + if server.handle_receive(data) { + server_connected = true; + } + } + Output::Connected => { + client_connected = true; + } + Output::Timeout(_) => break, + _ => {} + } + } + + // Poll OpenSSL server → dimpl client + while let Some(datagram) = server.poll_datagram() { + client.handle_packet(&datagram).expect("handle server pkt"); + } + + // Poll dimpl again after receiving server packets + loop { + match client.poll_output(&mut out_buf) { + Output::Packet(data) => { + if server.handle_receive(data) { + server_connected = true; + } + } + Output::Connected => { + client_connected = true; + } + Output::Timeout(_) => break, + _ => {} + } + } + + // Drive OpenSSL again in case dimpl sent more + while let Some(datagram) = server.poll_datagram() { + client.handle_packet(&datagram).expect("handle server pkt"); + } + + if client_connected && server_connected { + break; + } + } + + assert!(client_connected, "dimpl PSK client should connect"); + assert!(server_connected, "OpenSSL PSK server should connect"); + + // App data: client → server + client + .send_application_data(b"hello from dimpl") + .expect("send"); + loop { + match client.poll_output(&mut out_buf) { + Output::Packet(data) => { + server.handle_receive(data); + } + Output::Timeout(_) => break, + _ => {} + } + } + + let received = server.read_data().expect("server should receive data"); + assert_eq!(received, b"hello from dimpl"); + + // App data: server → client + server.send_data(b"hello from openssl"); + while let Some(datagram) = server.poll_datagram() { + client.handle_packet(&datagram).expect("handle server pkt"); + } + + let mut client_data = Vec::new(); + loop { + match client.poll_output(&mut out_buf) { + Output::ApplicationData(data) => client_data.extend_from_slice(data), + Output::Timeout(_) => break, + _ => {} + } + } + assert_eq!(client_data, b"hello from openssl"); +} + +#[test] +#[ignore = "OpenSSL does not support PSK-AES128-CCM8 over DTLS (only TLS)"] +fn dtls12_ossl_psk_ossl_client_dimpl_server() { + env_logger::try_init().ok(); + + let config = psk_dimpl_server_config(); + let now = Instant::now(); + + let mut server = Dtls::new_12_psk(config, now); + server.set_active(false); + + let ssl = ossl_psk_client(); + let mut client = OsslPskEndpoint::new(ssl, true); + + // Kick off OpenSSL client handshake + client.handle_receive(&[]); + + let mut server_connected = false; + let mut client_connected = false; + let mut out_buf = vec![0u8; 2048]; + + for _ in 0..30 { + // Poll OpenSSL client → dimpl server + while let Some(datagram) = client.poll_datagram() { + server.handle_packet(&datagram).expect("handle client pkt"); + } + + server.handle_timeout(Instant::now()).unwrap(); + + // Poll dimpl server → OpenSSL client + loop { + match server.poll_output(&mut out_buf) { + Output::Packet(data) => { + if client.handle_receive(data) { + client_connected = true; + } + } + Output::Connected => { + server_connected = true; + } + Output::Timeout(_) => break, + _ => {} + } + } + + if client_connected && server_connected { + break; + } + } + + assert!(client_connected, "OpenSSL PSK client should connect"); + assert!(server_connected, "dimpl PSK server should connect"); + + // App data: OpenSSL client → dimpl server + client.send_data(b"hello from openssl client"); + while let Some(datagram) = client.poll_datagram() { + server.handle_packet(&datagram).expect("handle client pkt"); + } + + let mut server_data = Vec::new(); + loop { + match server.poll_output(&mut out_buf) { + Output::ApplicationData(data) => server_data.extend_from_slice(data), + Output::Timeout(_) => break, + _ => {} + } + } + assert_eq!(server_data, b"hello from openssl client"); + + // App data: dimpl server → OpenSSL client + server + .send_application_data(b"hello from dimpl server") + .expect("send"); + loop { + match server.poll_output(&mut out_buf) { + Output::Packet(data) => { + client.handle_receive(data); + } + Output::Timeout(_) => break, + _ => {} + } + } + + let received = client.read_data().expect("client should receive data"); + assert_eq!(received, b"hello from dimpl server"); +} diff --git a/tests/dtls12/psk.rs b/tests/dtls12/psk.rs new file mode 100644 index 0000000..5be0ab5 --- /dev/null +++ b/tests/dtls12/psk.rs @@ -0,0 +1,331 @@ +//! DTLS 1.2 PSK handshake tests. + +use std::sync::Arc; +use std::time::Instant; + +use dimpl::crypto::Dtls12CipherSuite; +use dimpl::{Config, Dtls, Error, PskResolver}; + +use crate::common::{deliver_packets, drain_outputs}; + +/// Simple PSK resolver that returns a fixed key for a known identity. +struct FixedPsk { + identity: Vec, + key: Vec, +} + +impl PskResolver for FixedPsk { + fn resolve(&self, identity: &[u8]) -> Option> { + if identity == self.identity { + Some(self.key.clone()) + } else { + None + } + } +} + +fn psk_provider(suite: Dtls12CipherSuite) -> dimpl::crypto::CryptoProvider { + let mut provider = Config::default().crypto_provider().clone(); + let psk_suite = provider + .cipher_suites + .iter() + .copied() + .find(|cs| cs.suite() == suite) + .unwrap_or_else(|| panic!("{:?} not in provider", suite)); + + let suites = Box::leak(Box::new([psk_suite])); + provider.cipher_suites = suites; + provider +} + +/// Returns (client_config, server_config) for PSK tests. +fn psk_configs_for_suite(suite: Dtls12CipherSuite) -> (Arc, Arc) { + let identity = b"test-device".to_vec(); + let key = b"0123456789abcdef".to_vec(); // 16 bytes + + let resolver = Arc::new(FixedPsk { + identity: identity.clone(), + key, + }); + + let provider = psk_provider(suite); + + let client = Arc::new( + Config::builder() + .with_crypto_provider(provider.clone()) + .with_psk_client(identity, resolver.clone()) + .build() + .expect("build PSK client config"), + ); + + let server = Arc::new( + Config::builder() + .with_crypto_provider(provider) + .with_psk_server(Some(b"hint".to_vec()), resolver) + .build() + .expect("build PSK server config"), + ); + + (client, server) +} + +fn psk_configs() -> (Arc, Arc) { + psk_configs_for_suite(Dtls12CipherSuite::PSK_AES128_CCM_8) +} + +#[test] +fn dtls12_psk_self_handshake() { + let _ = env_logger::try_init(); + + let (client_config, server_config) = psk_configs(); + let now = Instant::now(); + + let mut client = Dtls::new_12_psk(client_config, now); + client.set_active(true); + + let mut server = Dtls::new_12_psk(server_config, now); + server.set_active(false); + + let mut client_connected = false; + let mut server_connected = false; + + for _ in 0..60 { + client.handle_timeout(Instant::now()).unwrap(); + server.handle_timeout(Instant::now()).unwrap(); + + // Drain client → server + let client_out = drain_outputs(&mut client); + if client_out.connected { + client_connected = true; + } + deliver_packets(&client_out.packets, &mut server); + + // Drain server → client + let server_out = drain_outputs(&mut server); + if server_out.connected { + server_connected = true; + } + deliver_packets(&server_out.packets, &mut client); + + if client_connected && server_connected { + break; + } + } + + assert!(client_connected, "PSK client should connect"); + assert!(server_connected, "PSK server should connect"); +} + +#[test] +fn dtls12_psk_application_data_roundtrip() { + let _ = env_logger::try_init(); + + let (client_config, server_config) = psk_configs(); + let now = Instant::now(); + + let mut client = Dtls::new_12_psk(client_config, now); + client.set_active(true); + + let mut server = Dtls::new_12_psk(server_config, now); + server.set_active(false); + + // Complete handshake + for _ in 0..60 { + client.handle_timeout(Instant::now()).unwrap(); + server.handle_timeout(Instant::now()).unwrap(); + + let co = drain_outputs(&mut client); + deliver_packets(&co.packets, &mut server); + + let so = drain_outputs(&mut server); + deliver_packets(&so.packets, &mut client); + + if co.connected || so.connected { + // One more round to let both sides finish + client.handle_timeout(Instant::now()).unwrap(); + server.handle_timeout(Instant::now()).unwrap(); + + let co2 = drain_outputs(&mut client); + deliver_packets(&co2.packets, &mut server); + + let so2 = drain_outputs(&mut server); + deliver_packets(&so2.packets, &mut client); + break; + } + } + + // Send data client → server + let payload = b"Hello from PSK client!"; + client + .send_application_data(payload) + .expect("send app data"); + + let co = drain_outputs(&mut client); + deliver_packets(&co.packets, &mut server); + + let so = drain_outputs(&mut server); + assert!( + so.app_data.iter().any(|d| d == payload), + "Server should receive client's application data" + ); + + // Send data server → client + let reply = b"Hello from PSK server!"; + server.send_application_data(reply).expect("send app data"); + + let so = drain_outputs(&mut server); + deliver_packets(&so.packets, &mut client); + + let co = drain_outputs(&mut client); + assert!( + co.app_data.iter().any(|d| d == reply), + "Client should receive server's application data" + ); +} + +#[test] +fn psk_invalid_identity_fails_at_finished() { + let _ = env_logger::try_init(); + + struct FailingResolver; + impl PskResolver for FailingResolver { + fn resolve(&self, _identity: &[u8]) -> Option> { + None + } + } + + struct PassingResolver; + impl PskResolver for PassingResolver { + fn resolve(&self, _identity: &[u8]) -> Option> { + Some(vec![0u8; 32]) + } + } + + let server_config = dimpl::Config::builder() + .with_psk_server(None, Arc::new(FailingResolver)) + .build() + .expect("server config should build"); + let mut server = Dtls::new_12_psk(Arc::new(server_config), Instant::now()); + + let client_config = dimpl::Config::builder() + .with_psk_client(b"test_identity".to_vec(), Arc::new(PassingResolver)) + .build() + .expect("client config should build"); + let mut client = Dtls::new_12_psk(Arc::new(client_config), Instant::now()); + client.set_active(true); + + // Drive the handshake; expect a SecurityError from mismatched PSK keys. + let mut error_found = false; + for _ in 0..60 { + if let Err(e) = client.handle_timeout(Instant::now()) { + assert!( + matches!(e, Error::SecurityError(_)), + "unexpected error: {e:?}" + ); + error_found = true; + break; + } + let co = drain_outputs(&mut client); + for p in &co.packets { + if let Err(e) = server.handle_packet(p) { + assert!( + matches!(e, Error::SecurityError(_)), + "unexpected error: {e:?}" + ); + error_found = true; + break; + } + } + if error_found { + break; + } + assert!( + !co.connected, + "client should not connect with mismatched PSK" + ); + + if let Err(e) = server.handle_timeout(Instant::now()) { + assert!( + matches!(e, Error::SecurityError(_)), + "unexpected error: {e:?}" + ); + error_found = true; + break; + } + let so = drain_outputs(&mut server); + for p in &so.packets { + if let Err(e) = client.handle_packet(p) { + assert!( + matches!(e, Error::SecurityError(_)), + "unexpected error: {e:?}" + ); + error_found = true; + break; + } + } + if error_found { + break; + } + assert!( + !so.connected, + "server should not connect with mismatched PSK" + ); + } + + assert!( + error_found, + "Expected SecurityError from PSK verification failure" + ); +} + +#[test] +fn psk_valid_identity_succeeds() { + let _ = env_logger::try_init(); + + struct AlwaysPassResolver; + impl PskResolver for AlwaysPassResolver { + fn resolve(&self, _identity: &[u8]) -> Option> { + Some(vec![0u8; 32]) + } + } + + let server_config = dimpl::Config::builder() + .with_psk_server(None, Arc::new(AlwaysPassResolver)) + .build() + .expect("server config should build"); + let mut server = Dtls::new_12_psk(Arc::new(server_config), Instant::now()); + + let client_config = dimpl::Config::builder() + .with_psk_client(b"test_identity".to_vec(), Arc::new(AlwaysPassResolver)) + .build() + .expect("client config should build"); + let mut client = Dtls::new_12_psk(Arc::new(client_config), Instant::now()); + client.set_active(true); + + let mut client_connected = false; + let mut server_connected = false; + + for _ in 0..60 { + client.handle_timeout(Instant::now()).unwrap(); + server.handle_timeout(Instant::now()).unwrap(); + + let co = drain_outputs(&mut client); + if co.connected { + client_connected = true; + } + deliver_packets(&co.packets, &mut server); + + let so = drain_outputs(&mut server); + if so.connected { + server_connected = true; + } + deliver_packets(&so.packets, &mut client); + + if client_connected && server_connected { + break; + } + } + + assert!(client_connected, "PSK client should connect"); + assert!(server_connected, "PSK server should connect"); +} diff --git a/tests/ossl/io_buf.rs b/tests/ossl/io_buf.rs index 6288941..f84daea 100644 --- a/tests/ossl/io_buf.rs +++ b/tests/ossl/io_buf.rs @@ -14,7 +14,7 @@ impl Deref for DatagramSend { } } -#[derive(Default)] +#[derive(Default, Debug)] pub struct IoBuffer { pub incoming: Vec, pub outgoing: VecDeque, diff --git a/tests/ossl/mod.rs b/tests/ossl/mod.rs index f1b431c..56bc60e 100644 --- a/tests/ossl/mod.rs +++ b/tests/ossl/mod.rs @@ -29,7 +29,7 @@ use std::io; pub use cert::{DtlsCertOptions, DtlsPKeyType, Fingerprint, OsslDtlsCert}; -mod io_buf; +pub mod io_buf; mod stream; mod dtls;