diff --git a/CHANGELOG.md b/CHANGELOG.md index 83205cf0..82ea31f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ # Unreleased + * Remove PrfProvider/HkdfProvider, derive from HmacProvider (breaking) #94 + # 0.4.3 * Fix server auto-sensing DTLS version with fragmented ClientHello #87 diff --git a/src/crypto/aws_lc_rs/hkdf.rs b/src/crypto/aws_lc_rs/hkdf.rs deleted file mode 100644 index 0af640d6..00000000 --- a/src/crypto/aws_lc_rs/hkdf.rs +++ /dev/null @@ -1,401 +0,0 @@ -//! HKDF implementation using aws-lc-rs for TLS 1.3 key derivation. - -use aws_lc_rs::hkdf::{HKDF_SHA256, HKDF_SHA384, KeyType, Prk}; -use aws_lc_rs::hmac; - -use super::super::HkdfProvider; -use crate::buffer::Buf; -use crate::types::HashAlgorithm; - -/// Custom KeyType implementation for arbitrary output lengths. -struct OutputLen(usize); - -impl KeyType for OutputLen { - fn len(&self) -> usize { - self.0 - } -} - -/// HKDF provider implementation using aws-lc-rs. -#[derive(Debug)] -pub(super) struct AwsLcHkdfProvider; - -impl HkdfProvider for AwsLcHkdfProvider { - fn hkdf_extract( - &self, - hash: HashAlgorithm, - salt: &[u8], - ikm: &[u8], - out: &mut Buf, - ) -> Result<(), String> { - out.clear(); - - // HKDF-Extract is defined as HMAC-Hash(salt, IKM) - // Per RFC 5869: PRK = HMAC-Hash(salt, IKM) - // If salt is empty, use a string of HashLen zeros - let hash_len = hash.output_len(); - let algorithm = match hash { - HashAlgorithm::SHA256 => hmac::HMAC_SHA256, - HashAlgorithm::SHA384 => hmac::HMAC_SHA384, - _ => return Err(format!("Unsupported hash for HKDF: {:?}", hash)), - }; - - // If salt is empty, use zero-filled salt of hash length - let salt_bytes: Vec; - let actual_salt = if salt.is_empty() { - salt_bytes = vec![0u8; hash_len]; - &salt_bytes[..] - } else { - salt - }; - - let key = hmac::Key::new(algorithm, actual_salt); - let prk = hmac::sign(&key, ikm); - - out.extend_from_slice(prk.as_ref()); - Ok(()) - } - - fn hkdf_expand( - &self, - hash: HashAlgorithm, - prk: &[u8], - info: &[u8], - out: &mut Buf, - output_len: usize, - ) -> Result<(), String> { - out.clear(); - - let algorithm = match hash { - HashAlgorithm::SHA256 => HKDF_SHA256, - HashAlgorithm::SHA384 => HKDF_SHA384, - _ => return Err(format!("Unsupported hash for HKDF: {:?}", hash)), - }; - - let prk = Prk::new_less_safe(algorithm, prk); - let info_slice = [info]; - let okm = prk - .expand(&info_slice, OutputLen(output_len)) - .map_err(|e| format!("HKDF expand failed: {:?}", e))?; - - let mut output = vec![0u8; output_len]; - okm.fill(&mut output) - .map_err(|e| format!("HKDF fill failed: {:?}", e))?; - - out.extend_from_slice(&output); - Ok(()) - } - - fn hkdf_expand_label( - &self, - hash: HashAlgorithm, - secret: &[u8], - label: &[u8], - context: &[u8], - out: &mut Buf, - output_len: usize, - ) -> Result<(), String> { - // Build the HkdfLabel structure per RFC 8446 Section 7.1: - // - // struct { - // uint16 length = Length; - // opaque label<7..255> = "tls13 " + Label; - // opaque context<0..255> = Context; - // } HkdfLabel; - // - // The label must be prefixed with "tls13 " (6 bytes) - - let full_label_len = 6 + label.len(); // "tls13 " + label - - if full_label_len > 255 { - return Err("Label too long for HKDF-Expand-Label".to_string()); - } - if context.len() > 255 { - return Err("Context too long for HKDF-Expand-Label".to_string()); - } - if output_len > 65535 { - return Err("Output length too large for HKDF-Expand-Label".to_string()); - } - - // Build the info (HkdfLabel) - let info_len = 2 + 1 + full_label_len + 1 + context.len(); - let mut info = Vec::with_capacity(info_len); - - // uint16 length - info.extend_from_slice(&(output_len as u16).to_be_bytes()); - - // opaque label<7..255> = "tls13 " + Label - info.push(full_label_len as u8); - info.extend_from_slice(b"tls13 "); - info.extend_from_slice(label); - - // opaque context<0..255> - info.push(context.len() as u8); - info.extend_from_slice(context); - - // Now do regular HKDF-Expand - self.hkdf_expand(hash, secret, &info, out, output_len) - } - - fn hkdf_expand_label_dtls13( - &self, - hash: HashAlgorithm, - secret: &[u8], - label: &[u8], - context: &[u8], - out: &mut Buf, - output_len: usize, - ) -> Result<(), String> { - // Build the HkdfLabel structure for DTLS 1.3 per RFC 9147: - // - // struct { - // uint16 length = Length; - // opaque label<6..255> = "dtls13" + Label; - // opaque context<0..255> = Context; - // } HkdfLabel; - // - // Note: DTLS 1.3 uses "dtls13" prefix (6 bytes, no space) instead of "tls13 " - - let full_label_len = 6 + label.len(); // "dtls13" + label - - if full_label_len > 255 { - return Err("Label too long for HKDF-Expand-Label".to_string()); - } - if context.len() > 255 { - return Err("Context too long for HKDF-Expand-Label".to_string()); - } - if output_len > 65535 { - return Err("Output length too large for HKDF-Expand-Label".to_string()); - } - - // Build the info (HkdfLabel) - let info_len = 2 + 1 + full_label_len + 1 + context.len(); - let mut info = Vec::with_capacity(info_len); - - // uint16 length - info.extend_from_slice(&(output_len as u16).to_be_bytes()); - - // opaque label<6..255> = "dtls13" + Label - info.push(full_label_len as u8); - info.extend_from_slice(b"dtls13"); - info.extend_from_slice(label); - - // opaque context<0..255> - info.push(context.len() as u8); - info.extend_from_slice(context); - - // Now do regular HKDF-Expand - self.hkdf_expand(hash, secret, &info, out, output_len) - } -} - -/// Static instance of the HKDF provider. -pub(super) static HKDF_PROVIDER: AwsLcHkdfProvider = AwsLcHkdfProvider; - -#[cfg(test)] -mod tests { - use super::*; - - // RFC 5869 Test Case 1 - Basic test case with SHA-256 - #[test] - fn test_hkdf_sha256_rfc5869_case1() { - let provider = AwsLcHkdfProvider; - - // IKM = 0x0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b (22 bytes) - let ikm = [0x0b; 22]; - - // salt = 0x000102030405060708090a0b0c (13 bytes) - let salt = [ - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, - ]; - - // info = 0xf0f1f2f3f4f5f6f7f8f9 (10 bytes) - let info = [0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9]; - - // Expected PRK (32 bytes) - let expected_prk = [ - 0x07, 0x77, 0x09, 0x36, 0x2c, 0x2e, 0x32, 0xdf, 0x0d, 0xdc, 0x3f, 0x0d, 0xc4, 0x7b, - 0xba, 0x63, 0x90, 0xb6, 0xc7, 0x3b, 0xb5, 0x0f, 0x9c, 0x31, 0x22, 0xec, 0x84, 0x4a, - 0xd7, 0xc2, 0xb3, 0xe5, - ]; - - // Expected OKM (42 bytes) - let expected_okm = [ - 0x3c, 0xb2, 0x5f, 0x25, 0xfa, 0xac, 0xd5, 0x7a, 0x90, 0x43, 0x4f, 0x64, 0xd0, 0x36, - 0x2f, 0x2a, 0x2d, 0x2d, 0x0a, 0x90, 0xcf, 0x1a, 0x5a, 0x4c, 0x5d, 0xb0, 0x2d, 0x56, - 0xec, 0xc4, 0xc5, 0xbf, 0x34, 0x00, 0x72, 0x08, 0xd5, 0xb8, 0x87, 0x18, 0x58, 0x65, - ]; - - // Test extract - let mut prk = Buf::new(); - provider - .hkdf_extract(HashAlgorithm::SHA256, &salt, &ikm, &mut prk) - .unwrap(); - assert_eq!(&*prk, &expected_prk[..]); - - // Test expand - let mut okm = Buf::new(); - provider - .hkdf_expand(HashAlgorithm::SHA256, &prk, &info, &mut okm, 42) - .unwrap(); - assert_eq!(&*okm, &expected_okm[..]); - } - - // RFC 5869 Test Case 2 - Longer inputs/outputs with SHA-256 - #[test] - fn test_hkdf_sha256_rfc5869_case2() { - let provider = AwsLcHkdfProvider; - - // IKM = 0x000102...4f (80 bytes) - let ikm: Vec = (0x00..=0x4f).collect(); - - // salt = 0x606162...af (80 bytes) - let salt: Vec = (0x60..=0xaf).collect(); - - // info = 0xb0b1b2...ff (80 bytes) - let info: Vec = (0xb0..=0xff).collect(); - - // Expected PRK (32 bytes) - let expected_prk = [ - 0x06, 0xa6, 0xb8, 0x8c, 0x58, 0x53, 0x36, 0x1a, 0x06, 0x10, 0x4c, 0x9c, 0xeb, 0x35, - 0xb4, 0x5c, 0xef, 0x76, 0x00, 0x14, 0x90, 0x46, 0x71, 0x01, 0x4a, 0x19, 0x3f, 0x40, - 0xc1, 0x5f, 0xc2, 0x44, - ]; - - // Expected OKM (82 bytes) - let expected_okm = [ - 0xb1, 0x1e, 0x39, 0x8d, 0xc8, 0x03, 0x27, 0xa1, 0xc8, 0xe7, 0xf7, 0x8c, 0x59, 0x6a, - 0x49, 0x34, 0x4f, 0x01, 0x2e, 0xda, 0x2d, 0x4e, 0xfa, 0xd8, 0xa0, 0x50, 0xcc, 0x4c, - 0x19, 0xaf, 0xa9, 0x7c, 0x59, 0x04, 0x5a, 0x99, 0xca, 0xc7, 0x82, 0x72, 0x71, 0xcb, - 0x41, 0xc6, 0x5e, 0x59, 0x0e, 0x09, 0xda, 0x32, 0x75, 0x60, 0x0c, 0x2f, 0x09, 0xb8, - 0x36, 0x77, 0x93, 0xa9, 0xac, 0xa3, 0xdb, 0x71, 0xcc, 0x30, 0xc5, 0x81, 0x79, 0xec, - 0x3e, 0x87, 0xc1, 0x4c, 0x01, 0xd5, 0xc1, 0xf3, 0x43, 0x4f, 0x1d, 0x87, - ]; - - // Test extract - let mut prk = Buf::new(); - provider - .hkdf_extract(HashAlgorithm::SHA256, &salt, &ikm, &mut prk) - .unwrap(); - assert_eq!(&*prk, &expected_prk[..]); - - // Test expand - let mut okm = Buf::new(); - provider - .hkdf_expand(HashAlgorithm::SHA256, &prk, &info, &mut okm, 82) - .unwrap(); - assert_eq!(&*okm, &expected_okm[..]); - } - - // RFC 5869 Test Case 3 - Zero-length salt and info with SHA-256 - #[test] - fn test_hkdf_sha256_rfc5869_case3() { - let provider = AwsLcHkdfProvider; - - // IKM = 0x0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b (22 bytes) - let ikm = [0x0b; 22]; - - // salt = empty - let salt: [u8; 0] = []; - - // info = empty - let info: [u8; 0] = []; - - // Expected PRK (32 bytes) - let expected_prk = [ - 0x19, 0xef, 0x24, 0xa3, 0x2c, 0x71, 0x7b, 0x16, 0x7f, 0x33, 0xa9, 0x1d, 0x6f, 0x64, - 0x8b, 0xdf, 0x96, 0x59, 0x67, 0x76, 0xaf, 0xdb, 0x63, 0x77, 0xac, 0x43, 0x4c, 0x1c, - 0x29, 0x3c, 0xcb, 0x04, - ]; - - // Expected OKM (42 bytes) - let expected_okm = [ - 0x8d, 0xa4, 0xe7, 0x75, 0xa5, 0x63, 0xc1, 0x8f, 0x71, 0x5f, 0x80, 0x2a, 0x06, 0x3c, - 0x5a, 0x31, 0xb8, 0xa1, 0x1f, 0x5c, 0x5e, 0xe1, 0x87, 0x9e, 0xc3, 0x45, 0x4e, 0x5f, - 0x3c, 0x73, 0x8d, 0x2d, 0x9d, 0x20, 0x13, 0x95, 0xfa, 0xa4, 0xb6, 0x1a, 0x96, 0xc8, - ]; - - // Test extract - let mut prk = Buf::new(); - provider - .hkdf_extract(HashAlgorithm::SHA256, &salt, &ikm, &mut prk) - .unwrap(); - assert_eq!(&*prk, &expected_prk[..]); - - // Test expand - let mut okm = Buf::new(); - provider - .hkdf_expand(HashAlgorithm::SHA256, &prk, &info, &mut okm, 42) - .unwrap(); - assert_eq!(&*okm, &expected_okm[..]); - } - - // Test HKDF-Expand-Label structure is built correctly - #[test] - fn test_hkdf_expand_label_basic() { - let provider = AwsLcHkdfProvider; - let secret = [0u8; 32]; - let mut out = Buf::new(); - - // Should succeed with valid inputs - provider - .hkdf_expand_label(HashAlgorithm::SHA256, &secret, b"key", &[], &mut out, 16) - .unwrap(); - assert_eq!(out.len(), 16); - - // Should succeed with context - provider - .hkdf_expand_label( - HashAlgorithm::SHA256, - &secret, - b"iv", - &[1, 2, 3, 4], - &mut out, - 12, - ) - .unwrap(); - assert_eq!(out.len(), 12); - } - - // Test DTLS 1.3 expand label - #[test] - fn test_hkdf_expand_label_dtls13_basic() { - let provider = AwsLcHkdfProvider; - let secret = [0u8; 32]; - let mut out = Buf::new(); - - // Should succeed with valid inputs - provider - .hkdf_expand_label_dtls13(HashAlgorithm::SHA256, &secret, b"key", &[], &mut out, 16) - .unwrap(); - assert_eq!(out.len(), 16); - - // TLS 1.3 and DTLS 1.3 with same inputs should produce different outputs - // due to different label prefixes ("tls13 " vs "dtls13") - let mut tls_out = Buf::new(); - let mut dtls_out = Buf::new(); - - provider - .hkdf_expand_label( - HashAlgorithm::SHA256, - &secret, - b"key", - &[], - &mut tls_out, - 16, - ) - .unwrap(); - provider - .hkdf_expand_label_dtls13( - HashAlgorithm::SHA256, - &secret, - b"key", - &[], - &mut dtls_out, - 16, - ) - .unwrap(); - - assert_ne!(&*tls_out, &*dtls_out); - } -} diff --git a/src/crypto/aws_lc_rs/hmac.rs b/src/crypto/aws_lc_rs/hmac.rs index 97a324cb..eaadc3ce 100644 --- a/src/crypto/aws_lc_rs/hmac.rs +++ b/src/crypto/aws_lc_rs/hmac.rs @@ -1,48 +1,12 @@ -//! HMAC utilities using aws-lc-rs. +//! HMAC implementation using aws-lc-rs. use aws_lc_rs::hmac; use super::super::HmacProvider; -use crate::buffer::Buf; use crate::types::HashAlgorithm; -/// Compute HMAC using TLS 1.2 P_hash algorithm. -pub(super) fn p_hash( - algorithm: hmac::Algorithm, - secret: &[u8], - full_seed: &[u8], - out: &mut Buf, - output_len: usize, -) -> Result<(), String> { - out.clear(); - - let key = hmac::Key::new(algorithm, secret); - - // A(1) = HMAC_hash(secret, A(0)) where A(0) = seed - let mut a = hmac::sign(&key, full_seed); - - while out.len() < output_len { - // HMAC_hash(secret, A(i) + seed) - let mut ctx = hmac::Context::with_key(&key); - ctx.update(a.as_ref()); - ctx.update(full_seed); - let output = ctx.sign(); - - let remaining = output_len - out.len(); - let to_copy = std::cmp::min(remaining, output.as_ref().len()); - out.extend_from_slice(&output.as_ref()[..to_copy]); - - if out.len() < output_len { - // A(i+1) = HMAC_hash(secret, A(i)) - a = hmac::sign(&key, a.as_ref()); - } - } - - Ok(()) -} - /// Get HMAC algorithm from hash algorithm. -pub(super) fn hmac_algorithm(hash: HashAlgorithm) -> Result { +fn hmac_algorithm(hash: HashAlgorithm) -> Result { match hash { HashAlgorithm::SHA256 => Ok(hmac::HMAC_SHA256), HashAlgorithm::SHA384 => Ok(hmac::HMAC_SHA384), @@ -52,17 +16,24 @@ pub(super) fn hmac_algorithm(hash: HashAlgorithm) -> Result Result<[u8; 32], String> { - let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key); + fn hmac( + &self, + hash: HashAlgorithm, + key: &[u8], + data: &[u8], + out: &mut [u8], + ) -> Result { + let algorithm = hmac_algorithm(hash)?; + let hmac_key = hmac::Key::new(algorithm, key); let tag = hmac::sign(&hmac_key, data); - let mut result = [0u8; 32]; - result.copy_from_slice(tag.as_ref()); - Ok(result) + let len = tag.as_ref().len(); + out[..len].copy_from_slice(tag.as_ref()); + Ok(len) } } /// Static instance of the HMAC provider. -pub(super) static HMAC_PROVIDER: AwsLcHmacProvider = AwsLcHmacProvider; +pub(crate) static HMAC_PROVIDER: AwsLcHmacProvider = AwsLcHmacProvider; diff --git a/src/crypto/aws_lc_rs/mod.rs b/src/crypto/aws_lc_rs/mod.rs index 4ac662cd..0a6d89ce 100644 --- a/src/crypto/aws_lc_rs/mod.rs +++ b/src/crypto/aws_lc_rs/mod.rs @@ -59,12 +59,10 @@ mod cipher_suite; mod hash; -mod hkdf; -mod hmac; +pub(crate) mod hmac; mod kx_group; mod random; mod sign; -mod tls12; use super::CryptoProvider; @@ -102,12 +100,6 @@ use super::CryptoProvider; /// - SEC1 DER format (OpenSSL EC private key format) /// - PEM encoded versions of the above /// -/// # TLS 1.2 PRF -/// -/// Implements the TLS 1.2 PRF for key derivation, including: -/// - Standard PRF for master secret and key expansion -/// - Extended Master Secret (RFC 7627) for improved security -/// /// # Random Number Generation /// /// Uses `SystemRandom` from aws-lc-rs for cryptographically secure random number generation. @@ -122,9 +114,7 @@ pub fn default_provider() -> CryptoProvider { hmac_provider: &hmac::HMAC_PROVIDER, // DTLS 1.2 components cipher_suites: cipher_suite::ALL_CIPHER_SUITES, - prf_provider: &tls12::PRF_PROVIDER, // DTLS 1.3 components dtls13_cipher_suites: cipher_suite::ALL_DTLS13_CIPHER_SUITES, - hkdf_provider: &hkdf::HKDF_PROVIDER, } } diff --git a/src/crypto/aws_lc_rs/tls12.rs b/src/crypto/aws_lc_rs/tls12.rs deleted file mode 100644 index cecb7830..00000000 --- a/src/crypto/aws_lc_rs/tls12.rs +++ /dev/null @@ -1,37 +0,0 @@ -//! TLS 1.2 PRF using aws-lc-rs. - -use super::super::PrfProvider; -use crate::buffer::Buf; -use crate::types::HashAlgorithm; - -use super::hmac; - -/// PRF provider implementation for TLS 1.2. -#[derive(Debug)] -pub(super) struct AwsLcPrfProvider; - -impl PrfProvider for AwsLcPrfProvider { - fn prf_tls12( - &self, - secret: &[u8], - label: &str, - seed: &[u8], - out: &mut Buf, - output_len: usize, - scratch: &mut Buf, - hash: HashAlgorithm, - ) -> Result<(), String> { - assert!(label.is_ascii(), "Label must be ASCII"); - - // Use scratch buffer for full_seed concatenation - scratch.clear(); - scratch.extend_from_slice(label.as_bytes()); - scratch.extend_from_slice(seed); - - let algorithm = hmac::hmac_algorithm(hash)?; - hmac::p_hash(algorithm, secret, scratch, out, output_len) - } -} - -/// Static instance of the PRF provider. -pub(super) static PRF_PROVIDER: AwsLcPrfProvider = AwsLcPrfProvider; diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 9c53469f..0c186714 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -13,6 +13,7 @@ pub mod aws_lc_rs; pub mod rust_crypto; mod dtls_aead; +pub mod prf_hkdf; mod provider; mod validation; @@ -30,9 +31,9 @@ pub use crate::buffer::{Buf, TmpBuf}; // Re-export all provider traits and types (similar to rustls structure) // This allows users to do: use dimpl::crypto::{CryptoProvider, SupportedDtls12CipherSuite, ...}; pub use provider::{ - ActiveKeyExchange, Cipher, CryptoProvider, CryptoSafe, HashContext, HashProvider, HkdfProvider, - HmacProvider, KeyProvider, PrfProvider, SecureRandom, SignatureVerifier, SigningKey, - SupportedDtls12CipherSuite, SupportedDtls13CipherSuite, SupportedKxGroup, check_verify_scheme, + ActiveKeyExchange, Cipher, CryptoProvider, CryptoSafe, HashContext, HashProvider, HmacProvider, + KeyProvider, SecureRandom, SignatureVerifier, SigningKey, SupportedDtls12CipherSuite, + SupportedDtls13CipherSuite, SupportedKxGroup, check_verify_scheme, }; #[cfg(feature = "_crypto-common")] pub use provider::{OID_P256, OID_P384, cert_named_group}; diff --git a/src/crypto/prf_hkdf.rs b/src/crypto/prf_hkdf.rs new file mode 100644 index 00000000..64a886c1 --- /dev/null +++ b/src/crypto/prf_hkdf.rs @@ -0,0 +1,537 @@ +//! PRF and HKDF key derivation built on top of [`HmacProvider`]. +//! +//! Both TLS 1.2 PRF and TLS 1.3 HKDF are pure compositions of HMAC calls. +//! This module provides generic implementations so that crypto backends only +//! need to implement [`HmacProvider`] — no separate PRF or HKDF providers. + +use crate::buffer::Buf; +use crate::types::HashAlgorithm; + +use super::HmacProvider; + +/// Maximum HMAC output size we support (SHA-384 = 48 bytes). +const MAX_HASH_LEN: usize = 48; + +// ============================================================================ +// TLS 1.2 PRF (RFC 5246 Section 5) +// ============================================================================ + +/// TLS 1.2 PRF: `PRF(secret, label, seed)` writing `output_len` bytes to `out`. +/// +/// Uses `scratch` for temporary concatenation of label+seed. +#[allow(clippy::too_many_arguments)] +pub fn prf_tls12( + hmac: &dyn HmacProvider, + secret: &[u8], + label: &str, + seed: &[u8], + out: &mut Buf, + output_len: usize, + scratch: &mut Buf, + hash: HashAlgorithm, +) -> Result<(), String> { + let mut hmac_a = [0u8; MAX_HASH_LEN]; + + // Build label + seed + scratch.clear(); + scratch.extend_from_slice(label.as_bytes()); + scratch.extend_from_slice(seed); + let label_seed = scratch.as_ref(); + + // A(1) = HMAC(secret, label_seed) + let hash_len = hmac.hmac(hash, secret, label_seed, &mut hmac_a)?; + + // Build payload = A(i) || label || seed + scratch.clear(); + scratch.extend_from_slice(&hmac_a[..hash_len]); + scratch.extend_from_slice(label.as_bytes()); + scratch.extend_from_slice(seed); + let payload = scratch.as_mut(); + + out.clear(); + while out.len() < output_len { + // P(i) = HMAC(secret, A(i) || label || seed) + let mut hmac_block = [0u8; MAX_HASH_LEN]; + let block_len = hmac.hmac(hash, secret, payload, &mut hmac_block)?; + + let remaining = output_len - out.len(); + let to_copy = remaining.min(block_len); + out.extend_from_slice(&hmac_block[..to_copy]); + + if out.len() < output_len { + // A(i+1) = HMAC(secret, A(i)) + hmac.hmac(hash, secret, &payload[..hash_len], &mut hmac_a)?; + payload[..hash_len].copy_from_slice(&hmac_a[..hash_len]); + } + } + + Ok(()) +} + +// ============================================================================ +// HKDF (RFC 5869) +// ============================================================================ + +/// HKDF-Extract: `PRK = HMAC-Hash(salt, IKM)`. +pub fn hkdf_extract( + hmac: &dyn HmacProvider, + hash: HashAlgorithm, + salt: &[u8], + ikm: &[u8], + out: &mut Buf, +) -> Result<(), String> { + out.clear(); + + let hash_len = hash.output_len(); + let zero_salt: Vec; + let actual_salt = if salt.is_empty() { + zero_salt = vec![0u8; hash_len]; + &zero_salt[..] + } else { + salt + }; + + let mut prk = [0u8; MAX_HASH_LEN]; + let prk_len = hmac.hmac(hash, actual_salt, ikm, &mut prk)?; + out.extend_from_slice(&prk[..prk_len]); + Ok(()) +} + +/// HKDF-Expand: expand `prk` to `output_len` bytes. +pub fn hkdf_expand( + hmac: &dyn HmacProvider, + hash: HashAlgorithm, + prk: &[u8], + info: &[u8], + out: &mut Buf, + output_len: usize, +) -> Result<(), String> { + let hash_len = hash.output_len(); + let n = output_len.div_ceil(hash_len); + if n > 255 { + return Err("HKDF output too long".into()); + } + + let mut t_prev = [0u8; MAX_HASH_LEN]; + let mut t_prev_len = 0usize; + + out.clear(); + for i in 1..=n { + let mut input = Vec::with_capacity(t_prev_len + info.len() + 1); + input.extend_from_slice(&t_prev[..t_prev_len]); + input.extend_from_slice(info); + input.push(i as u8); + + t_prev_len = hmac.hmac(hash, prk, &input, &mut t_prev)?; + + let remaining = output_len - out.len(); + let to_copy = remaining.min(t_prev_len); + out.extend_from_slice(&t_prev[..to_copy]); + } + + Ok(()) +} + +/// HKDF-Expand-Label for TLS 1.3 (RFC 8446 Section 7.1). +/// +/// Uses the `"tls13 "` prefix. +pub fn hkdf_expand_label( + hmac: &dyn HmacProvider, + hash: HashAlgorithm, + secret: &[u8], + label: &[u8], + context: &[u8], + out: &mut Buf, + output_len: usize, +) -> Result<(), String> { + let info = build_hkdf_label(b"tls13 ", label, context, output_len)?; + hkdf_expand(hmac, hash, secret, &info, out, output_len) +} + +/// HKDF-Expand-Label for DTLS 1.3 (RFC 9147). +/// +/// Uses the `"dtls13"` prefix (no trailing space). +pub fn hkdf_expand_label_dtls13( + hmac: &dyn HmacProvider, + hash: HashAlgorithm, + secret: &[u8], + label: &[u8], + context: &[u8], + out: &mut Buf, + output_len: usize, +) -> Result<(), String> { + let info = build_hkdf_label(b"dtls13", label, context, output_len)?; + hkdf_expand(hmac, hash, secret, &info, out, output_len) +} + +/// Build the HkdfLabel structure. +/// +/// ```text +/// struct { +/// uint16 length; +/// opaque label<6..255> = prefix + Label; +/// opaque context<0..255> = Context; +/// } HkdfLabel; +/// ``` +fn build_hkdf_label( + prefix: &[u8], + label: &[u8], + context: &[u8], + output_len: usize, +) -> Result, String> { + let full_label_len = prefix.len() + label.len(); + + if full_label_len > 255 { + return Err("Label too long for HKDF-Expand-Label".into()); + } + if context.len() > 255 { + return Err("Context too long for HKDF-Expand-Label".into()); + } + if output_len > 65535 { + return Err("Output length too large for HKDF-Expand-Label".into()); + } + + let info_len = 2 + 1 + full_label_len + 1 + context.len(); + let mut info = Vec::with_capacity(info_len); + + // uint16 length + info.extend_from_slice(&(output_len as u16).to_be_bytes()); + // opaque label + info.push(full_label_len as u8); + info.extend_from_slice(prefix); + info.extend_from_slice(label); + // opaque context + info.push(context.len() as u8); + info.extend_from_slice(context); + + Ok(info) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn hex_to_vec(hex: &str) -> Vec { + let hex = hex.replace(' ', "").replace('\n', ""); + let mut v = Vec::new(); + for i in 0..hex.len() / 2 { + // unwrap: test-only hex parsing + let byte = u8::from_str_radix(&hex[i * 2..i * 2 + 2], 16).unwrap(); + v.push(byte); + } + v + } + + fn slice_to_hex(data: &[u8]) -> String { + let mut s = String::new(); + for byte in data.iter() { + s.push_str(&format!("{:02x}", byte)); + } + s + } + + /// Convert an ASCII hex array into a byte array at compile time. + macro_rules! hex_as_bytes { + ($input:expr) => {{ + const fn from_hex_char(c: u8) -> u8 { + match c { + b'0'..=b'9' => c - b'0', + b'a'..=b'f' => c - b'a' + 10, + b'A'..=b'F' => c - b'A' + 10, + _ => panic!("Invalid hex character"), + } + } + + const INPUT: &[u8] = $input; + const LEN: usize = INPUT.len(); + const OUTPUT_LEN: usize = LEN / 2; + + const fn convert() -> [u8; OUTPUT_LEN] { + assert!(LEN % 2 == 0, "Hex string length must be even"); + + let mut out = [0u8; OUTPUT_LEN]; + let mut i = 0; + while i < LEN { + out[i / 2] = (from_hex_char(INPUT[i]) << 4) | from_hex_char(INPUT[i + 1]); + i += 2; + } + out + } + + convert() + }}; + } + + /// We need a concrete HmacProvider for tests. Use the default feature-gated one. + fn hmac_provider() -> &'static dyn HmacProvider { + #[cfg(feature = "aws-lc-rs")] + { + &crate::crypto::aws_lc_rs::hmac::HMAC_PROVIDER + } + #[cfg(all(not(feature = "aws-lc-rs"), feature = "rust-crypto"))] + { + &crate::crypto::rust_crypto::hmac::HMAC_PROVIDER + } + } + + // ======================================================================== + // HMAC-SHA-256 Test Vectors from RFC 4231 + // ======================================================================== + + #[test] + fn hmac_sha256_test_case_1() { + let key = hex_to_vec("0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b"); + let data = b"Hi There"; + let expected = "b0344c61d8db38535ca8afceaf0bf12b881dc200c9833da726e9376c2e32cff7"; + + let result = hmac_provider().hmac_sha256(&key, data).unwrap(); + assert_eq!(slice_to_hex(&result), expected); + } + + #[test] + fn hmac_sha256_test_case_2() { + let key = b"Jefe"; + let data = b"what do ya want for nothing?"; + let expected = "5bdcc146bf60754e6a042426089575c75a003f089d2739839dec58b964ec3843"; + + let result = hmac_provider().hmac_sha256(key, data).unwrap(); + assert_eq!(slice_to_hex(&result), expected); + } + + #[test] + fn hmac_sha256_test_case_3() { + let key = hex_to_vec("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + let data = hex_to_vec( + "dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd\ + dddddddddddddddddddddddddddddddddddd", + ); + let expected = "773ea91e36800e46854db8ebd09181a72959098b3ef8c122d9635514ced565fe"; + + let result = hmac_provider().hmac_sha256(&key, &data).unwrap(); + assert_eq!(slice_to_hex(&result), expected); + } + + #[test] + fn hmac_sha256_test_case_4() { + let key = hex_to_vec("0102030405060708090a0b0c0d0e0f10111213141516171819"); + let data = hex_to_vec( + "cdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcd\ + cdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcd", + ); + let expected = "82558a389a443c0ea4cc819899f2083a85f0faa3e578f8077a2e3ff46729665b"; + + let result = hmac_provider().hmac_sha256(&key, &data).unwrap(); + assert_eq!(slice_to_hex(&result), expected); + } + + #[test] + fn hmac_sha256_test_case_6() { + // Test with a key larger than block size (> 64 bytes) + let key = vec![0xaa; 131]; + let data = b"Test Using Larger Than Block-Size Key - Hash Key First"; + let expected = "60e431591ee0b67f0d8a26aacbf5b77f8e0bc6213728c5140546040f0ee37f54"; + + let result = hmac_provider().hmac_sha256(&key, data).unwrap(); + assert_eq!(slice_to_hex(&result), expected); + } + + #[test] + fn hmac_sha256_test_case_7() { + // Test with a key larger than block size and large data + let key = vec![0xaa; 131]; + let data = + b"This is a test using a larger than block-size key and a larger than block-size \ + data. The key needs to be hashed before being used by the HMAC algorithm."; + let expected = "9b09ffa71b942fcb27635fbcd5b0e944bfdc63644f0713938a7f51535c3a35e2"; + + let result = hmac_provider().hmac_sha256(&key, data).unwrap(); + assert_eq!(slice_to_hex(&result), expected); + } + + // ======================================================================== + // TLS 1.2 PRF + // ======================================================================== + + #[test] + fn prf_tls12_sha256() { + // Test vector from https://github.com/xomexh/TLS-PRF + let mut output = Buf::new(); + let mut scratch = Buf::new(); + prf_tls12( + hmac_provider(), + &hex_as_bytes!(b"9bbe436ba940f017b17652849a71db35"), + "test label", + &hex_as_bytes!(b"a0ba9f936cda311827a6f796ffd5198c"), + &mut output, + 100, + &mut scratch, + HashAlgorithm::SHA256, + ) + .unwrap(); + assert_eq!( + output.as_ref(), + &hex_as_bytes!( + b"e3f229ba727be17b8d122620557cd453c2aab21d\ + 07c3d495329b52d4e61edb5a6b301791e90d35c9\ + c9a46b4e14baf9af0fa022f7077def17abfd3797\ + c0564bab4fbc91666e9def9b97fce34f796789ba\ + a48082d122ee42c5a72e5a5110fff70187347b66" + ) + ); + } + + // ======================================================================== + // HKDF Test Vectors from RFC 5869 + // ======================================================================== + + #[test] + fn hkdf_sha256_rfc5869_case1() { + // Test Case 1 - Basic test case with SHA-256 + let ikm = hex_to_vec("0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b"); + let salt = hex_to_vec("000102030405060708090a0b0c"); + let info = hex_to_vec("f0f1f2f3f4f5f6f7f8f9"); + let expected_prk = "077709362c2e32df0ddc3f0dc47bba6390b6c73bb50f9c3122ec844ad7c2b3e5"; + let expected_okm = "3cb25f25faacd57a90434f64d0362f2a\ + 2d2d0a90cf1a5a4c5db02d56ecc4c5bf\ + 34007208d5b887185865"; + + let h = hmac_provider(); + + let mut prk = Buf::new(); + hkdf_extract(h, HashAlgorithm::SHA256, &salt, &ikm, &mut prk).unwrap(); + assert_eq!(slice_to_hex(prk.as_ref()), expected_prk); + + let mut okm = Buf::new(); + hkdf_expand(h, HashAlgorithm::SHA256, prk.as_ref(), &info, &mut okm, 42).unwrap(); + assert_eq!(slice_to_hex(okm.as_ref()), expected_okm); + } + + #[test] + fn hkdf_sha256_rfc5869_case2() { + // Test Case 2 - Longer inputs/outputs with SHA-256 + let ikm = hex_to_vec( + "000102030405060708090a0b0c0d0e0f\ + 101112131415161718191a1b1c1d1e1f\ + 202122232425262728292a2b2c2d2e2f\ + 303132333435363738393a3b3c3d3e3f\ + 404142434445464748494a4b4c4d4e4f", + ); + let salt = hex_to_vec( + "606162636465666768696a6b6c6d6e6f\ + 707172737475767778797a7b7c7d7e7f\ + 808182838485868788898a8b8c8d8e8f\ + 909192939495969798999a9b9c9d9e9f\ + a0a1a2a3a4a5a6a7a8a9aaabacadaeaf", + ); + let info = hex_to_vec( + "b0b1b2b3b4b5b6b7b8b9babbbcbdbebf\ + c0c1c2c3c4c5c6c7c8c9cacbcccdcecf\ + d0d1d2d3d4d5d6d7d8d9dadbdcdddedf\ + e0e1e2e3e4e5e6e7e8e9eaebecedeeef\ + f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff", + ); + let expected_prk = "06a6b88c5853361a06104c9ceb35b45cef760014904671014a193f40c15fc244"; + let expected_okm = "b11e398dc80327a1c8e7f78c596a4934\ + 4f012eda2d4efad8a050cc4c19afa97c\ + 59045a99cac7827271cb41c65e590e09\ + da3275600c2f09b8367793a9aca3db71\ + cc30c58179ec3e87c14c01d5c1f3434f\ + 1d87"; + + let h = hmac_provider(); + + let mut prk = Buf::new(); + hkdf_extract(h, HashAlgorithm::SHA256, &salt, &ikm, &mut prk).unwrap(); + assert_eq!(slice_to_hex(prk.as_ref()), expected_prk); + + let mut okm = Buf::new(); + hkdf_expand(h, HashAlgorithm::SHA256, prk.as_ref(), &info, &mut okm, 82).unwrap(); + assert_eq!(slice_to_hex(okm.as_ref()), expected_okm); + } + + #[test] + fn hkdf_sha256_rfc5869_case3() { + // Test Case 3 - Zero-length salt and info with SHA-256 + let ikm = hex_to_vec("0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b"); + let salt = vec![]; + let info = vec![]; + let expected_prk = "19ef24a32c717b167f33a91d6f648bdf96596776afdb6377ac434c1c293ccb04"; + let expected_okm = "8da4e775a563c18f715f802a063c5a31\ + b8a11f5c5ee1879ec3454e5f3c738d2d\ + 9d201395faa4b61a96c8"; + + let h = hmac_provider(); + + let mut prk = Buf::new(); + hkdf_extract(h, HashAlgorithm::SHA256, &salt, &ikm, &mut prk).unwrap(); + assert_eq!(slice_to_hex(prk.as_ref()), expected_prk); + + let mut okm = Buf::new(); + hkdf_expand(h, HashAlgorithm::SHA256, prk.as_ref(), &info, &mut okm, 42).unwrap(); + assert_eq!(slice_to_hex(okm.as_ref()), expected_okm); + } + + // ======================================================================== + // HKDF-Expand-Label + // ======================================================================== + + #[test] + fn hkdf_expand_label_basic() { + let h = hmac_provider(); + let secret = [0u8; 32]; + let mut out = Buf::new(); + + hkdf_expand_label(h, HashAlgorithm::SHA256, &secret, b"key", &[], &mut out, 16).unwrap(); + assert_eq!(out.len(), 16); + + hkdf_expand_label( + h, + HashAlgorithm::SHA256, + &secret, + b"iv", + &[1, 2, 3, 4], + &mut out, + 12, + ) + .unwrap(); + assert_eq!(out.len(), 12); + } + + #[test] + fn hkdf_expand_label_dtls13_basic() { + let h = hmac_provider(); + let secret = [0u8; 32]; + let mut out = Buf::new(); + + hkdf_expand_label_dtls13(h, HashAlgorithm::SHA256, &secret, b"key", &[], &mut out, 16) + .unwrap(); + assert_eq!(out.len(), 16); + + // TLS 1.3 and DTLS 1.3 with same inputs should produce different outputs + let mut tls_out = Buf::new(); + let mut dtls_out = Buf::new(); + + hkdf_expand_label( + h, + HashAlgorithm::SHA256, + &secret, + b"key", + &[], + &mut tls_out, + 16, + ) + .unwrap(); + hkdf_expand_label_dtls13( + h, + HashAlgorithm::SHA256, + &secret, + b"key", + &[], + &mut dtls_out, + 16, + ) + .unwrap(); + + assert_ne!(tls_out.as_ref(), dtls_out.as_ref()); + } +} diff --git a/src/crypto/provider.rs b/src/crypto/provider.rs index b3e07236..5def5fd5 100644 --- a/src/crypto/provider.rs +++ b/src/crypto/provider.rs @@ -19,8 +19,7 @@ //! - **Key Provider** ([`KeyProvider`]): Parse and load private keys //! - **Secure Random** ([`SecureRandom`]): Cryptographically secure RNG //! - **Hash Provider** ([`HashProvider`]): Factory for hash contexts -//! - **PRF Provider** ([`PrfProvider`]): TLS 1.2 PRF for key derivation -//! - **HMAC Provider** ([`HmacProvider`]): Compute HMAC signatures +//! - **HMAC Provider** ([`HmacProvider`]): Compute HMAC signatures (also drives PRF and HKDF) //! //! # Using a Custom Provider //! @@ -135,7 +134,7 @@ //! - **Key exchange**: ECDHE with X25519, P-256, or P-384 curves //! - **Signatures**: ECDSA with P-256/SHA-256 or P-384/SHA-384 //! - **Hash**: SHA-256 and SHA-384 -//! - **PRF**: TLS 1.2 PRF (using HMAC-SHA256 or HMAC-SHA384) +//! - **HMAC**: HMAC-SHA256 and HMAC-SHA384 (used for PRF, HKDF, and cookies) //! //! # Thread Safety //! @@ -375,27 +374,25 @@ pub trait HashProvider: CryptoSafe { fn create_hash(&self, algorithm: HashAlgorithm) -> Box; } -/// PRF (Pseudo-Random Function) for TLS 1.2 key derivation. -pub trait PrfProvider: CryptoSafe { - /// TLS 1.2 PRF: PRF(secret, label, seed) writing output to `out`. - /// Uses `scratch` for temporary concatenation of label+seed. - #[allow(clippy::too_many_arguments)] - fn prf_tls12( - &self, - secret: &[u8], - label: &str, - seed: &[u8], - out: &mut Buf, - output_len: usize, - scratch: &mut Buf, - hash: HashAlgorithm, - ) -> Result<(), String>; -} - /// HMAC provider for computing HMAC signatures. pub trait HmacProvider: CryptoSafe { /// Compute HMAC-SHA256(key, data) and return the result. - fn hmac_sha256(&self, key: &[u8], data: &[u8]) -> Result<[u8; 32], String>; + fn hmac_sha256(&self, key: &[u8], data: &[u8]) -> Result<[u8; 32], String> { + let mut out = [0u8; 32]; + self.hmac(HashAlgorithm::SHA256, key, data, &mut out)?; + Ok(out) + } + + /// Compute HMAC for the given hash algorithm, writing the result to `out`. + /// + /// Returns the number of bytes written. + fn hmac( + &self, + hash: HashAlgorithm, + key: &[u8], + data: &[u8], + out: &mut [u8], + ) -> Result; } // ============================================================================ @@ -438,70 +435,6 @@ pub trait SupportedDtls13CipherSuite: CryptoSafe { fn encrypt_sn(&self, sn_key: &[u8], sample: &[u8; 16]) -> [u8; 16]; } -/// HKDF provider for TLS 1.3 key derivation (RFC 5869). -/// -/// TLS 1.3 uses HKDF instead of the TLS 1.2 PRF for all key derivation. -pub trait HkdfProvider: CryptoSafe { - /// HKDF-Extract: Extract a pseudorandom key from input keying material. - /// PRK = HKDF-Extract(salt, IKM) - fn hkdf_extract( - &self, - hash: HashAlgorithm, - salt: &[u8], - ikm: &[u8], - out: &mut Buf, - ) -> Result<(), String>; - - /// HKDF-Expand: Expand a pseudorandom key to the desired length. - /// OKM = HKDF-Expand(PRK, info, L) - fn hkdf_expand( - &self, - hash: HashAlgorithm, - prk: &[u8], - info: &[u8], - out: &mut Buf, - output_len: usize, - ) -> Result<(), String>; - - /// HKDF-Expand-Label for TLS 1.3 (RFC 8446 Section 7.1). - /// Derives key material using the TLS 1.3 label format with "tls13 " prefix. - /// - /// HkdfLabel = struct { - /// uint16 length; - /// opaque label<7..255> = "tls13 " + Label; - /// opaque context<0..255> = Context; - /// } - /// OKM = HKDF-Expand(Secret, HkdfLabel, Length) - fn hkdf_expand_label( - &self, - hash: HashAlgorithm, - secret: &[u8], - label: &[u8], - context: &[u8], - out: &mut Buf, - output_len: usize, - ) -> Result<(), String>; - - /// HKDF-Expand-Label for DTLS 1.3 (RFC 9147). - /// Derives key material using the DTLS 1.3 label format with "dtls13" prefix. - /// - /// HkdfLabel = struct { - /// uint16 length; - /// opaque label<6..255> = "dtls13" + Label; - /// opaque context<0..255> = Context; - /// } - /// OKM = HKDF-Expand(Secret, HkdfLabel, Length) - fn hkdf_expand_label_dtls13( - &self, - hash: HashAlgorithm, - secret: &[u8], - label: &[u8], - context: &[u8], - out: &mut Buf, - output_len: usize, - ) -> Result<(), String>; -} - // ============================================================================ // Core Provider Struct // ============================================================================ @@ -514,12 +447,10 @@ pub trait HkdfProvider: CryptoSafe { /// /// # Version-Specific Components /// -/// Some components are version-specific: -/// - **DTLS 1.2**: Uses `cipher_suites` and `prf_provider` -/// - **DTLS 1.3**: Uses `dtls13_cipher_suites` and `hkdf_provider` -/// /// Shared components like `kx_groups`, `signature_verification`, `key_provider`, /// `secure_random`, `hash_provider`, and `hmac_provider` are used by both versions. +/// PRF (TLS 1.2) and HKDF (TLS 1.3) key derivation are built generically on top +/// of `hmac_provider` — see the [`prf_hkdf`](super::prf_hkdf) module. /// /// # Design /// @@ -548,10 +479,8 @@ pub trait HkdfProvider: CryptoSafe { /// hmac_provider: provider.hmac_provider, /// // DTLS 1.2 components /// cipher_suites: provider.cipher_suites, -/// prf_provider: provider.prf_provider, /// // DTLS 1.3 components /// dtls13_cipher_suites: provider.dtls13_cipher_suites, -/// hkdf_provider: provider.hkdf_provider, /// }; /// # } /// # #[cfg(not(feature = "aws-lc-rs"))] @@ -591,11 +520,6 @@ pub struct CryptoProvider { /// and MAC algorithms together (e.g., TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256). pub cipher_suites: &'static [&'static dyn SupportedDtls12CipherSuite], - /// PRF for TLS 1.2 key derivation. - /// - /// The Pseudo-Random Function used for key expansion in DTLS 1.2. - pub prf_provider: &'static dyn PrfProvider, - // ========================================================================= // DTLS 1.3 specific components // ========================================================================= @@ -604,11 +528,6 @@ pub struct CryptoProvider { /// TLS 1.3 cipher suites only specify the AEAD and hash algorithms /// (e.g., TLS_AES_128_GCM_SHA256). Key exchange is negotiated separately. pub dtls13_cipher_suites: &'static [&'static dyn SupportedDtls13CipherSuite], - - /// HKDF provider for TLS 1.3 key derivation. - /// - /// TLS 1.3 uses HKDF instead of the TLS 1.2 PRF for all key derivation. - pub hkdf_provider: &'static dyn HkdfProvider, } /// Static storage for the default crypto provider. diff --git a/src/crypto/rust_crypto/hkdf.rs b/src/crypto/rust_crypto/hkdf.rs deleted file mode 100644 index 3589651f..00000000 --- a/src/crypto/rust_crypto/hkdf.rs +++ /dev/null @@ -1,366 +0,0 @@ -//! HKDF implementation using RustCrypto crates for TLS 1.3 key derivation. - -use hkdf::Hkdf; -use sha2::{Sha256, Sha384}; - -use super::super::HkdfProvider; -use crate::buffer::Buf; -use crate::types::HashAlgorithm; - -/// HKDF provider implementation using RustCrypto. -#[derive(Debug)] -pub(super) struct RustCryptoHkdfProvider; - -impl HkdfProvider for RustCryptoHkdfProvider { - fn hkdf_extract( - &self, - hash: HashAlgorithm, - salt: &[u8], - ikm: &[u8], - out: &mut Buf, - ) -> Result<(), String> { - out.clear(); - - match hash { - HashAlgorithm::SHA256 => { - let salt = if salt.is_empty() { None } else { Some(salt) }; - let (prk, _) = Hkdf::::extract(salt, ikm); - out.extend_from_slice(prk.as_slice()); - } - HashAlgorithm::SHA384 => { - let salt = if salt.is_empty() { None } else { Some(salt) }; - let (prk, _) = Hkdf::::extract(salt, ikm); - out.extend_from_slice(prk.as_slice()); - } - _ => return Err(format!("Unsupported hash for HKDF: {:?}", hash)), - } - - Ok(()) - } - - fn hkdf_expand( - &self, - hash: HashAlgorithm, - prk: &[u8], - info: &[u8], - out: &mut Buf, - output_len: usize, - ) -> Result<(), String> { - out.clear(); - let mut output = vec![0u8; output_len]; - - match hash { - HashAlgorithm::SHA256 => { - let hk = - Hkdf::::from_prk(prk).map_err(|e| format!("Invalid PRK: {:?}", e))?; - hk.expand(info, &mut output) - .map_err(|e| format!("HKDF expand failed: {:?}", e))?; - } - HashAlgorithm::SHA384 => { - let hk = - Hkdf::::from_prk(prk).map_err(|e| format!("Invalid PRK: {:?}", e))?; - hk.expand(info, &mut output) - .map_err(|e| format!("HKDF expand failed: {:?}", e))?; - } - _ => return Err(format!("Unsupported hash for HKDF: {:?}", hash)), - } - - out.extend_from_slice(&output); - Ok(()) - } - - fn hkdf_expand_label( - &self, - hash: HashAlgorithm, - secret: &[u8], - label: &[u8], - context: &[u8], - out: &mut Buf, - output_len: usize, - ) -> Result<(), String> { - // Build the HkdfLabel structure per RFC 8446 Section 7.1 - let full_label_len = 6 + label.len(); // "tls13 " + label - - if full_label_len > 255 { - return Err("Label too long for HKDF-Expand-Label".to_string()); - } - if context.len() > 255 { - return Err("Context too long for HKDF-Expand-Label".to_string()); - } - if output_len > 65535 { - return Err("Output length too large for HKDF-Expand-Label".to_string()); - } - - // Build the info (HkdfLabel) - let info_len = 2 + 1 + full_label_len + 1 + context.len(); - let mut info = Vec::with_capacity(info_len); - - // uint16 length - info.extend_from_slice(&(output_len as u16).to_be_bytes()); - - // opaque label<7..255> = "tls13 " + Label - info.push(full_label_len as u8); - info.extend_from_slice(b"tls13 "); - info.extend_from_slice(label); - - // opaque context<0..255> - info.push(context.len() as u8); - info.extend_from_slice(context); - - // Now do regular HKDF-Expand - self.hkdf_expand(hash, secret, &info, out, output_len) - } - - fn hkdf_expand_label_dtls13( - &self, - hash: HashAlgorithm, - secret: &[u8], - label: &[u8], - context: &[u8], - out: &mut Buf, - output_len: usize, - ) -> Result<(), String> { - // Build the HkdfLabel structure for DTLS 1.3 per RFC 9147 - let full_label_len = 6 + label.len(); // "dtls13" + label - - if full_label_len > 255 { - return Err("Label too long for HKDF-Expand-Label".to_string()); - } - if context.len() > 255 { - return Err("Context too long for HKDF-Expand-Label".to_string()); - } - if output_len > 65535 { - return Err("Output length too large for HKDF-Expand-Label".to_string()); - } - - // Build the info (HkdfLabel) - let info_len = 2 + 1 + full_label_len + 1 + context.len(); - let mut info = Vec::with_capacity(info_len); - - // uint16 length - info.extend_from_slice(&(output_len as u16).to_be_bytes()); - - // opaque label<6..255> = "dtls13" + Label - info.push(full_label_len as u8); - info.extend_from_slice(b"dtls13"); - info.extend_from_slice(label); - - // opaque context<0..255> - info.push(context.len() as u8); - info.extend_from_slice(context); - - // Now do regular HKDF-Expand - self.hkdf_expand(hash, secret, &info, out, output_len) - } -} - -/// Static instance of the HKDF provider. -pub(super) static HKDF_PROVIDER: RustCryptoHkdfProvider = RustCryptoHkdfProvider; - -#[cfg(test)] -mod tests { - use super::*; - - // RFC 5869 Test Case 1 - Basic test case with SHA-256 - #[test] - fn test_hkdf_sha256_rfc5869_case1() { - let provider = RustCryptoHkdfProvider; - - // IKM = 0x0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b (22 bytes) - let ikm = [0x0b; 22]; - - // salt = 0x000102030405060708090a0b0c (13 bytes) - let salt = [ - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, - ]; - - // info = 0xf0f1f2f3f4f5f6f7f8f9 (10 bytes) - let info = [0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9]; - - // Expected PRK (32 bytes) - let expected_prk = [ - 0x07, 0x77, 0x09, 0x36, 0x2c, 0x2e, 0x32, 0xdf, 0x0d, 0xdc, 0x3f, 0x0d, 0xc4, 0x7b, - 0xba, 0x63, 0x90, 0xb6, 0xc7, 0x3b, 0xb5, 0x0f, 0x9c, 0x31, 0x22, 0xec, 0x84, 0x4a, - 0xd7, 0xc2, 0xb3, 0xe5, - ]; - - // Expected OKM (42 bytes) - let expected_okm = [ - 0x3c, 0xb2, 0x5f, 0x25, 0xfa, 0xac, 0xd5, 0x7a, 0x90, 0x43, 0x4f, 0x64, 0xd0, 0x36, - 0x2f, 0x2a, 0x2d, 0x2d, 0x0a, 0x90, 0xcf, 0x1a, 0x5a, 0x4c, 0x5d, 0xb0, 0x2d, 0x56, - 0xec, 0xc4, 0xc5, 0xbf, 0x34, 0x00, 0x72, 0x08, 0xd5, 0xb8, 0x87, 0x18, 0x58, 0x65, - ]; - - // Test extract - let mut prk = Buf::new(); - provider - .hkdf_extract(HashAlgorithm::SHA256, &salt, &ikm, &mut prk) - .unwrap(); - assert_eq!(&*prk, &expected_prk[..]); - - // Test expand - let mut okm = Buf::new(); - provider - .hkdf_expand(HashAlgorithm::SHA256, &prk, &info, &mut okm, 42) - .unwrap(); - assert_eq!(&*okm, &expected_okm[..]); - } - - // RFC 5869 Test Case 2 - Longer inputs/outputs with SHA-256 - #[test] - fn test_hkdf_sha256_rfc5869_case2() { - let provider = RustCryptoHkdfProvider; - - // IKM = 0x000102...4f (80 bytes) - let ikm: Vec = (0x00..=0x4f).collect(); - - // salt = 0x606162...af (80 bytes) - let salt: Vec = (0x60..=0xaf).collect(); - - // info = 0xb0b1b2...ff (80 bytes) - let info: Vec = (0xb0..=0xff).collect(); - - // Expected PRK (32 bytes) - let expected_prk = [ - 0x06, 0xa6, 0xb8, 0x8c, 0x58, 0x53, 0x36, 0x1a, 0x06, 0x10, 0x4c, 0x9c, 0xeb, 0x35, - 0xb4, 0x5c, 0xef, 0x76, 0x00, 0x14, 0x90, 0x46, 0x71, 0x01, 0x4a, 0x19, 0x3f, 0x40, - 0xc1, 0x5f, 0xc2, 0x44, - ]; - - // Expected OKM (82 bytes) - let expected_okm = [ - 0xb1, 0x1e, 0x39, 0x8d, 0xc8, 0x03, 0x27, 0xa1, 0xc8, 0xe7, 0xf7, 0x8c, 0x59, 0x6a, - 0x49, 0x34, 0x4f, 0x01, 0x2e, 0xda, 0x2d, 0x4e, 0xfa, 0xd8, 0xa0, 0x50, 0xcc, 0x4c, - 0x19, 0xaf, 0xa9, 0x7c, 0x59, 0x04, 0x5a, 0x99, 0xca, 0xc7, 0x82, 0x72, 0x71, 0xcb, - 0x41, 0xc6, 0x5e, 0x59, 0x0e, 0x09, 0xda, 0x32, 0x75, 0x60, 0x0c, 0x2f, 0x09, 0xb8, - 0x36, 0x77, 0x93, 0xa9, 0xac, 0xa3, 0xdb, 0x71, 0xcc, 0x30, 0xc5, 0x81, 0x79, 0xec, - 0x3e, 0x87, 0xc1, 0x4c, 0x01, 0xd5, 0xc1, 0xf3, 0x43, 0x4f, 0x1d, 0x87, - ]; - - // Test extract - let mut prk = Buf::new(); - provider - .hkdf_extract(HashAlgorithm::SHA256, &salt, &ikm, &mut prk) - .unwrap(); - assert_eq!(&*prk, &expected_prk[..]); - - // Test expand - let mut okm = Buf::new(); - provider - .hkdf_expand(HashAlgorithm::SHA256, &prk, &info, &mut okm, 82) - .unwrap(); - assert_eq!(&*okm, &expected_okm[..]); - } - - // RFC 5869 Test Case 3 - Zero-length salt and info with SHA-256 - #[test] - fn test_hkdf_sha256_rfc5869_case3() { - let provider = RustCryptoHkdfProvider; - - // IKM = 0x0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b (22 bytes) - let ikm = [0x0b; 22]; - - // salt = empty - let salt: [u8; 0] = []; - - // info = empty - let info: [u8; 0] = []; - - // Expected PRK (32 bytes) - let expected_prk = [ - 0x19, 0xef, 0x24, 0xa3, 0x2c, 0x71, 0x7b, 0x16, 0x7f, 0x33, 0xa9, 0x1d, 0x6f, 0x64, - 0x8b, 0xdf, 0x96, 0x59, 0x67, 0x76, 0xaf, 0xdb, 0x63, 0x77, 0xac, 0x43, 0x4c, 0x1c, - 0x29, 0x3c, 0xcb, 0x04, - ]; - - // Expected OKM (42 bytes) - let expected_okm = [ - 0x8d, 0xa4, 0xe7, 0x75, 0xa5, 0x63, 0xc1, 0x8f, 0x71, 0x5f, 0x80, 0x2a, 0x06, 0x3c, - 0x5a, 0x31, 0xb8, 0xa1, 0x1f, 0x5c, 0x5e, 0xe1, 0x87, 0x9e, 0xc3, 0x45, 0x4e, 0x5f, - 0x3c, 0x73, 0x8d, 0x2d, 0x9d, 0x20, 0x13, 0x95, 0xfa, 0xa4, 0xb6, 0x1a, 0x96, 0xc8, - ]; - - // Test extract - let mut prk = Buf::new(); - provider - .hkdf_extract(HashAlgorithm::SHA256, &salt, &ikm, &mut prk) - .unwrap(); - assert_eq!(&*prk, &expected_prk[..]); - - // Test expand - let mut okm = Buf::new(); - provider - .hkdf_expand(HashAlgorithm::SHA256, &prk, &info, &mut okm, 42) - .unwrap(); - assert_eq!(&*okm, &expected_okm[..]); - } - - // Test HKDF-Expand-Label structure is built correctly - #[test] - fn test_hkdf_expand_label_basic() { - let provider = RustCryptoHkdfProvider; - let secret = [0u8; 32]; - let mut out = Buf::new(); - - // Should succeed with valid inputs - provider - .hkdf_expand_label(HashAlgorithm::SHA256, &secret, b"key", &[], &mut out, 16) - .unwrap(); - assert_eq!(out.len(), 16); - - // Should succeed with context - provider - .hkdf_expand_label( - HashAlgorithm::SHA256, - &secret, - b"iv", - &[1, 2, 3, 4], - &mut out, - 12, - ) - .unwrap(); - assert_eq!(out.len(), 12); - } - - // Test DTLS 1.3 expand label - #[test] - fn test_hkdf_expand_label_dtls13_basic() { - let provider = RustCryptoHkdfProvider; - let secret = [0u8; 32]; - let mut out = Buf::new(); - - // Should succeed with valid inputs - provider - .hkdf_expand_label_dtls13(HashAlgorithm::SHA256, &secret, b"key", &[], &mut out, 16) - .unwrap(); - assert_eq!(out.len(), 16); - - // TLS 1.3 and DTLS 1.3 with same inputs should produce different outputs - // due to different label prefixes ("tls13 " vs "dtls13") - let mut tls_out = Buf::new(); - let mut dtls_out = Buf::new(); - - provider - .hkdf_expand_label( - HashAlgorithm::SHA256, - &secret, - b"key", - &[], - &mut tls_out, - 16, - ) - .unwrap(); - provider - .hkdf_expand_label_dtls13( - HashAlgorithm::SHA256, - &secret, - b"key", - &[], - &mut dtls_out, - 16, - ) - .unwrap(); - - assert_ne!(&*tls_out, &*dtls_out); - } -} diff --git a/src/crypto/rust_crypto/hmac.rs b/src/crypto/rust_crypto/hmac.rs index 03082e67..16650efd 100644 --- a/src/crypto/rust_crypto/hmac.rs +++ b/src/crypto/rust_crypto/hmac.rs @@ -1,101 +1,46 @@ -//! HMAC utilities using RustCrypto. +//! HMAC implementation using RustCrypto. use hmac::{Hmac, Mac}; use sha2::{Sha256, Sha384}; use super::super::HmacProvider; -use crate::buffer::Buf; use crate::types::HashAlgorithm; -/// Compute HMAC using TLS 1.2 P_hash algorithm. -pub(super) fn p_hash( - hash_alg: HashAlgorithm, - secret: &[u8], - full_seed: &[u8], - out: &mut Buf, - output_len: usize, -) -> Result<(), String> { - out.clear(); - - // A(1) = HMAC_hash(secret, A(0)) where A(0) = seed - match hash_alg { - HashAlgorithm::SHA256 => { - let mut a_hmac = Hmac::::new_from_slice(secret) - .map_err(|_| "Invalid HMAC key length".to_string())?; - a_hmac.update(full_seed); - let mut a = a_hmac.finalize().into_bytes(); - - while out.len() < output_len { - // HMAC_hash(secret, A(i) + seed) - let mut ctx = Hmac::::new_from_slice(secret) - .map_err(|_| "Invalid HMAC key length".to_string())?; - ctx.update(&a); - ctx.update(full_seed); - let output = ctx.finalize().into_bytes(); - - let remaining = output_len - out.len(); - let to_copy = std::cmp::min(remaining, output.len()); - out.extend_from_slice(&output[..to_copy]); - - if out.len() < output_len { - // A(i+1) = HMAC_hash(secret, A(i)) - let mut next_a = Hmac::::new_from_slice(secret) - .map_err(|_| "Invalid HMAC key length".to_string())?; - next_a.update(&a); - a = next_a.finalize().into_bytes(); - } - } - } - HashAlgorithm::SHA384 => { - let mut a_hmac = Hmac::::new_from_slice(secret) - .map_err(|_| "Invalid HMAC key length".to_string())?; - a_hmac.update(full_seed); - let mut a = a_hmac.finalize().into_bytes(); - - while out.len() < output_len { - // HMAC_hash(secret, A(i) + seed) - let mut ctx = Hmac::::new_from_slice(secret) - .map_err(|_| "Invalid HMAC key length".to_string())?; - ctx.update(&a); - ctx.update(full_seed); - let output = ctx.finalize().into_bytes(); - - let remaining = output_len - out.len(); - let to_copy = std::cmp::min(remaining, output.len()); - out.extend_from_slice(&output[..to_copy]); - - if out.len() < output_len { - // A(i+1) = HMAC_hash(secret, A(i)) - let mut next_a = Hmac::::new_from_slice(secret) - .map_err(|_| "Invalid HMAC key length".to_string())?; - next_a.update(&a); - a = next_a.finalize().into_bytes(); - } - } - } - _ => return Err(format!("Unsupported HMAC hash algorithm: {:?}", hash_alg)), - } - - Ok(()) -} - /// HMAC provider implementation. #[derive(Debug)] -pub(super) struct RustCryptoHmacProvider; +pub(crate) struct RustCryptoHmacProvider; impl HmacProvider for RustCryptoHmacProvider { - fn hmac_sha256(&self, key: &[u8], data: &[u8]) -> Result<[u8; 32], String> { - let mut mac = - Hmac::::new_from_slice(key).map_err(|_| "Invalid HMAC key".to_string())?; - mac.update(data); - let result = mac.finalize(); - let bytes = result.into_bytes(); - - let mut output = [0u8; 32]; - output.copy_from_slice(&bytes); - Ok(output) + fn hmac( + &self, + hash: HashAlgorithm, + key: &[u8], + data: &[u8], + out: &mut [u8], + ) -> Result { + match hash { + HashAlgorithm::SHA256 => { + let mut mac = Hmac::::new_from_slice(key) + .map_err(|_| "Invalid HMAC key".to_string())?; + mac.update(data); + let result = mac.finalize().into_bytes(); + let len = result.len(); + out[..len].copy_from_slice(&result); + Ok(len) + } + HashAlgorithm::SHA384 => { + let mut mac = Hmac::::new_from_slice(key) + .map_err(|_| "Invalid HMAC key".to_string())?; + mac.update(data); + let result = mac.finalize().into_bytes(); + let len = result.len(); + out[..len].copy_from_slice(&result); + Ok(len) + } + _ => Err(format!("Unsupported HMAC hash algorithm: {:?}", hash)), + } } } /// Static instance of the HMAC provider. -pub(super) static HMAC_PROVIDER: RustCryptoHmacProvider = RustCryptoHmacProvider; +pub(crate) static HMAC_PROVIDER: RustCryptoHmacProvider = RustCryptoHmacProvider; diff --git a/src/crypto/rust_crypto/mod.rs b/src/crypto/rust_crypto/mod.rs index 8b4bcb5d..ebab8589 100644 --- a/src/crypto/rust_crypto/mod.rs +++ b/src/crypto/rust_crypto/mod.rs @@ -41,12 +41,10 @@ mod cipher_suite; mod hash; -mod hkdf; -mod hmac; +pub(crate) mod hmac; mod kx_group; mod random; mod sign; -mod tls12; use super::CryptoProvider; @@ -84,12 +82,6 @@ use super::CryptoProvider; /// - SEC1 DER format (OpenSSL EC private key format) /// - PEM encoded versions of the above /// -/// # TLS 1.2 PRF -/// -/// Implements the TLS 1.2 PRF for key derivation, including: -/// - Standard PRF for master secret and key expansion -/// - Extended Master Secret (RFC 7627) for improved security -/// /// # Random Number Generation /// /// Uses `OsRng` from the `rand` crate for cryptographically secure random number generation. @@ -104,9 +96,7 @@ pub fn default_provider() -> CryptoProvider { hmac_provider: &hmac::HMAC_PROVIDER, // DTLS 1.2 components cipher_suites: cipher_suite::ALL_CIPHER_SUITES, - prf_provider: &tls12::PRF_PROVIDER, // DTLS 1.3 components dtls13_cipher_suites: cipher_suite::ALL_DTLS13_CIPHER_SUITES, - hkdf_provider: &hkdf::HKDF_PROVIDER, } } diff --git a/src/crypto/rust_crypto/tls12.rs b/src/crypto/rust_crypto/tls12.rs deleted file mode 100644 index a14cb7a5..00000000 --- a/src/crypto/rust_crypto/tls12.rs +++ /dev/null @@ -1,36 +0,0 @@ -//! TLS 1.2 PRF using RustCrypto. - -use super::super::PrfProvider; -use crate::buffer::Buf; -use crate::types::HashAlgorithm; - -use super::hmac; - -/// PRF provider implementation for TLS 1.2. -#[derive(Debug)] -pub(super) struct RustCryptoPrfProvider; - -impl PrfProvider for RustCryptoPrfProvider { - fn prf_tls12( - &self, - secret: &[u8], - label: &str, - seed: &[u8], - out: &mut Buf, - output_len: usize, - scratch: &mut Buf, - hash: HashAlgorithm, - ) -> Result<(), String> { - assert!(label.is_ascii(), "Label must be ASCII"); - - // Compute full_seed = label + seed using scratch buffer - scratch.clear(); - scratch.extend_from_slice(label.as_bytes()); - scratch.extend_from_slice(seed); - - hmac::p_hash(hash, secret, scratch, out, output_len) - } -} - -/// Static instance of the PRF provider. -pub(super) static PRF_PROVIDER: RustCryptoPrfProvider = RustCryptoPrfProvider; diff --git a/src/crypto/validation/mod.rs b/src/crypto/validation/mod.rs index a32eca16..1c2c05ab 100644 --- a/src/crypto/validation/mod.rs +++ b/src/crypto/validation/mod.rs @@ -72,7 +72,7 @@ impl CryptoProvider { self.validate_cipher_suites()?; self.validate_kx_groups()?; let validated_hashes = self.validate_hash_providers()?; - self.validate_prf_provider(&validated_hashes)?; + self.validate_prf(&validated_hashes)?; self.validate_signature_verifier(&validated_hashes)?; self.validate_hmac_provider()?; self.validate_dtls13_cipher_suites()?; @@ -149,43 +149,37 @@ impl CryptoProvider { Ok(required_hashes) } - /// Validate that PRF provider works for every supported hash algorithm. - fn validate_prf_provider(&self, validated_hashes: &[HashAlgorithm]) -> Result<(), Error> { - // Test PRF with known test vector (RFC 5246 test vector) - // PRF(secret, label, seed) should be deterministic + /// Validate that PRF (via HMAC) works for every supported hash algorithm. + fn validate_prf(&self, validated_hashes: &[HashAlgorithm]) -> Result<(), Error> { let secret = b"test_secret"; let label = "test label"; let seed = b"test_seed"; let output_len = 32; - // Test PRF for each validated hash algorithm for &hash_alg in validated_hashes { let mut result = Buf::new(); let mut scratch = Buf::new(); - self.prf_provider - .prf_tls12( - secret, - label, - seed, - &mut result, - output_len, - &mut scratch, - hash_alg, - ) - .map_err(|e| { - Error::ConfigError(format!("PRF provider failed for {:?}: {}", hash_alg, e)) - })?; + super::prf_hkdf::prf_tls12( + self.hmac_provider, + secret, + label, + seed, + &mut result, + output_len, + &mut scratch, + hash_alg, + ) + .map_err(|e| Error::ConfigError(format!("PRF failed for {:?}: {}", hash_alg, e)))?; if result.len() != output_len { return Err(Error::ConfigError(format!( - "PRF provider {:?} returned wrong length: expected {}, got {}", + "PRF {:?} returned wrong length: expected {}, got {}", hash_alg, output_len, result.len() ))); } - // Verify the exact output matches expected test vector let maybe_expected = PRF_TEST_VECTORS .iter() .find(|(h, _)| *h == hash_alg) @@ -200,7 +194,7 @@ impl CryptoProvider { if result.as_ref() != *expected { return Err(Error::ConfigError(format!( - "PRF provider {:?} produced incorrect result", + "PRF {:?} produced incorrect result", hash_alg ))); } @@ -260,25 +254,24 @@ impl CryptoProvider { )); } - // Verify HKDF works for each DTLS 1.3 cipher suite's hash algorithm + // Verify HKDF (via HMAC) works for each DTLS 1.3 cipher suite's hash algorithm for cs in self.dtls13_cipher_suites { let hash = cs.suite().hash_algorithm(); let hash_len = hash.output_len(); let zeros = [0u8; 48]; let zeros = &zeros[..hash_len]; let mut out = Buf::new(); - self.hkdf_provider - .hkdf_extract(hash, zeros, zeros, &mut out) + super::prf_hkdf::hkdf_extract(self.hmac_provider, hash, zeros, zeros, &mut out) .map_err(|e| { Error::ConfigError(format!( - "HKDF provider failed for DTLS 1.3 suite {:?}: {}", + "HKDF failed for DTLS 1.3 suite {:?}: {}", cs.suite(), e )) })?; if out.is_empty() { return Err(Error::ConfigError(format!( - "HKDF provider returned empty output for {:?}", + "HKDF returned empty output for {:?}", cs.suite() ))); } diff --git a/src/dtls12/context.rs b/src/dtls12/context.rs index 58887b84..59a7fcda 100644 --- a/src/dtls12/context.rs +++ b/src/dtls12/context.rs @@ -211,7 +211,8 @@ impl CryptoContext { let Some(pms) = &self.pre_master_secret else { return Err("Pre-master secret not available".to_string()); }; - self.provider().prf_provider.prf_tls12( + crypto::prf_hkdf::prf_tls12( + self.provider().hmac_provider, pms, "extended master secret", session_hash, @@ -276,7 +277,8 @@ impl CryptoContext { seed[32..].copy_from_slice(client_random); // Generate key material using PRF - self.provider().prf_provider.prf_tls12( + crypto::prf_hkdf::prf_tls12( + self.provider().hmac_provider, master_secret, "key expansion", &seed, @@ -418,7 +420,8 @@ impl CryptoContext { }; // Generate 12 bytes of verify data using PRF - self.provider().prf_provider.prf_tls12( + crypto::prf_hkdf::prf_tls12( + self.provider().hmac_provider, master_secret, label, handshake_hash, @@ -468,7 +471,8 @@ impl CryptoContext { seed.try_extend_from_slice(server_random) .expect("server_random too long"); - self.provider().prf_provider.prf_tls12( + crypto::prf_hkdf::prf_tls12( + self.provider().hmac_provider, master_secret, DTLS_SRTP_KEY_LABEL, &seed, diff --git a/src/dtls13/engine.rs b/src/dtls13/engine.rs index f22961c3..fb3074a8 100644 --- a/src/dtls13/engine.rs +++ b/src/dtls13/engine.rs @@ -9,11 +9,12 @@ use super::queue::{QueueRx, QueueTx}; use crate::buffer::{Buf, BufferPool, TmpBuf}; use crate::crypto::Aad; use crate::crypto::Cipher; -use crate::crypto::HkdfProvider; +use crate::crypto::HmacProvider; use crate::crypto::Nonce; use crate::crypto::SigningKey; use crate::crypto::SupportedDtls13CipherSuite; use crate::crypto::SupportedKxGroup; +use crate::crypto::prf_hkdf; use crate::dtls13::incoming::{Incoming, RecordDecrypt}; use crate::dtls13::message::Body; use crate::dtls13::message::ContentType; @@ -1542,8 +1543,8 @@ impl Engine { // Key Schedule // ========================================================================= - fn hkdf(&self) -> &dyn HkdfProvider { - self.config.crypto_provider().hkdf_provider + fn hmac(&self) -> &dyn HmacProvider { + self.config.crypto_provider().hmac_provider } fn hash_algorithm(&self) -> HashAlgorithm { @@ -1569,8 +1570,7 @@ impl Engine { let zeros = [0u8; 48]; let zeros = &zeros[..hash_len]; let mut early_secret = self.buffers_free.pop(); - self.hkdf() - .hkdf_extract(hash, zeros, zeros, &mut early_secret) + prf_hkdf::hkdf_extract(self.hmac(), hash, zeros, zeros, &mut early_secret) .map_err(|e| Error::CryptoError(format!("Failed to derive early secret: {}", e)))?; Ok(early_secret) } @@ -1582,17 +1582,18 @@ impl Engine { &mut self, shared_secret: &[u8], ) -> Result<(Buf, Buf, Buf), Error> { - // Call derive_early_secret first (needs &mut self) before borrowing hkdf + // Call derive_early_secret first (needs &mut self) before borrowing hmac let early_secret = self.derive_early_secret()?; let hash = self.hash_algorithm(); let hash_len = hash.output_len(); - let hkdf = self.hkdf(); + let hmac = self.hmac(); // Derive-Secret(early_secret, "derived", "") let empty_hash = self.transcript_hash_of(b""); let mut derived = Buf::new(); - hkdf.hkdf_expand_label_dtls13( + prf_hkdf::hkdf_expand_label_dtls13( + hmac, hash, &early_secret, b"derived", @@ -1604,7 +1605,7 @@ impl Engine { // handshake_secret = HKDF-Extract(derived, shared_secret) let mut handshake_secret = Buf::new(); - hkdf.hkdf_extract(hash, &derived, shared_secret, &mut handshake_secret) + prf_hkdf::hkdf_extract(hmac, hash, &derived, shared_secret, &mut handshake_secret) .map_err(|e| Error::CryptoError(format!("Failed to derive handshake secret: {}", e)))?; // Get transcript hash up to and including ServerHello @@ -1613,7 +1614,8 @@ impl Engine { // client_handshake_traffic_secret let mut c_hs_traffic = Buf::new(); - hkdf.hkdf_expand_label_dtls13( + prf_hkdf::hkdf_expand_label_dtls13( + hmac, hash, &handshake_secret, b"c hs traffic", @@ -1625,7 +1627,8 @@ impl Engine { // server_handshake_traffic_secret let mut s_hs_traffic = Buf::new(); - hkdf.hkdf_expand_label_dtls13( + prf_hkdf::hkdf_expand_label_dtls13( + hmac, hash, &handshake_secret, b"s hs traffic", @@ -1668,12 +1671,13 @@ impl Engine { ) -> Result<(Buf, Buf), Error> { let hash = self.hash_algorithm(); let hash_len = hash.output_len(); - let hkdf = self.hkdf(); + let hmac = self.hmac(); // Derive-Secret(handshake_secret, "derived", "") let empty_hash = self.transcript_hash_of(b""); let mut derived = Buf::new(); - hkdf.hkdf_expand_label_dtls13( + prf_hkdf::hkdf_expand_label_dtls13( + hmac, hash, handshake_secret, b"derived", @@ -1687,7 +1691,7 @@ impl Engine { let zeros = [0u8; 48]; let zeros = &zeros[..hash_len]; let mut master_secret = Buf::new(); - hkdf.hkdf_extract(hash, &derived, zeros, &mut master_secret) + prf_hkdf::hkdf_extract(hmac, hash, &derived, zeros, &mut master_secret) .map_err(|e| Error::CryptoError(format!("Failed to derive master secret: {}", e)))?; // Get transcript hash up to and including server Finished @@ -1696,7 +1700,8 @@ impl Engine { // exporter_master_secret = Derive-Secret(master_secret, "exp master", transcript_hash) let mut exp_master = Buf::new(); - hkdf.hkdf_expand_label_dtls13( + prf_hkdf::hkdf_expand_label_dtls13( + hmac, hash, &master_secret, b"exp master", @@ -1710,7 +1715,8 @@ impl Engine { // client_application_traffic_secret_0 let mut c_ap_traffic = Buf::new(); - hkdf.hkdf_expand_label_dtls13( + prf_hkdf::hkdf_expand_label_dtls13( + hmac, hash, &master_secret, b"c ap traffic", @@ -1722,7 +1728,8 @@ impl Engine { // server_application_traffic_secret_0 let mut s_ap_traffic = Buf::new(); - hkdf.hkdf_expand_label_dtls13( + prf_hkdf::hkdf_expand_label_dtls13( + hmac, hash, &master_secret, b"s ap traffic", @@ -1732,7 +1739,7 @@ impl Engine { ) .map_err(|e| Error::CryptoError(format!("Failed to derive s_ap_traffic: {}", e)))?; - // Store exporter master secret (deferred to avoid borrow conflict with hkdf) + // Store exporter master secret (deferred to avoid borrow conflict with hmac) self.exporter_master_secret = Some(exp_master); Ok((c_ap_traffic, s_ap_traffic)) @@ -1772,13 +1779,19 @@ impl Engine { fn derive_next_traffic_secret(&self, current: &Buf) -> Result { let hash = self.hash_algorithm(); let hash_len = hash.output_len(); - let hkdf = self.hkdf(); + let hmac = self.hmac(); let mut next = Buf::new(); - hkdf.hkdf_expand_label_dtls13(hash, current, b"traffic upd", &[], &mut next, hash_len) - .map_err(|e| { - Error::CryptoError(format!("Failed to derive next traffic secret: {}", e)) - })?; + prf_hkdf::hkdf_expand_label_dtls13( + hmac, + hash, + current, + b"traffic upd", + &[], + &mut next, + hash_len, + ) + .map_err(|e| Error::CryptoError(format!("Failed to derive next traffic secret: {}", e)))?; Ok(next) } @@ -1898,16 +1911,25 @@ impl Engine { fn derive_epoch_keys(&self, traffic_secret: &Buf) -> Result { let hash = self.hash_algorithm(); let suite = self.suite_provider(); - let hkdf = self.hkdf(); + let hmac = self.hmac(); // key = HKDF-Expand-Label(secret, "key", "", key_length) let mut key = Buf::new(); - hkdf.hkdf_expand_label_dtls13(hash, traffic_secret, b"key", &[], &mut key, suite.key_len()) - .map_err(|e| Error::CryptoError(format!("Failed to derive key: {}", e)))?; + prf_hkdf::hkdf_expand_label_dtls13( + hmac, + hash, + traffic_secret, + b"key", + &[], + &mut key, + suite.key_len(), + ) + .map_err(|e| Error::CryptoError(format!("Failed to derive key: {}", e)))?; // iv = HKDF-Expand-Label(secret, "iv", "", iv_length) let mut iv_buf = Buf::new(); - hkdf.hkdf_expand_label_dtls13( + prf_hkdf::hkdf_expand_label_dtls13( + hmac, hash, traffic_secret, b"iv", @@ -1919,7 +1941,8 @@ impl Engine { // sn_key = HKDF-Expand-Label(secret, "sn", "", key_length) let mut sn_key = Buf::new(); - hkdf.hkdf_expand_label_dtls13( + prf_hkdf::hkdf_expand_label_dtls13( + hmac, hash, traffic_secret, b"sn", @@ -1959,11 +1982,12 @@ impl Engine { pub fn compute_verify_data(&self, traffic_secret: &[u8]) -> Result { let hash = self.hash_algorithm(); let hash_len = hash.output_len(); - let hkdf = self.hkdf(); + let hmac = self.hmac(); // finished_key = HKDF-Expand-Label(secret, "finished", "", Hash.len) let mut finished_key = Buf::new(); - hkdf.hkdf_expand_label_dtls13( + prf_hkdf::hkdf_expand_label_dtls13( + hmac, hash, traffic_secret, b"finished", @@ -1978,8 +2002,14 @@ impl Engine { // HMAC(finished_key, transcript_hash) via HKDF-Extract(salt=key, IKM=data) let mut verify_data = Buf::new(); - hkdf.hkdf_extract(hash, &finished_key, &transcript_hash, &mut verify_data) - .map_err(|e| Error::CryptoError(format!("Failed to compute verify data: {}", e)))?; + prf_hkdf::hkdf_extract( + hmac, + hash, + &finished_key, + &transcript_hash, + &mut verify_data, + ) + .map_err(|e| Error::CryptoError(format!("Failed to compute verify data: {}", e)))?; Ok(verify_data) } @@ -2119,7 +2149,7 @@ impl Engine { ) -> Result<(ArrayVec, crate::crypto::SrtpProfile), Error> { let hash = self.hash_algorithm(); let hash_len = hash.output_len(); - let hkdf = self.hkdf(); + let hmac = self.hmac(); let exp_master = self.exporter_master_secret.as_ref().ok_or_else(|| { Error::CryptoError("Exporter master secret not yet derived".to_string()) @@ -2131,7 +2161,8 @@ impl Engine { // 1. derived_secret = Derive-Secret(exporter_master_secret, label, "") let empty_hash = self.transcript_hash_of(b""); let mut derived = Buf::new(); - hkdf.hkdf_expand_label_dtls13( + prf_hkdf::hkdf_expand_label_dtls13( + hmac, hash, exp_master, b"EXTRACTOR-dtls_srtp", @@ -2144,7 +2175,8 @@ impl Engine { // 2. result = HKDF-Expand-Label(derived_secret, "exporter", Hash(context), length) let context_hash = self.transcript_hash_of(b""); let mut keying_material_buf = Buf::new(); - hkdf.hkdf_expand_label_dtls13( + prf_hkdf::hkdf_expand_label_dtls13( + hmac, hash, &derived, b"exporter", @@ -2439,13 +2471,14 @@ mod tests { // Verify the handshake_secret can reproduce the same c_hs_traffic let hash = engine.hash_algorithm(); let hash_len = hash.output_len(); - let hkdf = engine.hkdf(); + let hmac = engine.hmac(); let mut transcript_hash = Buf::new(); engine.transcript_hash(&mut transcript_hash); let mut c_hs_manual = Buf::new(); - hkdf.hkdf_expand_label_dtls13( + prf_hkdf::hkdf_expand_label_dtls13( + hmac, hash, &handshake_secret, b"c hs traffic",