Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
394 changes: 366 additions & 28 deletions Cargo-minimal.lock

Large diffs are not rendered by default.

394 changes: 366 additions & 28 deletions Cargo-recent.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ members = [
"payjoin-directory",
"payjoin-test-utils",
"payjoin-ffi",
"payjoin-service",
]
resolver = "2"

[patch.crates-io]
ohttp-relay = { path = "ohttp-relay" }
payjoin = { path = "payjoin" }
payjoin-directory = { path = "payjoin-directory" }
payjoin-service = { path = "payjoin-service" }
payjoin-test-utils = { path = "payjoin-test-utils" }

[profile.crane]
Expand Down
8 changes: 4 additions & 4 deletions ohttp-relay/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@ exclude = ["tests"]
default = ["bootstrap"]
bootstrap = ["connect-bootstrap", "ws-bootstrap"]
connect-bootstrap = []
ws-bootstrap = ["futures", "hyper-tungstenite", "rustls", "tokio-tungstenite"]
ws-bootstrap = ["futures", "rustls", "tokio-tungstenite"]
_test-util = []

[dependencies]
byteorder = "1.5.0"
bytes = "1.10.1"
futures = { version = "0.3.31", optional = true }
hex = { package = "hex-conservative", version = "0.1.1" }
http = "1.3.1"
http-body-util = "0.1.3"
hyper = { version = "1.6.0", features = ["http1", "server"] }
Expand All @@ -31,8 +32,7 @@ hyper-rustls = { version = "0.27.7", default-features = false, features = [
"http1",
"ring",
] }
hyper-tungstenite = { version = "0.18.0", optional = true }
hyper-util = { version = "0.1.16", features = ["client-legacy"] }
hyper-util = { version = "0.1.16", features = ["client-legacy", "service"] }
rustls = { version = "0.23.31", optional = true, default-features = false, features = [
"ring",
] }
Expand All @@ -44,11 +44,11 @@ tokio = { version = "1.47.1", features = [
] }
tokio-tungstenite = { version = "0.27.0", optional = true }
tokio-util = { version = "0.7.16", features = ["net", "codec"] }
tower = "0.5"
tracing = "0.1.41"
tracing-subscriber = { version = "0.3.20", features = ["env-filter"] }

[dev-dependencies]
hex = { package = "hex-conservative", version = "0.1.1" }
mockito = "1.7.0"
rcgen = "0.12"
reqwest = { version = "0.12.23", default-features = false, features = [
Expand Down
16 changes: 9 additions & 7 deletions ohttp-relay/src/bootstrap/connect.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::fmt::Debug;
use std::net::SocketAddr;

use http_body_util::combinators::BoxBody;
use hyper::body::{Bytes, Incoming};
use hyper::body::Bytes;
use hyper::upgrade::Upgraded;
use hyper::{Method, Request, Response};
use hyper_util::rt::TokioIo;
Expand All @@ -11,15 +12,16 @@ use tracing::{error, instrument};
use crate::error::Error;
use crate::{empty, GatewayUri};

pub(crate) fn is_connect_request(req: &Request<Incoming>) -> bool {
Method::CONNECT == req.method()
}
pub(crate) fn is_connect_request<B>(req: &Request<B>) -> bool { Method::CONNECT == req.method() }

#[instrument]
pub(crate) async fn try_upgrade(
req: Request<Incoming>,
pub(crate) async fn try_upgrade<B>(
req: Request<B>,
gateway_origin: GatewayUri,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error> {
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error>
where
B: Send + Debug + 'static,
{
let addr = gateway_origin
.to_socket_addr()
.await
Expand Down
15 changes: 10 additions & 5 deletions ohttp-relay/src/bootstrap/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::fmt::Debug;

use http_body_util::combinators::BoxBody;
use hyper::body::{Bytes, Incoming};
use hyper::body::Bytes;
use hyper::{Request, Response};
use tracing::instrument;

Expand All @@ -13,18 +15,21 @@ pub mod connect;
pub mod ws;

#[instrument]
pub(crate) async fn handle_ohttp_keys(
mut req: Request<Incoming>,
pub(crate) async fn handle_ohttp_keys<B>(
req: Request<B>,
gateway_origin: GatewayUri,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error> {
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error>
where
B: Send + Debug + 'static,
{
#[cfg(feature = "connect-bootstrap")]
if connect::is_connect_request(&req) {
return connect::try_upgrade(req, gateway_origin).await;
}

#[cfg(feature = "ws-bootstrap")]
if ws::is_websocket_request(&req) {
return ws::try_upgrade(&mut req, gateway_origin).await;
return ws::try_upgrade(req, gateway_origin).await;
}

Err(Error::BadRequest("Not a supported proxy upgrade request".to_string()))
Expand Down
100 changes: 79 additions & 21 deletions ohttp-relay/src/bootstrap/ws.rs
Original file line number Diff line number Diff line change
@@ -1,58 +1,116 @@
use std::fmt::Debug;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};

use futures::{Sink, SinkExt, StreamExt};
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use hyper::body::{Bytes, Incoming};
use hyper::{Request, Response};
use hyper_tungstenite::HyperWebsocket;
use hyper::body::Bytes;
use hyper::header::{CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, UPGRADE};
use hyper::{Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_tungstenite::tungstenite::handshake::derive_accept_key;
use tokio_tungstenite::tungstenite::protocol::Message;
use tokio_tungstenite::{tungstenite, WebSocketStream};
use tracing::{error, instrument};

use crate::empty;
use crate::error::Error;
use crate::gateway_uri::GatewayUri;

pub(crate) fn is_websocket_request(req: &Request<Incoming>) -> bool {
hyper_tungstenite::is_upgrade_request(req)
/// Check if the request is a WebSocket upgrade request.
///
/// This is done manually to support generic body types.
/// When bootstrapping moves to axum, this can be replaced with
/// `axum::extract::ws::WebSocketUpgrade`.
pub(crate) fn is_websocket_request<B>(req: &Request<B>) -> bool {
let dominated_by_upgrade = req
.headers()
.get(CONNECTION)
.and_then(|v| v.to_str().ok())
.map(|v| v.to_ascii_lowercase().contains("upgrade"))
.unwrap_or(false);

let upgrade_to_websocket = req
.headers()
.get(UPGRADE)
.and_then(|v| v.to_str().ok())
.map(|v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false);

dominated_by_upgrade && upgrade_to_websocket && req.headers().contains_key(SEC_WEBSOCKET_KEY)
}

/// Upgrade the request to a WebSocket connection and proxy to the gateway.
///
/// This performs the WebSocket handshake to support generic body types.
/// When bootstrapping moves to axum, this can be replaced with
/// `axum::extract::ws::WebSocketUpgrade`.
#[instrument]
pub(crate) async fn try_upgrade(
req: &mut Request<Incoming>,
pub(crate) async fn try_upgrade<B>(
req: Request<B>,
gateway_origin: GatewayUri,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error> {
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error>
where
B: Send + Debug + 'static,
{
let gateway_addr = gateway_origin
.to_socket_addr()
.await
.map_err(|e| Error::InternalServerError(Box::new(e)))?
.ok_or_else(|| Error::NotFound)?;

let (res, websocket) = hyper_tungstenite::upgrade(req, None)
.map_err(|e| Error::BadRequest(format!("Error upgrading to websocket: {}", e)))?;
let key = req
.headers()
.get(SEC_WEBSOCKET_KEY)
.ok_or_else(|| Error::BadRequest("Missing Sec-WebSocket-Key header".to_string()))?
.to_str()
.map_err(|_| Error::BadRequest("Invalid Sec-WebSocket-Key header".to_string()))?
.to_string();

let accept_key = derive_accept_key(key.as_bytes());

tokio::spawn(async move {
if let Err(e) = serve_websocket(websocket, gateway_addr).await {
error!("Error in websocket connection: {e}");
match hyper::upgrade::on(req).await {
Ok(upgraded) => {
let ws_stream = WebSocketStream::from_raw_socket(
TokioIo::new(upgraded),
tungstenite::protocol::Role::Server,
None,
)
.await;
if let Err(e) = serve_websocket(ws_stream, gateway_addr).await {
error!("Error in websocket connection: {e}");
}
}
Err(e) => error!("WebSocket upgrade error: {}", e),
}
});
let (parts, body) = res.into_parts();
let boxbody = body.map_err(|never| match never {}).boxed();
Ok(Response::from_parts(parts, boxbody))

let res = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(UPGRADE, "websocket")
.header(CONNECTION, "Upgrade")
.header(SEC_WEBSOCKET_ACCEPT, accept_key)
.body(empty())
.map_err(|e| Error::InternalServerError(Box::new(e)))?;

Ok(res)
}

/// Stream WebSocket frames from the client to the gateway server's TCP socket and vice versa.
#[instrument]
async fn serve_websocket(
websocket: HyperWebsocket,
#[instrument(skip(ws_stream))]
async fn serve_websocket<S>(
ws_stream: WebSocketStream<S>,
gateway_addr: SocketAddr,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let mut tcp_stream = tokio::net::TcpStream::connect(gateway_addr).await?;
let mut ws_io = WsIo::new(websocket.await?);
let mut ws_io = WsIo::new(ws_stream);
let (_, _) = tokio::io::copy_bidirectional(&mut ws_io, &mut tcp_stream).await?;
Ok(())
}
Expand Down
Loading
Loading