Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions payjoin-cli/src/app/v1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use anyhow::{anyhow, Context, Result};
use bitcoincore_rpc::bitcoin::Amount;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use hyper::body::{Buf, Bytes, Incoming};
use hyper::body::{Bytes, Incoming};
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Method, Request, Response, StatusCode};
Expand Down Expand Up @@ -88,12 +88,10 @@ impl AppTrait for App {
"Sent fallback transaction hex: {:#}",
payjoin::bitcoin::consensus::encode::serialize_hex(&fallback_tx)
);
let psbt = ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()).map_err(
|e| {
log::debug!("Error processing response: {e:?}");
anyhow!("Failed to process response {e}")
},
)?;
let psbt = ctx.process_response(&response.bytes().await?).map_err(|e| {
log::debug!("Error processing response: {e:?}");
anyhow!("Failed to process response {e}")
})?;

self.process_pj_response(psbt)?;
Ok(())
Expand Down Expand Up @@ -279,8 +277,8 @@ impl App {
let (parts, body) = req.into_parts();
let headers = Headers(&parts.headers);
let query_string = parts.uri.query().unwrap_or("");
let body = body.collect().await.map_err(|e| Implementation(e.into()))?.aggregate().reader();
let proposal = UncheckedProposal::from_request(body, query_string, headers)?;
let body = body.collect().await.map_err(|e| Implementation(e.into()))?.to_bytes();
let proposal = UncheckedProposal::from_request(&body, query_string, headers)?;

let payjoin_proposal = self.process_v1_proposal(proposal)?;
let psbt = payjoin_proposal.psbt();
Expand Down
2 changes: 1 addition & 1 deletion payjoin-cli/src/app/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ impl App {
println!("Posting Original PSBT Payload request...");
let response = post_request(req).await?;
println!("Sent fallback transaction");
match v1_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) {
match v1_ctx.process_response(&response.bytes().await?) {
Ok(psbt) => Ok(psbt),
Err(re) => {
println!("{re}");
Expand Down
6 changes: 2 additions & 4 deletions payjoin-ffi/src/send/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::io::Cursor;
use std::str::FromStr;
use std::sync::{Arc, Mutex};

Expand Down Expand Up @@ -185,10 +184,9 @@ impl From<payjoin::send::v1::V1Context> for V1Context {
impl V1Context {
///Decodes and validates the response.
/// Call this method with response from receiver to continue BIP78 flow. If the response is valid you will get appropriate PSBT that you should sign and broadcast.
pub fn process_response(&self, response: Vec<u8>) -> Result<String, ResponseError> {
let mut decoder = Cursor::new(response);
pub fn process_response(&self, response: &[u8]) -> Result<String, ResponseError> {
<payjoin::send::v1::V1Context as Clone>::clone(&self.0.clone())
.process_response(&mut decoder)
.process_response(response)
.map(|e| e.to_string())
.map_err(Into::into)
}
Expand Down
2 changes: 1 addition & 1 deletion payjoin-ffi/src/send/uni.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ impl From<super::V1Context> for V1Context {
impl V1Context {
/// Decodes and validates the response.
/// Call this method with response from receiver to continue BIP78 flow. If the response is valid you will get appropriate PSBT that you should sign and broadcast.
pub fn process_response(&self, response: Vec<u8>) -> Result<String, ResponseError> {
pub fn process_response(&self, response: &[u8]) -> Result<String, ResponseError> {
self.0.process_response(response)
}
}
Expand Down
2 changes: 1 addition & 1 deletion payjoin/src/receive/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ impl From<InternalPayloadError> for PayloadError {
#[derive(Debug)]
pub(crate) enum InternalPayloadError {
/// The payload is not valid utf-8
Utf8(std::string::FromUtf8Error),
Utf8(std::str::Utf8Error),
/// The payload is not a valid PSBT
ParsePsbt(bitcoin::psbt::PsbtParseError),
/// Invalid sender parameters
Expand Down
4 changes: 2 additions & 2 deletions payjoin/src/receive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,11 @@ impl<'a> From<&'a InputPair> for InternalInputPair<'a> {

/// Validate the payload of a Payjoin request for PSBT and Params sanity
pub(crate) fn parse_payload(
base64: String,
base64: &str,
query: &str,
supported_versions: &'static [Version],
) -> Result<(Psbt, Params), PayloadError> {
let unchecked_psbt = Psbt::from_str(&base64).map_err(InternalPayloadError::ParsePsbt)?;
let unchecked_psbt = Psbt::from_str(base64).map_err(InternalPayloadError::ParsePsbt)?;

let psbt = unchecked_psbt.validate().map_err(InternalPayloadError::InconsistentPsbt)?;
log::debug!("Received original psbt: {psbt:?}");
Expand Down
7 changes: 1 addition & 6 deletions payjoin/src/receive/v1/exclusive/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ pub struct RequestError(InternalRequestError);

#[derive(Debug)]
pub(crate) enum InternalRequestError {
/// I/O error while reading the request body
Io(std::io::Error),
/// A required HTTP header is missing from the request
MissingHeader(&'static str),
/// The Content-Type header has an invalid value
Expand All @@ -43,8 +41,7 @@ impl From<RequestError> for JsonReply {
use InternalRequestError::*;

match &e.0 {
Io(_)
| MissingHeader(_)
MissingHeader(_)
| InvalidContentType(_)
| InvalidContentLength(_)
| ContentLengthTooLarge(_) =>
Expand All @@ -56,7 +53,6 @@ impl From<RequestError> for JsonReply {
impl fmt::Display for RequestError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self.0 {
InternalRequestError::Io(e) => write!(f, "{e}"),
InternalRequestError::MissingHeader(header) => write!(f, "Missing header: {header}"),
InternalRequestError::InvalidContentType(content_type) =>
write!(f, "Invalid content type: {content_type}"),
Expand All @@ -70,7 +66,6 @@ impl fmt::Display for RequestError {
impl error::Error for RequestError {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match &self.0 {
InternalRequestError::Io(e) => Some(e),
InternalRequestError::InvalidContentLength(e) => Some(e),
InternalRequestError::MissingHeader(_) => None,
InternalRequestError::InvalidContentType(_) => None,
Expand Down
25 changes: 10 additions & 15 deletions payjoin/src/receive/v1/exclusive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ pub fn build_v1_pj_uri<'a>(

impl UncheckedProposal {
pub fn from_request(
body: impl std::io::Read,
body: &[u8],
Comment on lines 24 to +26
Copy link
Collaborator

@benalleng benalleng Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be helpful! I was having a really tough time figuring out how to get a Read across the ffi layer in the v2_to_v1 python integration test

query: &str,
headers: impl Headers,
) -> Result<Self, ReplyableError> {
let parsed_body = parse_body(headers, body).map_err(ReplyableError::V1)?;
let validated_body = validate_body(headers, body).map_err(ReplyableError::V1)?;

let base64 = String::from_utf8(parsed_body).map_err(InternalPayloadError::Utf8)?;
let base64 = std::str::from_utf8(validated_body).map_err(InternalPayloadError::Utf8)?;

let (psbt, params) = crate::receive::parse_payload(base64, query, SUPPORTED_VERSIONS)
.map_err(ReplyableError::Payload)?;
Expand All @@ -41,10 +41,7 @@ impl UncheckedProposal {
/// Validate the request headers for a Payjoin request
///
/// [`RequestError`] should only be produced here.
fn parse_body(
headers: impl Headers,
mut body: impl std::io::Read,
) -> Result<Vec<u8>, RequestError> {
fn validate_body(headers: impl Headers, body: &[u8]) -> Result<&[u8], RequestError> {
let content_type = headers
.get_header("content-type")
.ok_or(InternalRequestError::MissingHeader("Content-Type"))?;
Expand All @@ -61,9 +58,7 @@ fn parse_body(
return Err(InternalRequestError::ContentLengthTooLarge(content_length).into());
}

let mut buf = vec![0; content_length];
body.read_exact(&mut buf).map_err(InternalRequestError::Io)?;
Ok(buf)
Ok(&body[..content_length])
}

#[cfg(test)]
Expand Down Expand Up @@ -99,9 +94,9 @@ mod tests {
padded_body.resize(MAX_CONTENT_LENGTH + 1, 0);
let headers = MockHeaders::new(padded_body.len() as u64);

let parsed_request = parse_body(headers.clone(), padded_body.as_slice());
assert!(parsed_request.is_err());
match parsed_request {
let validated_request = validate_body(headers.clone(), padded_body.as_slice());
assert!(validated_request.is_err());
match validated_request {
Ok(_) => panic!("Expected error, got success"),
Err(error) => {
assert_eq!(
Expand All @@ -119,8 +114,8 @@ mod tests {
fn test_from_request() -> Result<(), Box<dyn std::error::Error>> {
let body = ORIGINAL_PSBT.as_bytes();
let headers = MockHeaders::new(body.len() as u64);
let parsed_request = parse_body(headers.clone(), body);
assert!(parsed_request.is_ok());
let validated_request = validate_body(headers.clone(), body);
assert!(validated_request.is_ok());

let proposal = UncheckedProposal::from_request(body, QUERY_PARAMS, headers)?;

Expand Down
12 changes: 6 additions & 6 deletions payjoin/src/receive/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ impl Receiver<WithContext> {
Some(body) => body,
None => return Ok(None),
};
match String::from_utf8(body.clone()) {
match std::str::from_utf8(&body) {
// V1 response bodies are utf8 plaintext
Ok(response) => Ok(Some(Receiver { state: self.extract_proposal_from_v1(response)? })),
// V2 response bodies are encrypted binary
Expand All @@ -208,28 +208,28 @@ impl Receiver<WithContext> {

fn extract_proposal_from_v1(
&mut self,
response: String,
response: &str,
) -> Result<UncheckedProposal, ReplyableError> {
self.unchecked_from_payload(response)
}

fn extract_proposal_from_v2(&mut self, response: Vec<u8>) -> Result<UncheckedProposal, Error> {
let (payload_bytes, e) = decrypt_message_a(&response, self.context.s.secret_key().clone())?;
self.context.e = Some(e);
let payload = String::from_utf8(payload_bytes)
let payload = std::str::from_utf8(&payload_bytes)
.map_err(|e| Error::ReplyToSender(InternalPayloadError::Utf8(e).into()))?;
self.unchecked_from_payload(payload).map_err(Error::ReplyToSender)
}

fn unchecked_from_payload(
&mut self,
payload: String,
payload: &str,
) -> Result<UncheckedProposal, ReplyableError> {
let (base64, padded_query) = payload.split_once('\n').unwrap_or_default();
let query = padded_query.trim_matches('\0');
log::trace!("Received query: {query}, base64: {base64}"); // my guess is no \n so default is wrong
let (psbt, mut params) = parse_payload(base64.to_string(), query, SUPPORTED_VERSIONS)
.map_err(ReplyableError::Payload)?;
let (psbt, mut params) =
parse_payload(base64, query, SUPPORTED_VERSIONS).map_err(ReplyableError::Payload)?;

// Output substitution must be disabled for V1 sessions in V2 contexts.
//
Expand Down
3 changes: 0 additions & 3 deletions payjoin/src/send/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ pub struct ValidationError(InternalValidationError);
#[derive(Debug)]
pub(crate) enum InternalValidationError {
Parse,
Io(std::io::Error),
ContentTooLarge,
Proposal(InternalProposalError),
#[cfg(feature = "v2")]
Expand All @@ -120,7 +119,6 @@ impl fmt::Display for ValidationError {

match &self.0 {
Parse => write!(f, "couldn't decode as PSBT or JSON",),
Io(e) => write!(f, "couldn't read PSBT: {e}"),
ContentTooLarge => write!(f, "content is larger than {MAX_CONTENT_LENGTH} bytes"),
Proposal(e) => write!(f, "proposal PSBT error: {e}"),
#[cfg(feature = "v2")]
Expand All @@ -135,7 +133,6 @@ impl std::error::Error for ValidationError {

match &self.0 {
Parse => None,
Io(error) => Some(error),
ContentTooLarge => None,
Proposal(e) => Some(e),
#[cfg(feature = "v2")]
Expand Down
31 changes: 9 additions & 22 deletions payjoin/src/send/v1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
//! [`bitmask-core`](https://github.com/diba-io/bitmask-core) BDK integration. Bring your own
//! wallet and http client.

use std::io::{BufRead, BufReader};

use bitcoin::psbt::Psbt;
use bitcoin::{FeeRate, ScriptBuf, Weight};
use error::{BuildSenderError, InternalBuildSenderError};
Expand Down Expand Up @@ -275,18 +273,12 @@ impl V1Context {
/// Call this method with response from receiver to continue BIP78 flow. If the response is
/// valid you will get appropriate PSBT that you should sign and broadcast.
#[inline]
pub fn process_response(
self,
response: &mut impl std::io::Read,
) -> Result<Psbt, ResponseError> {
let mut buf_reader = BufReader::with_capacity(MAX_CONTENT_LENGTH + 1, response);
let buffer = buf_reader.fill_buf().map_err(InternalValidationError::Io)?;

if buffer.len() > MAX_CONTENT_LENGTH {
pub fn process_response(self, response: &[u8]) -> Result<Psbt, ResponseError> {
if response.len() > MAX_CONTENT_LENGTH {
return Err(ResponseError::from(InternalValidationError::ContentTooLarge));
}

let res_str = std::str::from_utf8(buffer).map_err(|_| InternalValidationError::Parse)?;
let res_str = std::str::from_utf8(response).map_err(|_| InternalValidationError::Parse)?;
let proposal = Psbt::from_str(res_str).map_err(|_| ResponseError::parse(res_str))?;
self.psbt_context.process_proposal(proposal).map_err(Into::into)
}
Expand Down Expand Up @@ -334,7 +326,7 @@ mod test {
"message": "This version of payjoin is not supported."
})
.to_string();
match ctx.process_response(&mut known_json_error.as_bytes()) {
match ctx.process_response(known_json_error.as_bytes()) {
Err(ResponseError::WellKnown(WellKnownError {
code: ErrorCode::VersionUnsupported,
..
Expand All @@ -348,27 +340,23 @@ mod test {
"message": "This version of payjoin is not supported."
})
.to_string();
match ctx.process_response(&mut invalid_json_error.as_bytes()) {
match ctx.process_response(invalid_json_error.as_bytes()) {
Err(ResponseError::Validation(_)) => (),
_ => panic!("Expected unrecognized JSON error"),
}
}

#[test]
fn process_response_valid() {
let mut cursor = std::io::Cursor::new(PAYJOIN_PROPOSAL.as_bytes());

let ctx = create_v1_context();
let response = ctx.process_response(&mut cursor);
let response = ctx.process_response(PAYJOIN_PROPOSAL.as_bytes());
assert!(response.is_ok())
}

#[test]
fn process_response_invalid_psbt() {
let mut cursor = std::io::Cursor::new(INVALID_PSBT.as_bytes());

let ctx = create_v1_context();
let response = ctx.process_response(&mut cursor);
let response = ctx.process_response(INVALID_PSBT.as_bytes());
match response {
Ok(_) => panic!("Invalid PSBT should have caused an error"),
Err(error) => match error {
Expand All @@ -386,11 +374,10 @@ mod test {
#[test]
fn process_response_invalid_utf8() {
// In UTF-8, 0xF0 represents the start of a 4-byte sequence, so 0xF0 by itself is invalid
let invalid_utf8 = [0xF0];
let mut cursor = std::io::Cursor::new(invalid_utf8);
let invalid_utf8 = &[0xF0];

let ctx = create_v1_context();
let response = ctx.process_response(&mut cursor);
let response = ctx.process_response(invalid_utf8);
match response {
Ok(_) => panic!("Invalid UTF-8 should have caused an error"),
Err(error) => match error {
Expand Down
Loading
Loading