From 86eabbaa1abb640ba89f8be9e8d51bb53c84288c Mon Sep 17 00:00:00 2001 From: peg Date: Mon, 23 Mar 2026 10:04:11 +0100 Subject: [PATCH 1/4] Add measurement header injection --- Cargo.lock | 56 ++++++++++--- Cargo.toml | 8 +- src/http_version.rs | 15 ++-- src/lib.rs | 190 +++++++++++++++++++++++++++++++++++++------- 4 files changed, 220 insertions(+), 49 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a169090..45f4bd6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -590,6 +590,38 @@ dependencies = [ "x509-parser 0.18.1", ] +[[package]] +name = "attestation" +version = "0.0.1" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#ed0a0c5125562b45631a2522d1212b4ece143393" +dependencies = [ + "anyhow", + "az-tdx-vtpm", + "base64 0.22.1", + "configfs-tsm", + "dcap-qvl 0.3.12 (git+https://github.com/flashbots/dcap-qvl.git?branch=peg%2Fazure-outdated-tcp-override)", + "hex", + "http", + "num-bigint", + "once_cell", + "openssl", + "parity-scale-codec", + "pem-rfc7468", + "rand_core 0.6.4", + "reqwest", + "rustls-webpki", + "serde", + "serde_json", + "tdx-quote", + "thiserror 2.0.17", + "time", + "tokio", + "tokio-rustls", + "tracing", + "tss-esapi", + "x509-parser 0.18.1", +] + [[package]] name = "attestation-provider-server" version = "0.1.0" @@ -612,7 +644,7 @@ version = "0.0.1" dependencies = [ "alloy-rpc-client", "alloy-transport-http", - "attestation", + "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate)", "bytes", "futures-util", "http", @@ -638,10 +670,10 @@ dependencies = [ [[package]] name = "attested-tls" version = "0.0.1" -source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate#5c109dba74d4f9de58b4b846f480599752dfb1f9" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#ed0a0c5125562b45631a2522d1212b4ece143393" dependencies = [ "anyhow", - "attestation", + "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier)", "ra-tls", "rcgen 0.14.7", "rustls", @@ -659,8 +691,8 @@ name = "attested-tls-proxy" version = "1.1.1" dependencies = [ "anyhow", - "attestation", - "attested-tls 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate)", + "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier)", + "attested-tls 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier)", "axum", "bytes", "clap", @@ -1072,7 +1104,7 @@ dependencies = [ [[package]] name = "cc-eventlog" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +source = "git+https://github.com/Dstack-TEE/dstack.git#f87c97728ad222a3f3553cf0fb756830f7634eb6" dependencies = [ "anyhow", "digest 0.10.7", @@ -1661,7 +1693,7 @@ dependencies = [ [[package]] name = "dstack-attest" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +source = "git+https://github.com/Dstack-TEE/dstack.git#f87c97728ad222a3f3553cf0fb756830f7634eb6" dependencies = [ "anyhow", "cc-eventlog", @@ -1687,7 +1719,7 @@ dependencies = [ [[package]] name = "dstack-types" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +source = "git+https://github.com/Dstack-TEE/dstack.git#f87c97728ad222a3f3553cf0fb756830f7634eb6" dependencies = [ "parity-scale-codec", "serde", @@ -2976,7 +3008,7 @@ dependencies = [ [[package]] name = "nested-tls" version = "0.0.1" -source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate#5c109dba74d4f9de58b4b846f480599752dfb1f9" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#ed0a0c5125562b45631a2522d1212b4ece143393" dependencies = [ "rustls", "tokio", @@ -3673,7 +3705,7 @@ checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" [[package]] name = "ra-tls" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +source = "git+https://github.com/Dstack-TEE/dstack.git#f87c97728ad222a3f3553cf0fb756830f7634eb6" dependencies = [ "anyhow", "bon", @@ -4480,7 +4512,7 @@ checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" [[package]] name = "size-parser" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +source = "git+https://github.com/Dstack-TEE/dstack.git#f87c97728ad222a3f3553cf0fb756830f7634eb6" dependencies = [ "anyhow", "serde", @@ -4671,7 +4703,7 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "tdx-attest" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +source = "git+https://github.com/Dstack-TEE/dstack.git#f87c97728ad222a3f3553cf0fb756830f7634eb6" dependencies = [ "anyhow", "cc-eventlog", diff --git a/Cargo.toml b/Cargo.toml index c285277..17b5749 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,9 +11,9 @@ repository = "https://github.com/flashbots/attested-tls-proxy" keywords = ["attested-TLS", "CVM", "TDX"] [dependencies] -attested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate" } -nested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate" } -attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate" } +attested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-expose-cert-verifier" } +nested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-expose-cert-verifier" } +attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-expose-cert-verifier" } tokio = { version = "1.50.0", features = ["full"] } tokio-rustls = { version = "0.26.4", default-features = false } x509-parser = { version = "0.18.0", features = ["verify"] } @@ -47,7 +47,7 @@ pin-project-lite = "0.2.16" [dev-dependencies] tempfile = "3.23.0" tdx-quote = { version = "0.0.5", features = ["mock"] } -attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate", features = ["mock"] } +attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-expose-cert-verifier", features = ["mock"] } tokio = { version = "1.48.0", features = ["full"] } jsonrpsee = { version = "0.26.0", features = ["server"] } diff --git a/src/http_version.rs b/src/http_version.rs index 901df66..d2f0af2 100644 --- a/src/http_version.rs +++ b/src/http_version.rs @@ -1,6 +1,7 @@ //! HTTP Version support and negotiation use hyper::Response; use hyper_util::rt::TokioIo; +use bytes::Bytes; use std::pin::Pin; use std::task::{Context, Poll}; @@ -55,15 +56,18 @@ impl HttpVersion { } } -type Http1Sender = hyper::client::conn::http1::SendRequest; -type Http2Sender = hyper::client::conn::http2::SendRequest; +type Http1Sender = hyper::client::conn::http1::SendRequest>; +type Http2Sender = hyper::client::conn::http2::SendRequest>; type Http1Connection = - hyper::client::conn::http1::Connection, hyper::body::Incoming>; + hyper::client::conn::http1::Connection< + TokioIo, + http_body_util::Full, + >; type Http2Connection = hyper::client::conn::http2::Connection< TokioIo, - hyper::body::Incoming, + http_body_util::Full, crate::TokioExecutor, >; @@ -88,8 +92,9 @@ impl From for HttpSender { impl HttpSender { pub async fn send_request( &mut self, - request: http::Request, + request: http::Request, ) -> Result, hyper::Error> { + let request = request.map(http_body_util::Full::new); match self { Self::Http1(sender) => sender.send_request(request).await, Self::Http2(sender) => sender.send_request(request).await, diff --git a/src/lib.rs b/src/lib.rs index 88aa200..8e44005 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,7 +12,7 @@ mod http_version; #[cfg(test)] mod test_helpers; -use attestation::{AttestationError, AttestationVerifier}; +use attestation::{AttestationError, AttestationExchangeMessage, AttestationVerifier}; use attested_tls::{AttestedCertificateResolver, AttestedCertificateVerifier, AttestedTlsError}; use bytes::Bytes; use http::{HeaderMap, HeaderName, HeaderValue}; @@ -36,6 +36,12 @@ use tracing::{debug, error, warn}; use crate::http_version::{ALPN_H2, ALPN_HTTP11, HttpConnection, HttpSender, HttpVersion}; +/// The header name for giving attestation type +const ATTESTATION_TYPE_HEADER: &str = "X-Flashbots-Attestation-Type"; + +/// The header name for giving measurements +const MEASUREMENT_HEADER: &str = "X-Flashbots-Measurement"; + /// The header name for giving the forwarded for IP static X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for"); @@ -48,7 +54,7 @@ const SERVER_RECONNECT_MAX_BACKOFF_SECS: u64 = 120; const KEEP_ALIVE_INTERVAL: u64 = 30; const KEEP_ALIVE_TIMEOUT: u64 = 10; type RequestWithResponseSender = ( - http::Request, + http::Request, oneshot::Sender>, hyper::Error>>, ); @@ -399,8 +405,22 @@ impl ProxyServer { ) -> Result<(), ProxyError> { debug!("[proxy-server] accepted connection"); + // Get attestation from the remote certificate from the inner session, if present. + let attestation = { + let (_io, server_connection) = tls_stream.get_ref(); + + match server_connection.peer_certificates() { + Some(remote_cert_chain) => Some( + AttestedCertificateVerifier::extract_custom_attestation_from_cert( + remote_cert_chain.first().ok_or(ProxyError::NoCertificate)?, + )?, + ), + None => None, + } + }; + let http_version = HttpVersion::from_negotiated_protocol_server(&tls_stream); - Self::serve_tls_stream(tls_stream, http_version, target, client_addr).await + Self::serve_tls_stream(tls_stream, http_version, target, client_addr, attestation).await } async fn handle_inner_connection( @@ -410,8 +430,22 @@ impl ProxyServer { ) -> Result<(), ProxyError> { debug!("[proxy-server] accepted inner-only connection"); + // Get attestation from the remote certificate, if present + let attestation = { + let (_io, server_connection) = tls_stream.get_ref(); + + match server_connection.peer_certificates() { + Some(remote_cert_chain) => Some( + AttestedCertificateVerifier::extract_custom_attestation_from_cert( + remote_cert_chain.first().ok_or(ProxyError::NoCertificate)?, + )?, + ), + None => None, + } + }; + let http_version = HttpVersion::from_negotiated_protocol_server(&tls_stream); - Self::serve_tls_stream(tls_stream, http_version, target, client_addr).await + Self::serve_tls_stream(tls_stream, http_version, target, client_addr, attestation).await } async fn serve_tls_stream( @@ -419,10 +453,19 @@ impl ProxyServer { http_version: HttpVersion, target: String, client_addr: SocketAddr, + attestation: Option, ) -> Result<(), ProxyError> where IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, { + let (remote_attestation_type, measurements) = match attestation { + Some(attestation) => ( + Some(attestation.attestation_type), + attestation.get_measurements()?, + ), + None => (None, None), + }; + // Setup a request handler let service = service_fn(move |mut req| { debug!("[proxy-server] Handling request {req:?}"); @@ -447,6 +490,30 @@ impl ProxyServer { update_header(headers, &X_FORWARDED_FOR, &new_x_forwarded_for); + // If we have measurements, from the remote peer, add them to the request header + let measurements = measurements.clone(); + + if let Some(measurements) = measurements { + match measurements.to_header_format() { + Ok(header_value) => { + headers.insert(MEASUREMENT_HEADER, header_value); + } + Err(e) => { + // This error is highly unlikely - that the measurement values fail to + // encode to JSON or fit in an HTTP header + error!("Failed to encode measurement values: {e}"); + } + } + } + + if let Some(remote_attestation_type) = remote_attestation_type { + update_header( + headers, + ATTESTATION_TYPE_HEADER, + remote_attestation_type.as_str(), + ); + } + let target = target.clone(); async move { match Self::handle_http_request(req, target).await { @@ -635,7 +702,7 @@ impl ProxyClient { // Channel for getting incoming requests from the source client let (requests_tx, mut requests_rx) = mpsc::channel::<( - http::Request, + http::Request, oneshot::Sender< Result>, hyper::Error>, >, @@ -648,7 +715,7 @@ impl ProxyClient { let mut first = true; let mut ready_tx = Some(ready_tx); 'reconnect: loop { - let (mut sender, conn) = + let (mut sender, conn, attestation) = // Connect to the proxy server and provide / verify attestation match Self::setup_connection_with_backoff(&target, &nesting_tls_connector, first) .await @@ -678,6 +745,9 @@ impl ProxyClient { let (conn_done_tx, mut conn_done_rx) = tokio::sync::watch::channel::>(None); + let mut remote_attestation_type = attestation.attestation_type; + let mut measurements = attestation.get_measurements().ok().flatten(); + tokio::spawn(async move { let res = conn.await; let _ = conn_done_tx.send(res.err()); @@ -689,17 +759,69 @@ impl ProxyClient { if let Some((req, response_tx)) = incoming_req_option { debug!("[proxy-client] Read incoming request from source client: {req:?}"); // Attempt to forward it to the proxy server - let (response, should_reconnect) = match sender.send_request(req).await { - Ok(resp) => { + let response = loop { + match sender.send_request(req.clone()).await { + Ok(mut resp) => { debug!("[proxy-client] Read response from proxy-server: {resp:?}"); - (Ok(resp.map(|b| b.boxed())), false) - } - Err(e) => { + // If we have measurements from the proxy-server, inject them into the + // response header + let headers = resp.headers_mut(); + if let Some(measurements) = measurements.clone() { + match measurements.to_header_format() { + Ok(header_value) => { + headers.insert(MEASUREMENT_HEADER, header_value); + } + Err(e) => { + // This error is highly unlikely - that the measurement values fail to + // encode to JSON or fit in an HTTP header + error!("Failed to encode measurement values: {e}"); + } + } + } + + update_header( + headers, + ATTESTATION_TYPE_HEADER, + remote_attestation_type.as_str(), + ); + break Ok(resp.map(|b| b.boxed())); + } + Err(e) => { warn!("Failed to send request to proxy-server: {e}"); - let mut resp = Response::new(full(format!("Request failed: {e}"))); - *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; - - (Ok(resp), true) + match Self::setup_connection_with_backoff( + &target, + &nesting_tls_connector, + false, + ) + .await + { + Ok((new_sender, new_conn, new_attestation)) => { + sender = new_sender; + remote_attestation_type = new_attestation.attestation_type; + measurements = new_attestation.get_measurements().ok().flatten(); + + let (new_conn_done_tx, new_conn_done_rx) = + tokio::sync::watch::channel::>(None); + conn_done_rx = new_conn_done_rx; + + tokio::spawn(async move { + let res = new_conn.await; + let _ = new_conn_done_tx.send(res.err()); + }); + + warn!("Reconnected to proxy-server, retrying request"); + continue; + } + Err(reconnect_err) => { + warn!("Reconnect after request failure failed: {reconnect_err}"); + let mut resp = Response::new(full(format!( + "Request failed: {e}" + ))); + *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; + break Ok(resp); + } + } + } } }; @@ -707,12 +829,6 @@ impl ProxyClient { if response_tx.send(response).is_err() { warn!("Failed to forward response to source client, probably they dropped the connection"); } - - if should_reconnect { - // Leave the inner loop and continue on the reconnect loop - warn!("Reconnecting to proxy-server due to failed request"); - break; - } } else { // The request sender was dropped - so no more incoming requests debug!("Request sender dropped - leaving connection handler loop"); @@ -799,7 +915,7 @@ impl ProxyClient { target: &str, nesting_tls_connector: &NestingTlsConnector, should_bail: bool, - ) -> Result<(HttpSender, HttpConnection), ProxyError> { + ) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> { let mut delay = Duration::from_secs(1); let max_delay = Duration::from_secs(SERVER_RECONNECT_MAX_BACKOFF_SECS); @@ -828,15 +944,29 @@ impl ProxyClient { async fn setup_connection( nesting_tls_connector: &NestingTlsConnector, target: &str, - ) -> Result<(HttpSender, HttpConnection), ProxyError> { + ) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> { let outbound_stream = tokio::net::TcpStream::connect(target).await?; let domain = server_name_from_host(target)?; let tls_stream = nesting_tls_connector .connect(domain, outbound_stream) .await?; + debug!("[proxy-client] Connected to proxy server"); + // Get attestation from session + let attestation = { + let (_io, server_connection) = tls_stream.get_ref(); + + let remote_cert_chain = server_connection + .peer_certificates() + .ok_or(ProxyError::NoCertificate)?; + + AttestedCertificateVerifier::extract_custom_attestation_from_cert( + remote_cert_chain.first().ok_or(ProxyError::NoCertificate)?, + )? + }; + // The attestation exchange is now complete - setup an HTTP client let http_version = HttpVersion::from_negotiated_protocol_client(&tls_stream); @@ -848,20 +978,20 @@ impl ProxyClient { .keep_alive_interval(Some(Duration::from_secs(KEEP_ALIVE_INTERVAL))) .keep_alive_timeout(Duration::from_secs(KEEP_ALIVE_TIMEOUT)) .keep_alive_while_idle(true) - .handshake::<_, hyper::body::Incoming>(outbound_io) + .handshake::<_, http_body_util::Full>(outbound_io) .await?; (sender.into(), conn.into()) } HttpVersion::Http1 => { let (sender, conn) = hyper::client::conn::http1::Builder::new() - .handshake::<_, hyper::body::Incoming>(outbound_io) + .handshake::<_, http_body_util::Full>(outbound_io) .await?; (sender.into(), conn.into()) } }; - // Return the HTTP client, as well as remote measurements - Ok((sender, conn)) + // Return the HTTP client, as well as remote attestation + Ok((sender, conn, attestation)) } // Handle a request from the source client to the proxy server @@ -869,6 +999,10 @@ impl ProxyClient { req: hyper::Request, requests_tx: mpsc::Sender, ) -> Result>, ProxyError> { + let (parts, body) = req.into_parts(); + let body = body.collect().await?.to_bytes(); + let req = http::Request::from_parts(parts, body); + let (response_tx, response_rx) = oneshot::channel(); requests_tx.send((req, response_tx)).await?; Ok(response_rx.await??) @@ -1230,7 +1364,7 @@ mod tests { let nesting_tls_connector = NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); - let (sender, conn) = ProxyClient::setup_connection( + let (sender, conn, _attestation) = ProxyClient::setup_connection( &nesting_tls_connector, &format!("localhost:{}", proxy_addr.port()), ) From d45a353b90ffa2e713caa5e0258a97825179d9eb Mon Sep 17 00:00:00 2001 From: peg Date: Mon, 23 Mar 2026 10:37:16 +0100 Subject: [PATCH 2/4] Fix re-connection bug --- src/http_version.rs | 11 ++-- src/lib.rs | 131 +++++++++++++++++++++++++------------------- 2 files changed, 79 insertions(+), 63 deletions(-) diff --git a/src/http_version.rs b/src/http_version.rs index d2f0af2..91948c3 100644 --- a/src/http_version.rs +++ b/src/http_version.rs @@ -1,7 +1,7 @@ //! HTTP Version support and negotiation +use bytes::Bytes; use hyper::Response; use hyper_util::rt::TokioIo; -use bytes::Bytes; use std::pin::Pin; use std::task::{Context, Poll}; @@ -59,11 +59,10 @@ impl HttpVersion { type Http1Sender = hyper::client::conn::http1::SendRequest>; type Http2Sender = hyper::client::conn::http2::SendRequest>; -type Http1Connection = - hyper::client::conn::http1::Connection< - TokioIo, - http_body_util::Full, - >; +type Http1Connection = hyper::client::conn::http1::Connection< + TokioIo, + http_body_util::Full, +>; type Http2Connection = hyper::client::conn::http2::Connection< TokioIo, diff --git a/src/lib.rs b/src/lib.rs index 8e44005..acf25f4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -410,11 +410,20 @@ impl ProxyServer { let (_io, server_connection) = tls_stream.get_ref(); match server_connection.peer_certificates() { - Some(remote_cert_chain) => Some( - AttestedCertificateVerifier::extract_custom_attestation_from_cert( - remote_cert_chain.first().ok_or(ProxyError::NoCertificate)?, - )?, - ), + Some(remote_cert_chain) => remote_cert_chain + .first() + .and_then(|cert| { + match AttestedCertificateVerifier::extract_custom_attestation_from_cert(cert) + { + Ok(attestation) => Some(attestation), + Err(err) => { + warn!( + "Failed to extract remote attestation from inner-session certificate: {err}" + ); + None + } + } + }), None => None, } }; @@ -435,11 +444,15 @@ impl ProxyServer { let (_io, server_connection) = tls_stream.get_ref(); match server_connection.peer_certificates() { - Some(remote_cert_chain) => Some( - AttestedCertificateVerifier::extract_custom_attestation_from_cert( - remote_cert_chain.first().ok_or(ProxyError::NoCertificate)?, - )?, - ), + Some(remote_cert_chain) => remote_cert_chain.first().and_then(|cert| { + match AttestedCertificateVerifier::extract_custom_attestation_from_cert(cert) { + Ok(attestation) => Some(attestation), + Err(err) => { + warn!("Failed to extract remote attestation from certificate: {err}"); + None + } + } + }), None => None, } }; @@ -461,7 +474,13 @@ impl ProxyServer { let (remote_attestation_type, measurements) = match attestation { Some(attestation) => ( Some(attestation.attestation_type), - attestation.get_measurements()?, + match attestation.get_measurements() { + Ok(measurements) => measurements, + Err(err) => { + warn!("Failed to extract measurements from peer attestation: {err}"); + None + } + }, ), None => (None, None), }; @@ -715,7 +734,7 @@ impl ProxyClient { let mut first = true; let mut ready_tx = Some(ready_tx); 'reconnect: loop { - let (mut sender, conn, attestation) = + let (mut sender, conn) = // Connect to the proxy server and provide / verify attestation match Self::setup_connection_with_backoff(&target, &nesting_tls_connector, first) .await @@ -745,9 +764,6 @@ impl ProxyClient { let (conn_done_tx, mut conn_done_rx) = tokio::sync::watch::channel::>(None); - let mut remote_attestation_type = attestation.attestation_type; - let mut measurements = attestation.get_measurements().ok().flatten(); - tokio::spawn(async move { let res = conn.await; let _ = conn_done_tx.send(res.err()); @@ -760,45 +776,60 @@ impl ProxyClient { debug!("[proxy-client] Read incoming request from source client: {req:?}"); // Attempt to forward it to the proxy server let response = loop { - match sender.send_request(req.clone()).await { - Ok(mut resp) => { - debug!("[proxy-client] Read response from proxy-server: {resp:?}"); - // If we have measurements from the proxy-server, inject them into the - // response header - let headers = resp.headers_mut(); - if let Some(measurements) = measurements.clone() { - match measurements.to_header_format() { - Ok(header_value) => { - headers.insert(MEASUREMENT_HEADER, header_value); + let send_result = tokio::select! { + result = sender.send_request(req.clone()) => result, + _ = conn_done_rx.changed() => { + warn!("Connection dropped while request was in flight"); + match Self::setup_connection_with_backoff( + &target, + &nesting_tls_connector, + true, + ) + .await + { + Ok((new_sender, new_conn)) => { + sender = new_sender; + + let (new_conn_done_tx, new_conn_done_rx) = + tokio::sync::watch::channel::>(None); + conn_done_rx = new_conn_done_rx; + + tokio::spawn(async move { + let res = new_conn.await; + let _ = new_conn_done_tx.send(res.err()); + }); + + warn!("Reconnected to proxy-server, retrying request"); + continue; } - Err(e) => { - // This error is highly unlikely - that the measurement values fail to - // encode to JSON or fit in an HTTP header - error!("Failed to encode measurement values: {e}"); + Err(reconnect_err) => { + warn!("Reconnect after in-flight drop failed: {reconnect_err}"); + let mut resp = Response::new(full( + "Request failed: connection to proxy-server dropped", + )); + *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; + break Ok(resp); } } } + }; - update_header( - headers, - ATTESTATION_TYPE_HEADER, - remote_attestation_type.as_str(), - ); + match send_result { + Ok(resp) => { + debug!("[proxy-client] Read response from proxy-server: {resp:?}"); break Ok(resp.map(|b| b.boxed())); } Err(e) => { - warn!("Failed to send request to proxy-server: {e}"); + warn!("Failed to send request to proxy-server: {e}"); match Self::setup_connection_with_backoff( &target, &nesting_tls_connector, - false, + true, ) .await { - Ok((new_sender, new_conn, new_attestation)) => { + Ok((new_sender, new_conn)) => { sender = new_sender; - remote_attestation_type = new_attestation.attestation_type; - measurements = new_attestation.get_measurements().ok().flatten(); let (new_conn_done_tx, new_conn_done_rx) = tokio::sync::watch::channel::>(None); @@ -915,7 +946,7 @@ impl ProxyClient { target: &str, nesting_tls_connector: &NestingTlsConnector, should_bail: bool, - ) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> { + ) -> Result<(HttpSender, HttpConnection), ProxyError> { let mut delay = Duration::from_secs(1); let max_delay = Duration::from_secs(SERVER_RECONNECT_MAX_BACKOFF_SECS); @@ -944,7 +975,7 @@ impl ProxyClient { async fn setup_connection( nesting_tls_connector: &NestingTlsConnector, target: &str, - ) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> { + ) -> Result<(HttpSender, HttpConnection), ProxyError> { let outbound_stream = tokio::net::TcpStream::connect(target).await?; let domain = server_name_from_host(target)?; @@ -954,19 +985,6 @@ impl ProxyClient { debug!("[proxy-client] Connected to proxy server"); - // Get attestation from session - let attestation = { - let (_io, server_connection) = tls_stream.get_ref(); - - let remote_cert_chain = server_connection - .peer_certificates() - .ok_or(ProxyError::NoCertificate)?; - - AttestedCertificateVerifier::extract_custom_attestation_from_cert( - remote_cert_chain.first().ok_or(ProxyError::NoCertificate)?, - )? - }; - // The attestation exchange is now complete - setup an HTTP client let http_version = HttpVersion::from_negotiated_protocol_client(&tls_stream); @@ -990,8 +1008,7 @@ impl ProxyClient { } }; - // Return the HTTP client, as well as remote attestation - Ok((sender, conn, attestation)) + Ok((sender, conn)) } // Handle a request from the source client to the proxy server @@ -1364,7 +1381,7 @@ mod tests { let nesting_tls_connector = NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); - let (sender, conn, _attestation) = ProxyClient::setup_connection( + let (sender, conn) = ProxyClient::setup_connection( &nesting_tls_connector, &format!("localhost:{}", proxy_addr.port()), ) From 7ed43a9a9e060f704ca9ab75bc3f4ef435110f26 Mon Sep 17 00:00:00 2001 From: peg Date: Mon, 23 Mar 2026 11:30:00 +0100 Subject: [PATCH 3/4] Fully restore measurement header injection --- Cargo.lock | 6 +-- src/lib.rs | 123 +++++++++++++++++++++++++++++++++++++++----- src/test_helpers.rs | 15 +++--- 3 files changed, 121 insertions(+), 23 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 45f4bd6..f04e874 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -593,7 +593,7 @@ dependencies = [ [[package]] name = "attestation" version = "0.0.1" -source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#ed0a0c5125562b45631a2522d1212b4ece143393" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#2e4273cd93670e705e789555c00d43ca9c1e4af2" dependencies = [ "anyhow", "az-tdx-vtpm", @@ -670,7 +670,7 @@ dependencies = [ [[package]] name = "attested-tls" version = "0.0.1" -source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#ed0a0c5125562b45631a2522d1212b4ece143393" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#2e4273cd93670e705e789555c00d43ca9c1e4af2" dependencies = [ "anyhow", "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier)", @@ -3008,7 +3008,7 @@ dependencies = [ [[package]] name = "nested-tls" version = "0.0.1" -source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#ed0a0c5125562b45631a2522d1212b4ece143393" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#2e4273cd93670e705e789555c00d43ca9c1e4af2" dependencies = [ "rustls", "tokio", diff --git a/src/lib.rs b/src/lib.rs index acf25f4..8d60d37 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -734,7 +734,7 @@ impl ProxyClient { let mut first = true; let mut ready_tx = Some(ready_tx); 'reconnect: loop { - let (mut sender, conn) = + let (mut sender, conn, attestation) = // Connect to the proxy server and provide / verify attestation match Self::setup_connection_with_backoff(&target, &nesting_tls_connector, first) .await @@ -764,6 +764,9 @@ impl ProxyClient { let (conn_done_tx, mut conn_done_rx) = tokio::sync::watch::channel::>(None); + let mut remote_attestation_type = attestation.attestation_type; + let mut measurements = attestation.get_measurements().ok().flatten(); + tokio::spawn(async move { let res = conn.await; let _ = conn_done_tx.send(res.err()); @@ -787,8 +790,10 @@ impl ProxyClient { ) .await { - Ok((new_sender, new_conn)) => { + Ok((new_sender, new_conn, new_attestation)) => { sender = new_sender; + remote_attestation_type = new_attestation.attestation_type; + measurements = new_attestation.get_measurements().ok().flatten(); let (new_conn_done_tx, new_conn_done_rx) = tokio::sync::watch::channel::>(None); @@ -815,8 +820,26 @@ impl ProxyClient { }; match send_result { - Ok(resp) => { + Ok(mut resp) => { debug!("[proxy-client] Read response from proxy-server: {resp:?}"); + let headers = resp.headers_mut(); + if let Some(measurements) = measurements.clone() { + match measurements.to_header_format() { + Ok(header_value) => { + headers.insert(MEASUREMENT_HEADER, header_value); + } + Err(e) => { + error!("Failed to encode measurement values: {e}"); + } + } + } + + update_header( + headers, + ATTESTATION_TYPE_HEADER, + remote_attestation_type.as_str(), + ); + break Ok(resp.map(|b| b.boxed())); } Err(e) => { @@ -828,8 +851,10 @@ impl ProxyClient { ) .await { - Ok((new_sender, new_conn)) => { + Ok((new_sender, new_conn, new_attestation)) => { sender = new_sender; + remote_attestation_type = new_attestation.attestation_type; + measurements = new_attestation.get_measurements().ok().flatten(); let (new_conn_done_tx, new_conn_done_rx) = tokio::sync::watch::channel::>(None); @@ -946,7 +971,7 @@ impl ProxyClient { target: &str, nesting_tls_connector: &NestingTlsConnector, should_bail: bool, - ) -> Result<(HttpSender, HttpConnection), ProxyError> { + ) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> { let mut delay = Duration::from_secs(1); let max_delay = Duration::from_secs(SERVER_RECONNECT_MAX_BACKOFF_SECS); @@ -975,7 +1000,7 @@ impl ProxyClient { async fn setup_connection( nesting_tls_connector: &NestingTlsConnector, target: &str, - ) -> Result<(HttpSender, HttpConnection), ProxyError> { + ) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> { let outbound_stream = tokio::net::TcpStream::connect(target).await?; let domain = server_name_from_host(target)?; @@ -985,6 +1010,18 @@ impl ProxyClient { debug!("[proxy-client] Connected to proxy server"); + let attestation = { + let (_io, server_connection) = tls_stream.get_ref(); + + let remote_cert_chain = server_connection + .peer_certificates() + .ok_or(ProxyError::NoCertificate)?; + + AttestedCertificateVerifier::extract_custom_attestation_from_cert( + remote_cert_chain.first().ok_or(ProxyError::NoCertificate)?, + )? + }; + // The attestation exchange is now complete - setup an HTTP client let http_version = HttpVersion::from_negotiated_protocol_client(&tls_stream); @@ -1008,7 +1045,7 @@ impl ProxyClient { } }; - Ok((sender, conn)) + Ok((sender, conn, attestation)) } // Handle a request from the source client to the proxy server @@ -1207,6 +1244,7 @@ where #[cfg(test)] mod tests { use attestation::{AttestationType, measurements::MeasurementPolicy}; + use std::collections::HashMap; use tokio_rustls::TlsConnector; use super::*; @@ -1215,6 +1253,43 @@ mod tests { generate_tls_config_with_client_auth, init_tracing, }; + fn expected_mock_measurements() -> HashMap { + let zero_measurement = "0".repeat(96); + HashMap::from([ + ("0".to_string(), zero_measurement.clone()), + ("1".to_string(), zero_measurement.clone()), + ("2".to_string(), zero_measurement.clone()), + ("3".to_string(), zero_measurement.clone()), + ("4".to_string(), zero_measurement), + ]) + } + + fn assert_mock_measurements(body: &str) { + let parsed: HashMap = serde_json::from_str(body).unwrap(); + assert_eq!(parsed, expected_mock_measurements()); + } + + fn assert_mock_measurements_header(headers: &http::HeaderMap) { + let body = headers + .get(MEASUREMENT_HEADER) + .and_then(|v| v.to_str().ok()) + .unwrap(); + assert_mock_measurements(body); + } + + fn assert_attestation_type_header(headers: &http::HeaderMap, expected: &str) { + assert_eq!( + headers + .get(ATTESTATION_TYPE_HEADER) + .and_then(|v| v.to_str().ok()), + Some(expected) + ); + } + + fn assert_no_measurements_header(headers: &http::HeaderMap) { + assert!(headers.get(MEASUREMENT_HEADER).is_none()); + } + #[test] fn proxy_alpn_protocols_prefer_http2() { let mut protocols = Vec::new(); @@ -1381,7 +1456,7 @@ mod tests { let nesting_tls_connector = NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); - let (sender, conn) = ProxyClient::setup_connection( + let (sender, conn, _attestation) = ProxyClient::setup_connection( &nesting_tls_connector, &format!("localhost:{}", proxy_addr.port()), ) @@ -1445,6 +1520,9 @@ mod tests { .await .unwrap(); + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + let res_body = res.text().await.unwrap(); assert_eq!(res_body, "No measurements"); } @@ -1513,8 +1591,11 @@ mod tests { .await .unwrap(); + assert_attestation_type_header(res.headers(), "none"); + assert_no_measurements_header(res.headers()); + let res_body = res.text().await.unwrap(); - assert_eq!(res_body, "No measurements"); + assert_mock_measurements(&res_body); } // Server has no attestation, client has mock DCAP but no client auth @@ -1574,7 +1655,11 @@ mod tests { .await .unwrap(); - let _res_body = res.text().await.unwrap(); + assert_attestation_type_header(res.headers(), "none"); + assert_no_measurements_header(res.headers()); + + let res_body = res.text().await.unwrap(); + assert_eq!(res_body, "No measurements"); } // Server has mock DCAP, client has mock DCAP and client auth @@ -1641,12 +1726,16 @@ mod tests { let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); - assert_eq!(res.text().await.unwrap(), "No measurements"); + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + assert_mock_measurements(&res.text().await.unwrap()); let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); - assert_eq!(res.text().await.unwrap(), "No measurements"); + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + assert_mock_measurements(&res.text().await.unwrap()); } // Server has mock DCAP, client no attestation - just get the server certificate @@ -1874,9 +1963,11 @@ mod tests { proxy_client.accept().await.unwrap(); }); - let _initial_response = reqwest::get(format!("http://{}", proxy_client_addr)) + let initial_response = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); + assert_attestation_type_header(initial_response.headers(), "dcap-tdx"); + assert_mock_measurements_header(initial_response.headers()); // Now break the connection connection_breaker_tx.send(()).unwrap(); @@ -1886,6 +1977,9 @@ mod tests { .await .unwrap(); + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + let res_body = res.text().await.unwrap(); assert_eq!(res_body, "No measurements"); } @@ -1945,6 +2039,9 @@ mod tests { .await .unwrap(); + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + let res_body = res.text().await.unwrap(); assert_eq!(res_body, "No measurements"); } diff --git a/src/test_helpers.rs b/src/test_helpers.rs index 431c5f8..b8509c0 100644 --- a/src/test_helpers.rs +++ b/src/test_helpers.rs @@ -12,6 +12,8 @@ use tokio_rustls::rustls::{ }; use tracing_subscriber::{EnvFilter, fmt}; +use crate::MEASUREMENT_HEADER; + static INIT: Once = Once::new(); /// Helper to generate a self-signed certificate for testing with a DNS subject name @@ -127,13 +129,12 @@ pub async fn example_http_service() -> SocketAddr { addr } -async fn get_handler(_headers: http::HeaderMap) -> impl IntoResponse { - // headers - // .get(MEASUREMENT_HEADER) - // .and_then(|v| v.to_str().ok()) - // .unwrap_or("No measurements") - // .to_string() - "No measurements".to_string() +async fn get_handler(headers: http::HeaderMap) -> impl IntoResponse { + headers + .get(MEASUREMENT_HEADER) + .and_then(|v| v.to_str().ok()) + .unwrap_or("No measurements") + .to_string() } pub fn init_tracing() { From a0478373522ec99bd95a39d6d901dd1c167988f8 Mon Sep 17 00:00:00 2001 From: peg Date: Mon, 23 Mar 2026 12:44:34 +0100 Subject: [PATCH 4/4] Use the same reconnect behavior as before --- src/http_version.rs | 16 +-- src/lib.rs | 268 +++++++++++++++++++++++++++----------------- 2 files changed, 170 insertions(+), 114 deletions(-) diff --git a/src/http_version.rs b/src/http_version.rs index 91948c3..901df66 100644 --- a/src/http_version.rs +++ b/src/http_version.rs @@ -1,5 +1,4 @@ //! HTTP Version support and negotiation -use bytes::Bytes; use hyper::Response; use hyper_util::rt::TokioIo; use std::pin::Pin; @@ -56,17 +55,15 @@ impl HttpVersion { } } -type Http1Sender = hyper::client::conn::http1::SendRequest>; -type Http2Sender = hyper::client::conn::http2::SendRequest>; +type Http1Sender = hyper::client::conn::http1::SendRequest; +type Http2Sender = hyper::client::conn::http2::SendRequest; -type Http1Connection = hyper::client::conn::http1::Connection< - TokioIo, - http_body_util::Full, ->; +type Http1Connection = + hyper::client::conn::http1::Connection, hyper::body::Incoming>; type Http2Connection = hyper::client::conn::http2::Connection< TokioIo, - http_body_util::Full, + hyper::body::Incoming, crate::TokioExecutor, >; @@ -91,9 +88,8 @@ impl From for HttpSender { impl HttpSender { pub async fn send_request( &mut self, - request: http::Request, + request: http::Request, ) -> Result, hyper::Error> { - let request = request.map(http_body_util::Full::new); match self { Self::Http1(sender) => sender.send_request(request).await, Self::Http2(sender) => sender.send_request(request).await, diff --git a/src/lib.rs b/src/lib.rs index 8d60d37..e7e012a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -54,7 +54,7 @@ const SERVER_RECONNECT_MAX_BACKOFF_SECS: u64 = 120; const KEEP_ALIVE_INTERVAL: u64 = 30; const KEEP_ALIVE_TIMEOUT: u64 = 10; type RequestWithResponseSender = ( - http::Request, + http::Request, oneshot::Sender>, hyper::Error>>, ); @@ -721,7 +721,7 @@ impl ProxyClient { // Channel for getting incoming requests from the source client let (requests_tx, mut requests_rx) = mpsc::channel::<( - http::Request, + http::Request, oneshot::Sender< Result>, hyper::Error>, >, @@ -764,8 +764,8 @@ impl ProxyClient { let (conn_done_tx, mut conn_done_rx) = tokio::sync::watch::channel::>(None); - let mut remote_attestation_type = attestation.attestation_type; - let mut measurements = attestation.get_measurements().ok().flatten(); + let remote_attestation_type = attestation.attestation_type; + let measurements = attestation.get_measurements().ok().flatten(); tokio::spawn(async move { let res = conn.await; @@ -778,106 +778,35 @@ impl ProxyClient { if let Some((req, response_tx)) = incoming_req_option { debug!("[proxy-client] Read incoming request from source client: {req:?}"); // Attempt to forward it to the proxy server - let response = loop { - let send_result = tokio::select! { - result = sender.send_request(req.clone()) => result, - _ = conn_done_rx.changed() => { - warn!("Connection dropped while request was in flight"); - match Self::setup_connection_with_backoff( - &target, - &nesting_tls_connector, - true, - ) - .await - { - Ok((new_sender, new_conn, new_attestation)) => { - sender = new_sender; - remote_attestation_type = new_attestation.attestation_type; - measurements = new_attestation.get_measurements().ok().flatten(); - - let (new_conn_done_tx, new_conn_done_rx) = - tokio::sync::watch::channel::>(None); - conn_done_rx = new_conn_done_rx; - - tokio::spawn(async move { - let res = new_conn.await; - let _ = new_conn_done_tx.send(res.err()); - }); - - warn!("Reconnected to proxy-server, retrying request"); - continue; + let (response, should_reconnect) = match sender.send_request(req).await { + Ok(mut resp) => { + debug!("[proxy-client] Read response from proxy-server: {resp:?}"); + let headers = resp.headers_mut(); + if let Some(measurements) = measurements.clone() { + match measurements.to_header_format() { + Ok(header_value) => { + headers.insert(MEASUREMENT_HEADER, header_value); } - Err(reconnect_err) => { - warn!("Reconnect after in-flight drop failed: {reconnect_err}"); - let mut resp = Response::new(full( - "Request failed: connection to proxy-server dropped", - )); - *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; - break Ok(resp); + Err(e) => { + error!("Failed to encode measurement values: {e}"); } } } - }; - - match send_result { - Ok(mut resp) => { - debug!("[proxy-client] Read response from proxy-server: {resp:?}"); - let headers = resp.headers_mut(); - if let Some(measurements) = measurements.clone() { - match measurements.to_header_format() { - Ok(header_value) => { - headers.insert(MEASUREMENT_HEADER, header_value); - } - Err(e) => { - error!("Failed to encode measurement values: {e}"); - } - } - } - update_header( - headers, - ATTESTATION_TYPE_HEADER, - remote_attestation_type.as_str(), - ); + update_header( + headers, + ATTESTATION_TYPE_HEADER, + remote_attestation_type.as_str(), + ); - break Ok(resp.map(|b| b.boxed())); - } - Err(e) => { - warn!("Failed to send request to proxy-server: {e}"); - match Self::setup_connection_with_backoff( - &target, - &nesting_tls_connector, - true, - ) - .await - { - Ok((new_sender, new_conn, new_attestation)) => { - sender = new_sender; - remote_attestation_type = new_attestation.attestation_type; - measurements = new_attestation.get_measurements().ok().flatten(); - - let (new_conn_done_tx, new_conn_done_rx) = - tokio::sync::watch::channel::>(None); - conn_done_rx = new_conn_done_rx; - - tokio::spawn(async move { - let res = new_conn.await; - let _ = new_conn_done_tx.send(res.err()); - }); - - warn!("Reconnected to proxy-server, retrying request"); - continue; - } - Err(reconnect_err) => { - warn!("Reconnect after request failure failed: {reconnect_err}"); - let mut resp = Response::new(full(format!( - "Request failed: {e}" - ))); - *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; - break Ok(resp); - } - } - } + (Ok(resp.map(|b| b.boxed())), false) + } + Err(e) => { + warn!("Failed to send request to proxy-server: {e}"); + let mut resp = Response::new(full(format!("Request failed: {e}"))); + *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; + + (Ok(resp), true) } }; @@ -885,6 +814,12 @@ impl ProxyClient { if response_tx.send(response).is_err() { warn!("Failed to forward response to source client, probably they dropped the connection"); } + + if should_reconnect { + // Leave the inner loop and continue on the reconnect loop + warn!("Reconnecting to proxy-server due to failed request"); + break; + } } else { // The request sender was dropped - so no more incoming requests debug!("Request sender dropped - leaving connection handler loop"); @@ -1033,13 +968,13 @@ impl ProxyClient { .keep_alive_interval(Some(Duration::from_secs(KEEP_ALIVE_INTERVAL))) .keep_alive_timeout(Duration::from_secs(KEEP_ALIVE_TIMEOUT)) .keep_alive_while_idle(true) - .handshake::<_, http_body_util::Full>(outbound_io) + .handshake::<_, hyper::body::Incoming>(outbound_io) .await?; (sender.into(), conn.into()) } HttpVersion::Http1 => { let (sender, conn) = hyper::client::conn::http1::Builder::new() - .handshake::<_, http_body_util::Full>(outbound_io) + .handshake::<_, hyper::body::Incoming>(outbound_io) .await?; (sender.into(), conn.into()) } @@ -1053,10 +988,6 @@ impl ProxyClient { req: hyper::Request, requests_tx: mpsc::Sender, ) -> Result>, ProxyError> { - let (parts, body) = req.into_parts(); - let body = body.collect().await?.to_bytes(); - let req = http::Request::from_parts(parts, body); - let (response_tx, response_rx) = oneshot::channel(); requests_tx.send((req, response_tx)).await?; Ok(response_rx.await??) @@ -1245,6 +1176,10 @@ where mod tests { use attestation::{AttestationType, measurements::MeasurementPolicy}; use std::collections::HashMap; + use std::sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }; use tokio_rustls::TlsConnector; use super::*; @@ -1932,6 +1867,7 @@ mod tests { // This is used to trigger a dropped connection to the proxy server let (connection_breaker_tx, connection_breaker_rx) = oneshot::channel(); + let (reconnected_tx, reconnected_rx) = oneshot::channel(); tokio::spawn(async move { let connection_handle = proxy_server.accept().await.unwrap(); @@ -1943,6 +1879,7 @@ mod tests { // Now accept another connection proxy_server.accept().await.unwrap(); + let _ = reconnected_tx.send(()); }); let proxy_client = ProxyClient::new_with_tls_config( @@ -1971,6 +1908,7 @@ mod tests { // Now break the connection connection_breaker_tx.send(()).unwrap(); + reconnected_rx.await.unwrap(); // Make another request let res = reqwest::get(format!("http://{}", proxy_client_addr)) @@ -1984,6 +1922,128 @@ mod tests { assert_eq!(res_body, "No measurements"); } + #[tokio::test(flavor = "multi_thread")] + async fn http_proxy_does_not_retry_failed_request() { + init_tracing(); + + let request_count = Arc::new(AtomicUsize::new(0)); + let request_seen = Arc::new(tokio::sync::Notify::new()); + let (release_tx, release_rx) = tokio::sync::watch::channel(false); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let target_addr = listener.local_addr().unwrap(); + + let app = axum::Router::new().route( + "/", + axum::routing::get({ + let request_count = request_count.clone(); + let request_seen = request_seen.clone(); + let release_rx = release_rx.clone(); + + move || { + let request_count = request_count.clone(); + let request_seen = request_seen.clone(); + let mut release_rx = release_rx.clone(); + + async move { + request_count.fetch_add(1, Ordering::SeqCst); + request_seen.notify_waiters(); + + if !*release_rx.borrow() { + release_rx.changed().await.unwrap(); + } + + "ok" + } + } + }), + ); + + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); + let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), + target_addr.to_string(), + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + AttestationVerifier::expect_none(), + false, + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.local_addr().unwrap(); + + let (connection_breaker_tx, connection_breaker_rx) = oneshot::channel(); + let (reconnected_tx, reconnected_rx) = oneshot::channel(); + + tokio::spawn(async move { + let connection_handle = proxy_server.accept().await.unwrap(); + connection_breaker_rx.await.unwrap(); + connection_handle.abort(); + proxy_server.accept().await.unwrap(); + let _ = reconnected_tx.send(()); + }); + + let proxy_client = ProxyClient::new_with_tls_config( + client_config, + "127.0.0.1:0".to_string(), + format!("localhost:{}", proxy_addr.port()), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::mock(), + None, + ) + .await + .unwrap(); + + let proxy_client_addr = proxy_client.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_client.accept().await.unwrap(); + proxy_client.accept().await.unwrap(); + }); + + let request_url = format!("http://{}", proxy_client_addr); + let failed_request = tokio::spawn(async move { reqwest::get(request_url).await.unwrap() }); + + loop { + if request_count.load(Ordering::SeqCst) > 0 { + break; + } + + request_seen.notified().await; + } + + connection_breaker_tx.send(()).unwrap(); + release_tx.send(true).unwrap(); + + let failed_response = failed_request.await.unwrap(); + assert_eq!(failed_response.status(), hyper::StatusCode::BAD_GATEWAY); + assert_eq!(request_count.load(Ordering::SeqCst), 1); + + reconnected_rx.await.unwrap(); + + let res = reqwest::get(format!("http://{}", proxy_client_addr)) + .await + .unwrap(); + + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + assert_eq!(res.text().await.unwrap(), "ok"); + assert_eq!(request_count.load(Ordering::SeqCst), 2); + } + // Use HTTP 1.1 #[tokio::test(flavor = "multi_thread")] async fn http_proxy_with_http1() {