diff --git a/Cargo-minimal.lock b/Cargo-minimal.lock index 69047c934..c91072753 100644 --- a/Cargo-minimal.lock +++ b/Cargo-minimal.lock @@ -2546,6 +2546,8 @@ dependencies = [ name = "ohttp-relay" version = "0.0.11" dependencies = [ + "bhttp", + "bitcoin-ohttp", "byteorder", "bytes", "futures", @@ -2816,7 +2818,6 @@ name = "payjoin-directory" version = "0.0.3" dependencies = [ "anyhow", - "bhttp", "bitcoin 0.32.8", "bitcoin-ohttp", "clap 4.5.46", @@ -2878,6 +2879,7 @@ dependencies = [ "anyhow", "axum", "axum-server", + "bitcoin-ohttp", "clap 4.5.46", "config", "flate2", diff --git a/Cargo-recent.lock b/Cargo-recent.lock index 69047c934..c91072753 100644 --- a/Cargo-recent.lock +++ b/Cargo-recent.lock @@ -2546,6 +2546,8 @@ dependencies = [ name = "ohttp-relay" version = "0.0.11" dependencies = [ + "bhttp", + "bitcoin-ohttp", "byteorder", "bytes", "futures", @@ -2816,7 +2818,6 @@ name = "payjoin-directory" version = "0.0.3" dependencies = [ "anyhow", - "bhttp", "bitcoin 0.32.8", "bitcoin-ohttp", "clap 4.5.46", @@ -2878,6 +2879,7 @@ dependencies = [ "anyhow", "axum", "axum-server", + "bitcoin-ohttp", "clap 4.5.46", "config", "flate2", diff --git a/ohttp-relay/Cargo.toml b/ohttp-relay/Cargo.toml index 98f982745..22e389587 100644 --- a/ohttp-relay/Cargo.toml +++ b/ohttp-relay/Cargo.toml @@ -20,6 +20,7 @@ ws-bootstrap = ["futures", "rustls", "tokio-tungstenite"] _test-util = [] [dependencies] +bhttp = { version = "0.6.1", features = ["http"] } byteorder = "1.5.0" bytes = "1.10.1" futures = { version = "0.3.31", optional = true } @@ -33,6 +34,7 @@ hyper-rustls = { version = "0.27.7", default-features = false, features = [ "ring", ] } hyper-util = { version = "0.1.16", features = ["client-legacy", "service"] } +ohttp = { package = "bitcoin-ohttp", version = "0.6" } rustls = { version = "0.23.31", optional = true, default-features = false, features = [ "ring", ] } diff --git a/ohttp-relay/src/gateway_helpers.rs b/ohttp-relay/src/gateway_helpers.rs new file mode 100644 index 000000000..9e16b7aca --- /dev/null +++ b/ohttp-relay/src/gateway_helpers.rs @@ -0,0 +1,115 @@ +use std::io::Cursor; + +pub const CHACHA20_POLY1305_NONCE_LEN: usize = 32; +pub const POLY1305_TAG_SIZE: usize = 16; +pub const OHTTP_OVERHEAD: usize = CHACHA20_POLY1305_NONCE_LEN + POLY1305_TAG_SIZE; +pub const ENCAPSULATED_MESSAGE_BYTES: usize = 8192; +pub const BHTTP_REQ_BYTES: usize = ENCAPSULATED_MESSAGE_BYTES - OHTTP_OVERHEAD; + +#[derive(Debug)] +pub enum GatewayError { + BadRequest(String), + OhttpKeyRejection(String), + InternalServerError(String), +} + +impl std::fmt::Display for GatewayError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + GatewayError::BadRequest(msg) => write!(f, "Bad request: {}", msg), + GatewayError::OhttpKeyRejection(msg) => write!(f, "OHTTP key rejection: {}", msg), + GatewayError::InternalServerError(msg) => write!(f, "Internal server error: {}", msg), + } + } +} + +impl std::error::Error for GatewayError {} + +pub struct DecapsulatedRequest { + pub method: String, + pub uri: String, + pub headers: Vec<(String, String)>, + pub body: Vec, +} + +pub fn decapsulate_ohttp_request( + ohttp_body: &[u8], + ohttp_server: &ohttp::Server, +) -> Result<(DecapsulatedRequest, ohttp::ServerResponse), GatewayError> { + let (bhttp_req, res_ctx) = ohttp_server.decapsulate(ohttp_body).map_err(|e| { + GatewayError::OhttpKeyRejection(format!("OHTTP decapsulation failed: {}", e)) + })?; + + let mut cursor = Cursor::new(bhttp_req); + let bhttp_msg = bhttp::Message::read_bhttp(&mut cursor) + .map_err(|e| GatewayError::BadRequest(format!("Invalid BHTTP: {}", e)))?; + + let method = String::from_utf8(bhttp_msg.control().method().unwrap_or_default().to_vec()) + .unwrap_or_else(|_| "GET".to_string()); + + let uri = format!( + "{}://{}{}", + std::str::from_utf8(bhttp_msg.control().scheme().unwrap_or_default()).unwrap_or("https"), + std::str::from_utf8(bhttp_msg.control().authority().unwrap_or_default()) + .unwrap_or("localhost"), + std::str::from_utf8(bhttp_msg.control().path().unwrap_or_default()).unwrap_or("/") + ); + + let mut headers = Vec::new(); + for field in bhttp_msg.header().fields() { + let name = String::from_utf8_lossy(field.name()).to_string(); + let value = String::from_utf8_lossy(field.value()).to_string(); + headers.push((name, value)); + } + + let body = bhttp_msg.content().to_vec(); + + Ok((DecapsulatedRequest { method, uri, headers, body }, res_ctx)) +} + +pub fn encapsulate_ohttp_response( + status_code: u16, + headers: Vec<(String, String)>, + body: Vec, + res_ctx: ohttp::ServerResponse, +) -> Result, GatewayError> { + let bhttp_status = bhttp::StatusCode::try_from(status_code) + .map_err(|e| GatewayError::InternalServerError(format!("Invalid status code: {}", e)))?; + + let mut bhttp_res = bhttp::Message::response(bhttp_status); + + for (name, value) in &headers { + bhttp_res.put_header(name.as_str(), value.as_str()); + } + + bhttp_res.write_content(&body); + + let mut bhttp_bytes = Vec::new(); + bhttp_res.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_bytes).map_err(|e| { + GatewayError::InternalServerError(format!("BHTTP serialization failed: {}", e)) + })?; + + if bhttp_bytes.len() > BHTTP_REQ_BYTES { + return Err(GatewayError::InternalServerError(format!( + "BHTTP response too large: {} > {}", + bhttp_bytes.len(), + BHTTP_REQ_BYTES + ))); + } + + bhttp_bytes.resize(BHTTP_REQ_BYTES, 0); + + let ohttp_res = res_ctx.encapsulate(&bhttp_bytes).map_err(|e| { + GatewayError::InternalServerError(format!("OHTTP encapsulation failed: {}", e)) + })?; + + if ohttp_res.len() != ENCAPSULATED_MESSAGE_BYTES { + return Err(GatewayError::InternalServerError(format!( + "Unexpected OHTTP response size: {} != {}", + ohttp_res.len(), + ENCAPSULATED_MESSAGE_BYTES + ))); + } + + Ok(ohttp_res) +} diff --git a/ohttp-relay/src/lib.rs b/ohttp-relay/src/lib.rs index fb70f01a0..e0c366470 100644 --- a/ohttp-relay/src/lib.rs +++ b/ohttp-relay/src/lib.rs @@ -41,6 +41,9 @@ pub mod gateway_prober; mod gateway_uri; pub mod sentinel; pub use sentinel::SentinelTag; +pub mod gateway_helpers; + +pub use gateway_helpers::{decapsulate_ohttp_request, encapsulate_ohttp_response}; use crate::error::{BoxError, Error}; diff --git a/payjoin-directory/Cargo.toml b/payjoin-directory/Cargo.toml index 0a63d54fa..0e7f29c73 100644 --- a/payjoin-directory/Cargo.toml +++ b/payjoin-directory/Cargo.toml @@ -20,7 +20,6 @@ acme = ["tokio-rustls-acme"] [dependencies] anyhow = "1.0.99" -bhttp = { version = "0.6.1", features = ["http"] } bitcoin = { version = "0.32.7", features = ["base64", "rand-std"] } clap = { version = "4.5.45", features = ["derive", "env"] } config = "0.15.14" diff --git a/payjoin-directory/src/lib.rs b/payjoin-directory/src/lib.rs index b7a14309c..0adb083f5 100644 --- a/payjoin-directory/src/lib.rs +++ b/payjoin-directory/src/lib.rs @@ -10,10 +10,10 @@ use http_body_util::{BodyExt, Empty, Full}; use hyper::body::{Body, Bytes}; use hyper::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_TYPE}; use hyper::server::conn::http1; -use hyper::{Method, Request, Response, StatusCode, Uri}; +use hyper::{Method, Request, Response, StatusCode}; use hyper_util::rt::TokioIo; use hyper_util::service::TowerToHyperService; -use payjoin::directory::{ShortId, ShortIdError, ENCAPSULATED_MESSAGE_BYTES}; +use payjoin::directory::{ShortId, ShortIdError}; use tokio::net::TcpListener; #[cfg(feature = "acme")] use tokio_rustls_acme::AcmeConfig; @@ -28,10 +28,6 @@ use ohttp_relay::SentinelTag; pub use crate::key_config::*; -const CHACHA20_POLY1305_NONCE_LEN: usize = 32; // chacha20poly1305 n_k -const POLY1305_TAG_SIZE: usize = 16; -pub const BHTTP_REQ_BYTES: usize = - ENCAPSULATED_MESSAGE_BYTES - (CHACHA20_POLY1305_NONCE_LEN + POLY1305_TAG_SIZE); const V1_MAX_BUFFER_SIZE: usize = 65536; const V1_REJECT_RES_JSON: &str = @@ -128,7 +124,7 @@ fn parse_address_lines(text: &str) -> std::collections::HashSet { db: D, - ohttp: ohttp::Server, + pub ohttp: ohttp::Server, sentinel_tag: SentinelTag, v1: Option, } @@ -271,15 +267,14 @@ impl Service { } let mut response = match (parts.method, path_segments.as_slice()) { - (Method::POST, ["", ".well-known", "ohttp-gateway"]) => - self.handle_ohttp_gateway(body).await, (Method::GET, ["", ".well-known", "ohttp-gateway"]) => self.handle_ohttp_gateway_get(&query).await, - (Method::POST, ["", ""]) => self.handle_ohttp_gateway(body).await, (Method::GET, ["", "ohttp-keys"]) => self.get_ohttp_keys().await, - (Method::POST, ["", id]) => self.handle_post_v1(id, query, body).await, (Method::GET, ["", "health"]) => self.health_check().await, (Method::GET, ["", ""]) => handle_directory_home_path().await, + (Method::POST, ["", id]) => self.post_mailbox_or_v1(id, query, body).await, + (Method::GET, ["", id]) => self.get_mailbox(id).await, + (Method::PUT, ["", id]) if self.v1.is_some() => self.put_payjoin_v1(id, body).await, _ => Ok(not_found()), } .unwrap_or_else(|e| e.to_response()); @@ -290,8 +285,7 @@ impl Service { Ok(response) } - /// Route POST /{id}: forward to V1 fallback when enabled, otherwise reject. - async fn handle_post_v1( + async fn post_mailbox_or_v1( &self, id: &str, query: String, @@ -301,120 +295,52 @@ impl Service { B: Body + Send + 'static, B::Error: Into, { - if self.v1.is_some() { - self.post_fallback_v1(id, query, body).await - } else { - Ok(Response::builder() - .status(StatusCode::BAD_REQUEST) - .header(CONTENT_TYPE, "application/json") - .body(full(V1_VERSION_UNSUPPORTED_RES_JSON))?) - } - } - - /// Handle an encapsulated OHTTP request and return an encapsulated response - async fn handle_ohttp_gateway( - &self, - body: B, - ) -> Result>, HandlerError> - where - B: Body + Send + 'static, - B::Error: Into, - { - // Decapsulate OHTTP request - let ohttp_body = body + let body_bytes = body .collect() .await .map_err(|e| HandlerError::BadRequest(anyhow::anyhow!(e.into())))? .to_bytes(); - let (bhttp_req, res_ctx) = self - .ohttp - .decapsulate(&ohttp_body) - .map_err(|e| HandlerError::OhttpKeyRejection(e.into()))?; - let mut cursor = std::io::Cursor::new(bhttp_req); - let req = bhttp::Message::read_bhttp(&mut cursor) - .map_err(|e| HandlerError::BadRequest(e.into()))?; - let uri = Uri::builder() - .scheme(req.control().scheme().unwrap_or_default()) - .authority(req.control().authority().unwrap_or_default()) - .path_and_query(req.control().path().unwrap_or_default()) - .build()?; - let body = req.content().to_vec(); - let mut http_req = - Request::builder().uri(uri).method(req.control().method().unwrap_or_default()); - for header in req.header().fields() { - http_req = http_req.header(header.name(), header.value()) - } - let request = http_req.body(full(body))?; - - // Handle decapsulated request - let response = self.handle_decapsulated_request(request).await?; - - // Encapsulate OHTTP response - let (parts, body) = response.into_parts(); - let mut bhttp_res = bhttp::Message::response( - bhttp::StatusCode::try_from(parts.status.as_u16()) - .map_err(|e| HandlerError::InternalServerError(e.into()))?, - ); - for (name, value) in parts.headers.iter() { - bhttp_res.put_header(name.as_str(), value.to_str().unwrap_or_default()); - } - let full_body = body - .collect() - .await - .map_err(|e| HandlerError::InternalServerError(e.into()))? - .to_bytes(); - bhttp_res.write_content(&full_body); - let mut bhttp_bytes = Vec::new(); - bhttp_res - .write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_bytes) - .map_err(|e| HandlerError::InternalServerError(e.into()))?; - bhttp_bytes.resize(BHTTP_REQ_BYTES, 0); - let ohttp_res = res_ctx - .encapsulate(&bhttp_bytes) - .map_err(|e| HandlerError::InternalServerError(e.into()))?; - assert!(ohttp_res.len() == ENCAPSULATED_MESSAGE_BYTES, "Unexpected OHTTP response size"); - Ok(Response::new(full(ohttp_res))) - } - async fn handle_decapsulated_request( - &self, - req: Request>, - ) -> Result>, HandlerError> { - let path = req.uri().path().to_string(); - let (parts, body) = req.into_parts(); - - let path_segments: Vec<&str> = path.split('/').collect(); - debug!("handle_v2: {:?}", &path_segments); - match (parts.method, path_segments.as_slice()) { - (Method::POST, &["", id]) => self.post_mailbox(id, body).await, - (Method::GET, &["", id]) => self.get_mailbox(id).await, - (Method::PUT, &["", id]) if self.v1.is_some() => self.put_payjoin_v1(id, body).await, - _ => Ok(not_found()), + if body_bytes.len() > V1_MAX_BUFFER_SIZE { + return Err(HandlerError::PayloadTooLarge); } - } - - async fn post_mailbox( - &self, - id: &str, - body: BoxBody, - ) -> Result>, HandlerError> { - let none_response = Response::builder().status(StatusCode::OK).body(empty())?; - trace!("post_mailbox"); let id = ShortId::from_str(id)?; - let req = body - .collect() - .await - .map_err(|e| HandlerError::InternalServerError(e.into()))? - .to_bytes(); - if req.len() > V1_MAX_BUFFER_SIZE { - return Err(HandlerError::PayloadTooLarge); - } + match String::from_utf8(body_bytes.to_vec()) { + Ok(body_str) => { + if self.v1.is_none() { + return Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .header(CONTENT_TYPE, "application/json") + .body(full(V1_VERSION_UNSUPPORTED_RES_JSON))?); + } + trace!("POST mailbox (v1 fallback)"); + self.check_v1_blocklist(&body_str).await?; + let v2_compat_body = format!("{body_str}\n{query}"); + let none_response = Response::builder() + .status(StatusCode::SERVICE_UNAVAILABLE) + .body(full(V1_UNAVAILABLE_RES_JSON))?; + handle_peek( + self.db.post_v1_request_and_wait_for_response(&id, v2_compat_body.into()).await, + none_response, + ) + } + Err(_) => { + if body_bytes.len() < 100 && self.v1.is_some() { + trace!("POST mailbox (invalid v1 body)"); + return Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(full(V1_REJECT_RES_JSON))?); + } - match self.db.post_v2_payload(&id, req.into()).await { - Ok(_) => Ok(none_response), - Err(e) => Err(HandlerError::InternalServerError(e.into())), + trace!("POST mailbox (v2 binary)"); + let none_response = Response::builder().status(StatusCode::OK).body(empty())?; + match self.db.post_v2_payload(&id, body_bytes.into()).await { + Ok(_) => Ok(none_response), + Err(e) => Err(HandlerError::InternalServerError(e.into())), + } + } } } @@ -451,11 +377,15 @@ impl Service { Ok(()) } - async fn put_payjoin_v1( + async fn put_payjoin_v1( &self, id: &str, - body: BoxBody, - ) -> Result>, HandlerError> { + body: B, + ) -> Result>, HandlerError> + where + B: Body + Send + 'static, + B::Error: Into, + { trace!("Put_payjoin_v1"); let ok_response = Response::builder().status(StatusCode::OK).body(empty())?; @@ -463,7 +393,12 @@ impl Service { let req = body .collect() .await - .map_err(|e| HandlerError::InternalServerError(e.into()))? + .map_err(|e| { + HandlerError::InternalServerError(anyhow::anyhow!( + "Failed to read body: {}", + e.into() + )) + })? .to_bytes(); if req.len() > V1_MAX_BUFFER_SIZE { return Err(HandlerError::PayloadTooLarge); @@ -478,43 +413,6 @@ impl Service { } } - async fn post_fallback_v1( - &self, - id: &str, - query: String, - body: B, - ) -> Result>, HandlerError> - where - B: Body + Send + 'static, - B::Error: Into, - { - trace!("Post fallback v1"); - let none_response = Response::builder() - .status(StatusCode::SERVICE_UNAVAILABLE) - .body(full(V1_UNAVAILABLE_RES_JSON))?; - let bad_request_body_res = - Response::builder().status(StatusCode::BAD_REQUEST).body(full(V1_REJECT_RES_JSON))?; - - let body_bytes = match body.collect().await { - Ok(bytes) => bytes.to_bytes(), - Err(_) => return Ok(bad_request_body_res), - }; - - let body_str = match String::from_utf8(body_bytes.to_vec()) { - Ok(body_str) => body_str, - Err(_) => return Ok(bad_request_body_res), - }; - - self.check_v1_blocklist(&body_str).await?; - - let v2_compat_body = format!("{body_str}\n{query}"); - let id = ShortId::from_str(id)?; - handle_peek( - self.db.post_v1_request_and_wait_for_response(&id, v2_compat_body.into()).await, - none_response, - ) - } - async fn handle_ohttp_gateway_get( &self, query: &str, @@ -589,10 +487,6 @@ impl Service { async fn handle_directory_home_path() -> Result>, HandlerError> { - let mut res = Response::new(empty()); - *res.status_mut() = StatusCode::OK; - res.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("text/html")); - let html = r#" @@ -644,7 +538,9 @@ async fn handle_directory_home_path() -> Result "#; - *res.body_mut() = full(html); + let mut res = Response::new(full(html)); + *res.status_mut() = StatusCode::OK; + res.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("text/html")); Ok(res) } @@ -654,7 +550,6 @@ enum HandlerError { InternalServerError(anyhow::Error), ServiceUnavailable(anyhow::Error), SenderGone(anyhow::Error), - OhttpKeyRejection(anyhow::Error), BadRequest(anyhow::Error), /// V1 PSBT rejected — returns the BIP78 `original-psbt-rejected` error. V1PsbtRejected(anyhow::Error), @@ -678,15 +573,6 @@ impl HandlerError { error!("Sender gone: {}", e); *res.status_mut() = StatusCode::GONE } - HandlerError::OhttpKeyRejection(e) => { - const OHTTP_KEY_REJECTION_RES_JSON: &str = r#"{"type":"https://iana.org/assignments/http-problem-types#ohttp-key", "title": "key identifier unknown"}"#; - - warn!("Bad request: Key configuration rejected: {}", e); - *res.status_mut() = StatusCode::BAD_REQUEST; - res.headers_mut() - .insert(CONTENT_TYPE, HeaderValue::from_static("application/problem+json")); - *res.body_mut() = full(OHTTP_KEY_REJECTION_RES_JSON); - } HandlerError::BadRequest(e) => { warn!("Bad request: {}", e); *res.status_mut() = StatusCode::BAD_REQUEST diff --git a/payjoin-mailroom/Cargo.toml b/payjoin-mailroom/Cargo.toml index 4b5ba8448..2f7d311a0 100644 --- a/payjoin-mailroom/Cargo.toml +++ b/payjoin-mailroom/Cargo.toml @@ -36,6 +36,7 @@ config = "0.15" flate2 = { version = "1.1", optional = true } ipnet = { version = "2", optional = true } maxminddb = { version = "0.27", optional = true } +ohttp = { package = "bitcoin-ohttp", version = "0.6" } ohttp-relay = { path = "../ohttp-relay", features = ["bootstrap"] } opentelemetry = "0.31" opentelemetry-otlp = { version = "0.31", optional = true, features = [ diff --git a/payjoin-mailroom/src/lib.rs b/payjoin-mailroom/src/lib.rs index 08273bdba..a2f7c4482 100644 --- a/payjoin-mailroom/src/lib.rs +++ b/payjoin-mailroom/src/lib.rs @@ -11,7 +11,7 @@ use ohttp_relay::SentinelTag; use opentelemetry_sdk::metrics::SdkMeterProvider; use rand::Rng; use tokio_listener::{Listener, SystemOptions, UserOptions}; -use tower::{Service, ServiceBuilder}; +use tower::{Service, ServiceBuilder, ServiceExt}; use tracing::info; #[cfg(feature = "access-control")] @@ -20,14 +20,17 @@ pub mod cli; pub mod config; pub mod metrics; pub mod middleware; +pub mod ohttp; use crate::metrics::MetricsService; use crate::middleware::{track_connections, track_metrics}; +use crate::ohttp::OhttpGatewayConfig; #[derive(Clone)] struct Services { directory: payjoin_directory::Service, relay: ohttp_relay::Service, + ohttp_config: OhttpGatewayConfig, metrics: MetricsService, #[cfg(feature = "access-control")] geoip: Option>, @@ -40,10 +43,12 @@ pub async fn serve(config: Config, meter_provider: Option) -> let geoip = init_geoip(&config).await?; let directory = init_directory(&config, sentinel_tag).await?; + let ohttp_config = OhttpGatewayConfig::new(directory.ohttp.clone(), sentinel_tag); let services = Services { directory, relay: ohttp_relay::Service::new(sentinel_tag).await, + ohttp_config, metrics: MetricsService::new(meter_provider), #[cfg(feature = "access-control")] geoip, @@ -83,10 +88,12 @@ pub async fn serve_manual_tls( let geoip = init_geoip(&config).await?; let directory = init_directory(&config, sentinel_tag).await?; + let ohttp_config = OhttpGatewayConfig::new(directory.ohttp.clone(), sentinel_tag); let services = Services { directory, relay: ohttp_relay::Service::new_with_roots(root_store, sentinel_tag).await, + ohttp_config, metrics: MetricsService::new(None), #[cfg(feature = "access-control")] geoip, @@ -147,10 +154,12 @@ pub async fn serve_acme( let geoip = init_geoip(&config).await?; let directory = init_directory(&config, sentinel_tag).await?; + let ohttp_config = OhttpGatewayConfig::new(directory.ohttp.clone(), sentinel_tag); let services = Services { directory, relay: ohttp_relay::Service::new(sentinel_tag).await, + ohttp_config, metrics: MetricsService::new(meter_provider), #[cfg(feature = "access-control")] geoip, @@ -362,22 +371,54 @@ fn build_app(services: Services) -> Router { router } -async fn route_request( - State(mut services): State, - req: axum::extract::Request, -) -> Response { +async fn route_request(State(services): State, req: axum::extract::Request) -> Response { if is_relay_request(&req) { - match services.relay.call(req).await { + let mut relay = services.relay.clone(); + match relay.call(req).await { Ok(res) => res.into_response(), Err(e) => (axum::http::StatusCode::BAD_GATEWAY, e.to_string()).into_response(), } } else { // The directory service handles all other requests (including 404) - match services.directory.call(req).await { - Ok(res) => res.into_response(), + handle_directory_request(services, req).await + } +} + +async fn handle_directory_request(services: Services, req: axum::extract::Request) -> Response { + let is_ohttp_request = matches!( + (req.method(), req.uri().path()), + (&Method::POST, "/.well-known/ohttp-gateway") | (&Method::POST, "/") + ); + + if is_ohttp_request { + let app = Router::new() + .fallback(directory_handler) + .layer(axum::middleware::from_fn_with_state( + services.ohttp_config.clone(), + crate::ohttp::ohttp_gateway, + )) + .with_state(services.directory.clone()); + + match app.oneshot(req).await { + Ok(response) => response, Err(e) => (axum::http::StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(), } + } else { + directory_handler(State(services.directory), req).await + } +} + +async fn directory_handler( + State(directory): State>, + req: axum::extract::Request, +) -> Response { + let mut dir = directory.clone(); + match dir.call(req).await { + Ok(response) => response.into_response(), + Err(e) => + (axum::http::StatusCode::INTERNAL_SERVER_ERROR, format!("Directory error: {}", e)) + .into_response(), } } @@ -525,9 +566,11 @@ mod tests { ); let sentinel_tag = generate_sentinel_tag(); + let directory = init_directory(&config, sentinel_tag).await.unwrap(); let services = Services { - directory: init_directory(&config, sentinel_tag).await.unwrap(), + directory: directory.clone(), relay: ohttp_relay::Service::new(sentinel_tag).await, + ohttp_config: OhttpGatewayConfig::new(directory.ohttp.clone(), sentinel_tag), metrics: MetricsService::new(Some(provider.clone())), #[cfg(feature = "access-control")] geoip: None, diff --git a/payjoin-mailroom/src/ohttp/middleware.rs b/payjoin-mailroom/src/ohttp/middleware.rs new file mode 100644 index 000000000..c870d155d --- /dev/null +++ b/payjoin-mailroom/src/ohttp/middleware.rs @@ -0,0 +1,155 @@ +use axum::body::Body; +use axum::extract::{Request, State}; +use axum::http::header::CONTENT_TYPE; +use axum::http::StatusCode; +use axum::middleware::Next; +use axum::response::{IntoResponse, Response}; +use ohttp_relay::gateway_helpers::{decapsulate_ohttp_request, encapsulate_ohttp_response}; +use ohttp_relay::sentinel::{self, SentinelTag}; +use tracing::{error, warn}; + +/// Configuration for the OHTTP gateway middleware +#[derive(Clone)] +pub struct OhttpGatewayConfig { + pub ohttp_server: ohttp::Server, + pub sentinel_tag: SentinelTag, +} + +impl OhttpGatewayConfig { + pub fn new(ohttp_server: ohttp::Server, sentinel_tag: SentinelTag) -> Self { + Self { ohttp_server, sentinel_tag } + } +} + +pub async fn ohttp_gateway( + State(config): State, + req: Request, + next: Next, +) -> Response { + if let Some(header_value) = + req.headers().get(sentinel::HEADER_NAME).and_then(|v| v.to_str().ok()) + { + if sentinel::is_self_loop(&config.sentinel_tag, header_value) { + warn!("Rejected OHTTP request from same-instance relay"); + return ( + StatusCode::FORBIDDEN, + "Relay and gateway must be operated by different entities", + ) + .into_response(); + } + } + + let (parts, body) = req.into_parts(); + let body_bytes = match axum::body::to_bytes(body, usize::MAX).await { + Ok(bytes) => bytes.to_vec(), + Err(_) => { + return (StatusCode::BAD_REQUEST, "Failed to read request body").into_response(); + } + }; + + let (decapsulated_req, res_ctx) = + match decapsulate_ohttp_request(&body_bytes, &config.ohttp_server) { + Ok(result) => result, + Err(e) => { + error!("OHTTP decapsulation failed: {}", e); + return match e { + ohttp_relay::gateway_helpers::GatewayError::OhttpKeyRejection(_) => + ohttp_key_rejection_response(), + ohttp_relay::gateway_helpers::GatewayError::BadRequest(msg) => + (StatusCode::BAD_REQUEST, msg).into_response(), + ohttp_relay::gateway_helpers::GatewayError::InternalServerError(msg) => + (StatusCode::INTERNAL_SERVER_ERROR, msg).into_response(), + }; + } + }; + + let uri = match decapsulated_req.uri.parse::() { + Ok(uri) => uri, + Err(e) => { + error!("Invalid URI in BHTTP: {}", e); + return (StatusCode::BAD_REQUEST, "Invalid URI").into_response(); + } + }; + + let method = + decapsulated_req.method.parse::().unwrap_or(axum::http::Method::GET); + + let mut new_parts = parts; + new_parts.uri = uri; + new_parts.method = method; + + for (name, value) in decapsulated_req.headers { + if let Ok(header_name) = name.parse::() { + if let Ok(header_value) = value.parse::() { + new_parts.headers.insert(header_name, header_value); + } + } + } + + let inner_request = Request::from_parts(new_parts, Body::from(decapsulated_req.body)); + + let response = next.run(inner_request).await; + + let (parts, body) = response.into_parts(); + let response_bytes = match axum::body::to_bytes(body, usize::MAX).await { + Ok(bytes) => bytes.to_vec(), + Err(_) => { + return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response body") + .into_response(); + } + }; + + let headers: Vec<(String, String)> = parts + .headers + .iter() + .map(|(name, value)| { + (name.as_str().to_string(), value.to_str().unwrap_or_default().to_string()) + }) + .collect(); + + let ohttp_response = + match encapsulate_ohttp_response(parts.status.as_u16(), headers, response_bytes, res_ctx) { + Ok(bytes) => bytes, + Err(e) => { + error!("OHTTP encapsulation failed: {}", e); + return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to encapsulate response") + .into_response(); + } + }; + + (StatusCode::OK, ohttp_response).into_response() +} + +fn ohttp_key_rejection_response() -> Response { + const OHTTP_KEY_REJECTION_JSON: &str = r#"{"type":"https://iana.org/assignments/http-problem-types#ohttp-key", "title": "key identifier unknown"}"#; + + ( + StatusCode::BAD_REQUEST, + [(CONTENT_TYPE, "application/problem+json")], + OHTTP_KEY_REJECTION_JSON, + ) + .into_response() +} + +#[cfg(test)] +mod tests { + use super::*; + + fn create_test_ohttp_server() -> ohttp::Server { + use payjoin_test_utils::{KEM, KEY_ID, SYMMETRIC}; + + let server_config = ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)) + .expect("Failed to create test OHTTP config"); + + ohttp::Server::new(server_config).expect("Failed to create OHTTP server") + } + + #[test] + fn test_config_creation() { + let server = create_test_ohttp_server(); + let sentinel_tag = SentinelTag::new([0u8; 32]); + + let config = OhttpGatewayConfig::new(server, sentinel_tag); + assert!(std::mem::size_of_val(&config) > 0); + } +} diff --git a/payjoin-mailroom/src/ohttp/mod.rs b/payjoin-mailroom/src/ohttp/mod.rs new file mode 100644 index 000000000..25dc3db56 --- /dev/null +++ b/payjoin-mailroom/src/ohttp/mod.rs @@ -0,0 +1,3 @@ +pub mod middleware; + +pub use middleware::{ohttp_gateway, OhttpGatewayConfig};