Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion payjoin-cli/src/app/v1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ impl AppTrait for App {
let (req, ctx) = SenderBuilder::new(psbt, uri.clone())
.build_recommended(fee_rate)
.with_context(|| "Failed to build payjoin request")?
.extract_v1();
.create_v1_post_request();
let http = http_agent()?;
let body = String::from_utf8(req.body.clone()).unwrap();
println!("Sending fallback request to {}", &req.url);
Expand Down
17 changes: 9 additions & 8 deletions payjoin-cli/src/app/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ impl App {
match self.post_original_proposal(context.clone(), persister).await {
Ok(()) => (),
Err(_) => {
let (req, v1_ctx) = context.extract_v1();
let (req, v1_ctx) = context.create_v1_post_request();
let response = post_request(req).await?;
let psbt = Arc::new(
v1_ctx.process_response(response.bytes().await?.to_vec().as_slice())?,
Expand All @@ -208,8 +208,9 @@ impl App {
sender: Sender<WithReplyKey>,
persister: &SenderPersister,
) -> Result<()> {
let (req, ctx) = sender
.extract_v2(self.unwrap_relay_or_else_fetch(Some(sender.endpoint().clone())).await?)?;
let (req, ctx) = sender.create_v2_post_request(
self.unwrap_relay_or_else_fetch(Some(sender.endpoint().clone())).await?,
)?;
let response = post_request(req).await?;
println!("Posted original proposal...");
let sender = sender.process_response(&response.bytes().await?, ctx).save(persister)?;
Expand All @@ -224,7 +225,7 @@ impl App {
let mut session = sender.clone();
// Long poll until we get a response
loop {
let (req, ctx) = session.extract_req(
let (req, ctx) = session.create_poll_request(
self.unwrap_relay_or_else_fetch(Some(session.endpoint().clone())).await?,
)?;
let response = post_request(req).await?;
Expand Down Expand Up @@ -260,11 +261,11 @@ impl App {

let mut session = session;
loop {
let (req, context) = session.extract_req(&ohttp_relay)?;
let (req, context) = session.create_poll_request(&ohttp_relay)?;
println!("Polling receive request...");
let ohttp_response = post_request(req).await?;
let state_transition = session
.process_res(ohttp_response.bytes().await?.to_vec().as_slice(), context)
.process_response(ohttp_response.bytes().await?.to_vec().as_slice(), context)
.save(persister);
match state_transition {
Ok(OptionalTransitionOutcome::Progress(next_state)) => {
Expand Down Expand Up @@ -439,11 +440,11 @@ impl App {
persister: &ReceiverPersister,
) -> Result<()> {
let (req, ohttp_ctx) = proposal
.extract_req(&self.unwrap_relay_or_else_fetch(None).await?)
.create_post_request(&self.unwrap_relay_or_else_fetch(None).await?)
.map_err(|e| anyhow!("v2 req extraction failed {}", e))?;
let res = post_request(req).await?;
let payjoin_psbt = proposal.psbt().clone();
proposal.process_res(&res.bytes().await?, ohttp_ctx).save(persister)?;
proposal.process_response(&res.bytes().await?, ohttp_ctx).save(persister)?;
println!(
"Response successful. Watch mempool for successful Payjoin. TXID: {}",
payjoin_psbt.extract_tx_unchecked_fee_rate().compute_txid()
Expand Down
45 changes: 18 additions & 27 deletions payjoin-directory/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,42 +54,33 @@ impl DbPool {
}

/// Peek using [`DEFAULT_COLUMN`] as the channel type.
pub async fn push_default(&self, subdirectory_id: &ShortId, data: Vec<u8>) -> Result<()> {
self.push(subdirectory_id, DEFAULT_COLUMN, data).await
pub async fn push_default(&self, mailbox_id: &ShortId, data: Vec<u8>) -> Result<()> {
self.push(mailbox_id, DEFAULT_COLUMN, data).await
}

pub async fn peek_default(&self, subdirectory_id: &ShortId) -> Result<Vec<u8>> {
self.peek_with_timeout(subdirectory_id, DEFAULT_COLUMN).await
pub async fn peek_default(&self, mailbox_id: &ShortId) -> Result<Vec<u8>> {
self.peek_with_timeout(mailbox_id, DEFAULT_COLUMN).await
}

pub async fn push_v1(&self, subdirectory_id: &ShortId, data: Vec<u8>) -> Result<()> {
self.push(subdirectory_id, PJ_V1_COLUMN, data).await
pub async fn push_v1(&self, mailbox_id: &ShortId, data: Vec<u8>) -> Result<()> {
self.push(mailbox_id, PJ_V1_COLUMN, data).await
}

/// Peek using [`PJ_V1_COLUMN`] as the channel type.
pub async fn peek_v1(&self, subdirectory_id: &ShortId) -> Result<Vec<u8>> {
self.peek_with_timeout(subdirectory_id, PJ_V1_COLUMN).await
pub async fn peek_v1(&self, mailbox_id: &ShortId) -> Result<Vec<u8>> {
self.peek_with_timeout(mailbox_id, PJ_V1_COLUMN).await
}

async fn push(
&self,
subdirectory_id: &ShortId,
channel_type: &str,
data: Vec<u8>,
) -> Result<()> {
async fn push(&self, mailbox_id: &ShortId, channel_type: &str, data: Vec<u8>) -> Result<()> {
let mut conn = self.client.get_async_connection().await?;
let key = channel_name(subdirectory_id, channel_type);
let key = channel_name(mailbox_id, channel_type);
() = conn.set(&key, data.clone()).await?;
() = conn.publish(&key, "updated").await?;
Ok(())
}

async fn peek_with_timeout(
&self,
subdirectory_id: &ShortId,
channel_type: &str,
) -> Result<Vec<u8>> {
match tokio::time::timeout(self.timeout, self.peek(subdirectory_id, channel_type)).await {
async fn peek_with_timeout(&self, mailbox_id: &ShortId, channel_type: &str) -> Result<Vec<u8>> {
match tokio::time::timeout(self.timeout, self.peek(mailbox_id, channel_type)).await {
Ok(redis_result) => match redis_result {
Ok(result) => Ok(result),
Err(redis_err) => Err(Error::Redis(redis_err)),
Expand All @@ -98,11 +89,11 @@ impl DbPool {
}
}

async fn peek(&self, subdirectory_id: &ShortId, channel_type: &str) -> RedisResult<Vec<u8>> {
async fn peek(&self, mailbox_id: &ShortId, channel_type: &str) -> RedisResult<Vec<u8>> {
let mut conn = self.client.get_async_connection().await?;
let key = channel_name(subdirectory_id, channel_type);
let key = channel_name(mailbox_id, channel_type);

// Attempt to fetch existing content for the given subdirectory_id and channel_type
// Attempt to fetch existing content for the given mailbox_id and channel_type
if let Ok(data) = conn.get::<_, Vec<u8>>(&key).await {
if !data.is_empty() {
return Ok(data);
Expand All @@ -112,7 +103,7 @@ impl DbPool {

// Set up a temporary listener for changes
let mut pubsub_conn = self.client.get_async_connection().await?.into_pubsub();
let channel_name = channel_name(subdirectory_id, channel_type);
let channel_name = channel_name(mailbox_id, channel_type);
pubsub_conn.subscribe(&channel_name).await?;

// Use a block to limit the scope of the mutable borrow
Expand Down Expand Up @@ -146,6 +137,6 @@ impl DbPool {
}
}

fn channel_name(subdirectory_id: &ShortId, channel_type: &str) -> Vec<u8> {
(subdirectory_id.to_string() + channel_type).into_bytes()
fn channel_name(mailbox_id: &ShortId, channel_type: &str) -> Vec<u8> {
(mailbox_id.to_string() + channel_type).into_bytes()
}
14 changes: 7 additions & 7 deletions payjoin-directory/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,8 @@ async fn handle_v2(
let path_segments: Vec<&str> = path.split('/').collect();
debug!("handle_v2: {:?}", &path_segments);
match (parts.method, path_segments.as_slice()) {
(Method::POST, &["", id]) => post_subdir(id, body, pool).await,
(Method::GET, &["", id]) => get_subdir(id, pool).await,
(Method::POST, &["", id]) => post_mailbox(id, body, pool).await,
(Method::GET, &["", id]) => get_mailbox(id, pool).await,
(Method::PUT, &["", id]) => put_payjoin_v1(id, body, pool).await,
_ => Ok(not_found()),
}
Expand Down Expand Up @@ -371,7 +371,7 @@ impl From<hyper::http::Error> for HandlerError {

impl From<ShortIdError> for HandlerError {
fn from(_: ShortIdError) -> Self {
HandlerError::BadRequest(anyhow::anyhow!("subdirectory ID must be 13 bech32 characters"))
HandlerError::BadRequest(anyhow::anyhow!("mailbox ID must be 13 bech32 characters"))
}
}

Expand Down Expand Up @@ -443,13 +443,13 @@ async fn put_payjoin_v1(
}
}

async fn post_subdir(
async fn post_mailbox(
id: &str,
body: BoxBody<Bytes, hyper::Error>,
pool: DbPool,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> {
let none_response = Response::builder().status(StatusCode::OK).body(empty())?;
trace!("post_subdir");
trace!("post_mailbox");

let id = ShortId::from_str(id)?;

Expand All @@ -465,11 +465,11 @@ async fn post_subdir(
}
}

async fn get_subdir(
async fn get_mailbox(
id: &str,
pool: DbPool,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> {
trace!("get_subdir");
trace!("get_mailbox");
let id = ShortId::from_str(id)?;
let timeout_response = Response::builder().status(StatusCode::ACCEPTED).body(empty())?;
handle_peek(pool.peek_default(&id).await, timeout_response)
Expand Down
39 changes: 19 additions & 20 deletions payjoin-ffi/python/test/test_payjoin_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def process_receiver_proposal(self, receiver: ReceiveSession, recv_persist
if res is None:
return None
return res

if receiver.is_UNCHECKED_PROPOSAL():
return await self.process_unchecked_proposal(receiver.inner, recv_persister)
if receiver.is_MAYBE_INPUTS_OWNED():
Expand All @@ -78,56 +78,55 @@ async def process_receiver_proposal(self, receiver: ReceiveSession, recv_persist
return await self.process_provisional_proposal(receiver.inner, recv_persister)
if receiver.is_PAYJOIN_PROPOSAL():
return receiver

raise Exception(f"Unknown receiver state: {receiver}")



def create_receiver_context(self, receiver_address: bitcoinffi.Address, directory: Url, ohttp_keys: OhttpKeys, recv_persister: InMemoryReceiverSessionEventLog) -> Initialized:
receiver = UninitializedReceiver().create_session(address=receiver_address, directory=directory.as_string(), ohttp_keys=ohttp_keys, expire_after=None).save(recv_persister)
return receiver

async def retrieve_receiver_proposal(self, receiver: Initialized, recv_persister: InMemoryReceiverSessionEventLog, ohttp_relay: Url):
agent = httpx.AsyncClient()
request: RequestResponse = receiver.extract_req(ohttp_relay.as_string())
request: RequestResponse = receiver.create_poll_request(ohttp_relay.as_string())
response = await agent.post(
url=request.request.url.as_string(),
headers={"Content-Type": request.request.content_type},
content=request.request.body
)
res = receiver.process_res(response.content, request.client_response).save(recv_persister)
res = receiver.process_response(response.content, request.client_response).save(recv_persister)
if res.is_none():
return None
proposal = res.success()
return await self.process_unchecked_proposal(proposal, recv_persister)

Comment on lines -102 to +101
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are still quite a few of these stragglers I'll take care of

async def process_unchecked_proposal(self, proposal: UncheckedProposal, recv_persister: InMemoryReceiverSessionEventLog) :
receiver = proposal.check_broadcast_suitability(None, MempoolAcceptanceCallback(self.receiver)).save(recv_persister)
return await self.process_maybe_inputs_owned(receiver, recv_persister)

async def process_maybe_inputs_owned(self, proposal: MaybeInputsOwned, recv_persister: InMemoryReceiverSessionEventLog):
maybe_inputs_owned = proposal.check_inputs_not_owned(IsScriptOwnedCallback(self.receiver)).save(recv_persister)
return await self.process_maybe_inputs_seen(maybe_inputs_owned, recv_persister)

async def process_maybe_inputs_seen(self, proposal: MaybeInputsSeen, recv_persister: InMemoryReceiverSessionEventLog):
outputs_unknown = proposal.check_no_inputs_seen_before(CheckInputsNotSeenCallback(self.receiver)).save(recv_persister)
return await self.process_outputs_unknown(outputs_unknown, recv_persister)

async def process_outputs_unknown(self, proposal: OutputsUnknown, recv_persister: InMemoryReceiverSessionEventLog):
wants_outputs = proposal.identify_receiver_outputs(IsScriptOwnedCallback(self.receiver)).save(recv_persister)
return await self.process_wants_outputs(wants_outputs, recv_persister)

async def process_wants_outputs(self, proposal: WantsOutputs, recv_persister: InMemoryReceiverSessionEventLog):
wants_inputs = proposal.commit_outputs().save(recv_persister)
return await self.process_wants_inputs(wants_inputs, recv_persister)

async def process_wants_inputs(self, proposal: WantsInputs, recv_persister: InMemoryReceiverSessionEventLog):
provisional_proposal = proposal.contribute_inputs(get_inputs(self.receiver)).commit_inputs().save(recv_persister)
return await self.process_provisional_proposal(provisional_proposal, recv_persister)

async def process_provisional_proposal(self, proposal: ProvisionalProposal, recv_persister: InMemoryReceiverSessionEventLog):
payjoin_proposal = proposal.finalize_proposal(ProcessPsbtCallback(self.receiver), 1, 10).save(recv_persister)
return ReceiveSession.PAYJOIN_PROPOSAL(payjoin_proposal)

async def test_integration_v2_to_v2(self):
try:
receiver_address = bitcoinffi.Address(json.loads(self.receiver.call("getnewaddress", [])), bitcoinffi.Network.REGTEST)
Expand All @@ -154,7 +153,7 @@ async def test_integration_v2_to_v2(self):
pj_uri = session.pj_uri()
psbt = build_sweep_psbt(self.sender, pj_uri)
req_ctx: WithReplyKey = SenderBuilder(psbt, pj_uri).build_recommended(1000).save(sender_persister)
request: RequestV2PostContext = req_ctx.extract_v2(ohttp_relay.as_string())
request: RequestV2PostContext = req_ctx.create_v2_post_request(ohttp_relay.as_string())
response = await agent.post(
url=request.request.url.as_string(),
headers={"Content-Type": request.request.content_type},
Expand All @@ -172,7 +171,7 @@ async def test_integration_v2_to_v2(self):
self.assertEqual(payjoin_proposal.is_PAYJOIN_PROPOSAL(), True)

payjoin_proposal = payjoin_proposal.inner
request: RequestResponse = payjoin_proposal.extract_req(ohttp_relay.as_string())
request: RequestResponse = payjoin_proposal.create_post_request(ohttp_relay.as_string())
response = await agent.post(
url=request.request.url.as_string(),
headers={"Content-Type": request.request.content_type},
Expand All @@ -184,7 +183,7 @@ async def test_integration_v2_to_v2(self):
# Inside the Sender:
# Sender checks, signs, finalizes, extracts, and broadcasts
# Replay post fallback to get the response
request: RequestOhttpContext = send_ctx.extract_req(ohttp_relay.as_string())
request: RequestOhttpContext = send_ctx.create_poll_request(ohttp_relay.as_string())
response = await agent.post(
url=request.request.url.as_string(),
headers={"Content-Type": request.request.content_type},
Expand Down Expand Up @@ -254,7 +253,7 @@ def callback(self, tx):
return res
except Exception as e:
print(f"An error occurred: {e}")
return None
return None

class IsScriptOwnedCallback(IsScriptOwned):
def __init__(self, connection: RpcClient):
Expand All @@ -280,7 +279,7 @@ def __init__(self, connection: RpcClient):
self.connection = connection

def callback(self, psbt: str):
res = json.loads(self.connection.call("walletprocesspsbt", [psbt]))
res = json.loads(self.connection.call("walletprocesspsbt", [psbt]))
return res['psbt']

if __name__ == "__main__":
Expand Down
23 changes: 12 additions & 11 deletions payjoin-ffi/src/receive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,21 +214,22 @@ impl
}

impl Initialized {
pub fn extract_req(
/// Construct an OHTTP encapsulated GET request, polling the mailbox for the Original PSBT
pub fn create_poll_request(
&self,
ohttp_relay: String,
) -> Result<(Request, ClientResponse), ReceiverError> {
self.0
.clone()
.extract_req(ohttp_relay)
.create_poll_request(ohttp_relay)
.map(|(req, ctx)| (req.into(), ctx.into()))
.map_err(Into::into)
}

///The response can either be an UncheckedProposal or an ACCEPTED message indicating no UncheckedProposal is available yet.
pub fn process_res(&self, body: &[u8], ctx: &ClientResponse) -> InitializedTransition {
/// The response can either be an UncheckedProposal or an ACCEPTED message indicating no UncheckedProposal is available yet.
pub fn process_response(&self, body: &[u8], ctx: &ClientResponse) -> InitializedTransition {
InitializedTransition(Arc::new(RwLock::new(Some(
self.0.clone().process_res(body, ctx.into()),
self.0.clone().process_response(body, ctx.into()),
))))
}

Expand Down Expand Up @@ -845,30 +846,30 @@ impl PayjoinProposal {
.to_string()
}

/// Extract an OHTTP Encapsulated HTTP POST request for the Proposal PSBT
pub fn extract_req(
/// Construct an OHTTP Encapsulated HTTP POST request for the Proposal PSBT
pub fn create_post_request(
&self,
ohttp_relay: String,
) -> Result<(Request, ClientResponse), ReceiverError> {
self.0
.clone()
.extract_req(ohttp_relay)
.create_post_request(ohttp_relay)
.map_err(Into::into)
.map(|(req, ctx)| (req.into(), ctx.into()))
}

///Processes the response for the final POST message from the receiver client in the v2 Payjoin protocol.
/// Processes the response for the final POST message from the receiver client in the v2 Payjoin protocol.
///
/// This function decapsulates the response using the provided OHTTP context. If the response status is successful, it indicates that the Payjoin proposal has been accepted. Otherwise, it returns an error with the status code.
///
/// After this function is called, the receiver can either wait for the Payjoin transaction to be broadcast or choose to broadcast the original PSBT.
pub fn process_res(
pub fn process_response(
&self,
body: &[u8],
ohttp_context: &ClientResponse,
) -> PayjoinProposalTransition {
PayjoinProposalTransition(Arc::new(RwLock::new(Some(
self.0.clone().process_res(body, ohttp_context.into()),
self.0.clone().process_response(body, ohttp_context.into()),
))))
}
}
Expand Down
Loading
Loading