diff --git a/payjoin-cli/src/app/v2/mod.rs b/payjoin-cli/src/app/v2/mod.rs index e652f5562..ec4df681f 100644 --- a/payjoin-cli/src/app/v2/mod.rs +++ b/payjoin-cli/src/app/v2/mod.rs @@ -9,7 +9,7 @@ use payjoin::receive::v2::{ WithContext, }; use payjoin::receive::{Error, ReplyableError}; -use payjoin::send::v2::{Sender, SenderBuilder}; +use payjoin::send::v2::{Sender, SenderBuilder, WithReplyKey}; use payjoin::{ImplementationError, Uri}; use tokio::sync::watch; @@ -137,7 +137,7 @@ impl AppTrait for App { impl App { #[allow(clippy::incompatible_msrv)] - async fn spawn_payjoin_sender(&self, mut req_ctx: Sender) -> Result<()> { + async fn spawn_payjoin_sender(&self, mut req_ctx: Sender) -> Result<()> { let mut interrupt = self.interrupt.clone(); tokio::select! { res = self.long_poll_post(&mut req_ctx) => { @@ -200,7 +200,7 @@ impl App { Ok(()) } - async fn long_poll_post(&self, req_ctx: &mut Sender) -> Result { + async fn long_poll_post(&self, req_ctx: &mut Sender) -> Result { let ohttp_relay = self.unwrap_relay_or_else_fetch().await?; match req_ctx.extract_v2(ohttp_relay.clone()) { diff --git a/payjoin-cli/src/db/v2.rs b/payjoin-cli/src/db/v2.rs index 3f6b74871..f33161d4e 100644 --- a/payjoin-cli/src/db/v2.rs +++ b/payjoin-cli/src/db/v2.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use bitcoincore_rpc::jsonrpc::serde_json; use payjoin::persist::{Persister, Value}; use payjoin::receive::v2::{Receiver, ReceiverToken, WithContext}; -use payjoin::send::v2::{Sender, SenderToken}; +use payjoin::send::v2::{Sender, SenderToken, WithReplyKey}; use sled::Tree; use url::Url; @@ -14,10 +14,13 @@ impl SenderPersister { pub fn new(db: Arc) -> Self { Self(db) } } -impl Persister for SenderPersister { +impl Persister> for SenderPersister { type Token = SenderToken; type Error = crate::db::error::Error; - fn save(&mut self, value: Sender) -> std::result::Result { + fn save( + &mut self, + value: Sender, + ) -> std::result::Result { let send_tree = self.0 .0.open_tree("send_sessions")?; let key = value.key(); let value = serde_json::to_vec(&value).map_err(Error::Serialize)?; @@ -26,7 +29,7 @@ impl Persister for SenderPersister { Ok(key) } - fn load(&self, key: SenderToken) -> std::result::Result { + fn load(&self, key: SenderToken) -> std::result::Result, Self::Error> { let send_tree = self.0 .0.open_tree("send_sessions")?; let value = send_tree.get(key.as_ref())?.ok_or(Error::NotFound(key.to_string()))?; serde_json::from_slice(&value).map_err(Error::Deserialize) @@ -79,21 +82,23 @@ impl Database { Ok(()) } - pub(crate) fn get_send_sessions(&self) -> Result> { + pub(crate) fn get_send_sessions(&self) -> Result>> { let send_tree: Tree = self.0.open_tree("send_sessions")?; let mut sessions = Vec::new(); for item in send_tree.iter() { let (_, value) = item?; - let session: Sender = serde_json::from_slice(&value).map_err(Error::Deserialize)?; + let session: Sender = + serde_json::from_slice(&value).map_err(Error::Deserialize)?; sessions.push(session); } Ok(sessions) } - pub(crate) fn get_send_session(&self, pj_url: &Url) -> Result> { + pub(crate) fn get_send_session(&self, pj_url: &Url) -> Result>> { let send_tree = self.0.open_tree("send_sessions")?; if let Some(val) = send_tree.get(pj_url.as_str())? { - let session: Sender = serde_json::from_slice(&val).map_err(Error::Deserialize)?; + let session: Sender = + serde_json::from_slice(&val).map_err(Error::Deserialize)?; Ok(Some(session)) } else { Ok(None) diff --git a/payjoin-ffi/python/test/test_payjoin_integration_test.py b/payjoin-ffi/python/test/test_payjoin_integration_test.py index dbc763ce6..90a56243c 100644 --- a/payjoin-ffi/python/test/test_payjoin_integration_test.py +++ b/payjoin-ffi/python/test/test_payjoin_integration_test.py @@ -50,15 +50,15 @@ def __init__(self): super().__init__() self.senders = {} - def save(self, sender: Sender) -> SenderToken: + def save(self, sender: WithReplyKey) -> SenderToken: self.senders[str(sender.key())] = sender.to_json() return sender.key() - def load(self, token: SenderToken) -> Sender: + def load(self, token: SenderToken) -> WithReplyKey: token = str(token) if token not in self.senders.keys(): raise ValueError(f"Token not found: {token}") - return Sender.from_json(self.senders[token]) + return WithReplyKey.from_json(self.senders[token]) class TestPayjoin(unittest.IsolatedAsyncioTestCase): @classmethod @@ -106,7 +106,7 @@ async def test_integration_v2_to_v2(self): new_sender = SenderBuilder(psbt, pj_uri).build_recommended(1000) persister = InMemorySenderPersister() token = new_sender.persist(persister) - req_ctx: Sender = Sender.load(token, persister) + req_ctx: WithReplyKey = WithReplyKey.load(token, persister) request: RequestV2PostContext = req_ctx.extract_v2(ohttp_relay.as_string()) response = await agent.post( url=request.request.url.as_string(), diff --git a/payjoin-ffi/python/test/test_payjoin_unit_test.py b/payjoin-ffi/python/test/test_payjoin_unit_test.py index 689cb8593..76aef390d 100644 --- a/payjoin-ffi/python/test/test_payjoin_unit_test.py +++ b/payjoin-ffi/python/test/test_payjoin_unit_test.py @@ -64,15 +64,15 @@ class InMemorySenderPersister(payjoin.payjoin_ffi.SenderPersister): def __init__(self): self.senders = {} - def save(self, sender: payjoin.Sender) -> payjoin.SenderToken: + def save(self, sender: payjoin.WithReplyKey) -> payjoin.SenderToken: self.senders[str(sender.key())] = sender.to_json() return sender.key() - def load(self, token: payjoin.SenderToken) -> payjoin.Sender: + def load(self, token: payjoin.SenderToken) -> payjoin.WithReplyKey: token = str(token) if token not in self.senders.keys(): raise ValueError(f"Token not found: {token}") - return payjoin.Sender.from_json(self.senders[token]) + return payjoin.WithReplyKey.from_json(self.senders[token]) class TestSenderPersistence(unittest.TestCase): def test_sender_persistence(self): @@ -93,7 +93,7 @@ def test_sender_persistence(self): psbt = "cHNidP8BAHMCAAAAAY8nutGgJdyYGXWiBEb45Hoe9lWGbkxh/6bNiOJdCDuDAAAAAAD+////AtyVuAUAAAAAF6kUHehJ8GnSdBUOOv6ujXLrWmsJRDCHgIQeAAAAAAAXqRR3QJbbz0hnQ8IvQ0fptGn+votneofTAAAAAAEBIKgb1wUAAAAAF6kU3k4ekGHKWRNbA1rV5tR5kEVDVNCHAQcXFgAUx4pFclNVgo1WWAdN1SYNX8tphTABCGsCRzBEAiB8Q+A6dep+Rz92vhy26lT0AjZn4PRLi8Bf9qoB/CMk0wIgP/Rj2PWZ3gEjUkTlhDRNAQ0gXwTO7t9n+V14pZ6oljUBIQMVmsAaoNWHVMS02LfTSe0e388LNitPa1UQZyOihY+FFgABABYAFEb2Giu6c4KO5YW0pfw3lGp9jMUUAAA=" new_sender = payjoin.SenderBuilder(psbt, uri).build_recommended(1000) token = new_sender.persist(persister) - payjoin.Sender.load(token, persister) + payjoin.WithReplyKey.load(token, persister) if __name__ == "__main__": unittest.main() diff --git a/payjoin-ffi/src/send/mod.rs b/payjoin-ffi/src/send/mod.rs index abfec19fc..ae2717a02 100644 --- a/payjoin-ffi/src/send/mod.rs +++ b/payjoin-ffi/src/send/mod.rs @@ -29,7 +29,7 @@ impl SenderBuilder { /// Prepare an HTTP request and request context to process the response /// /// Call [`SenderBuilder::build_recommended()`] or other `build` methods - /// to create a [`Sender`] + /// to create a [`WithReplyKey`] pub fn new(psbt: String, uri: PjUri) -> Result { let psbt = payjoin::bitcoin::psbt::Psbt::from_str(psbt.as_str())?; Ok(payjoin::send::v2::SenderBuilder::new(psbt, uri.into()).into()) @@ -114,7 +114,7 @@ impl From for NewSender { } impl NewSender { - pub fn persist>( + pub fn persist>>( &self, persister: &mut P, ) -> Result { @@ -123,18 +123,20 @@ impl NewSender { } #[derive(Clone)] -pub struct Sender(payjoin::send::v2::Sender); +pub struct WithReplyKey(payjoin::send::v2::Sender); -impl From for Sender { - fn from(value: payjoin::send::v2::Sender) -> Self { Self(value) } +impl From> for WithReplyKey { + fn from(value: payjoin::send::v2::Sender) -> Self { + Self(value) + } } -impl From for payjoin::send::v2::Sender { - fn from(value: Sender) -> Self { value.0 } +impl From for payjoin::send::v2::Sender { + fn from(value: WithReplyKey) -> Self { value.0 } } -impl Sender { - pub fn load>( +impl WithReplyKey { + pub fn load>>( token: P::Token, persister: &P, ) -> Result { @@ -164,7 +166,9 @@ impl Sender { } pub fn from_json(json: &str) -> Result { - serde_json::from_str::(json).map_err(Into::into).map(Into::into) + serde_json::from_str::>(json) + .map_err(Into::into) + .map(Into::into) } pub fn key(&self) -> SenderToken { self.0.key() } @@ -190,35 +194,43 @@ impl V1Context { } } -pub struct V2PostContext(Mutex>); +pub struct V2PostContext( + Mutex>>, +); impl V2PostContext { /// Decodes and validates the response. /// Call this method with response from receiver to continue BIP-??? flow. A successful response can either be None if the relay has not response yet or Some(Psbt). /// If the response is some valid PSBT you should sign and broadcast. pub fn process_response(&self, response: &[u8]) -> Result { - <&V2PostContext as Into>::into(self) - .process_response(response) - .map(Into::into) - .map_err(Into::into) + <&V2PostContext as Into>>::into( + self, + ) + .process_response(response) + .map(Into::into) + .map_err(Into::into) } } -impl From<&V2PostContext> for payjoin::send::v2::V2PostContext { +impl From<&V2PostContext> for payjoin::send::v2::Sender { fn from(value: &V2PostContext) -> Self { let mut data_guard = value.0.lock().unwrap(); Option::take(&mut *data_guard).expect("ContextV2 moved out of memory") } } -impl From for V2PostContext { - fn from(value: payjoin::send::v2::V2PostContext) -> Self { Self(Mutex::new(Some(value))) } +impl From> for V2PostContext { + fn from(value: payjoin::send::v2::Sender) -> Self { + Self(Mutex::new(Some(value))) + } } -pub struct V2GetContext(payjoin::send::v2::V2GetContext); +pub struct V2GetContext(payjoin::send::v2::Sender); -impl From for V2GetContext { - fn from(value: payjoin::send::v2::V2GetContext) -> Self { Self(value) } +impl From> for V2GetContext { + fn from(value: payjoin::send::v2::Sender) -> Self { + Self(value) + } } impl V2GetContext { diff --git a/payjoin-ffi/src/send/uni.rs b/payjoin-ffi/src/send/uni.rs index fe592dec0..1c1c8b4a8 100644 --- a/payjoin-ffi/src/send/uni.rs +++ b/payjoin-ffi/src/send/uni.rs @@ -22,7 +22,7 @@ impl SenderBuilder { /// Prepare an HTTP request and request context to process the response /// /// Call [`SenderBuilder::build_recommended()`] or other `build` methods - /// to create a [`Sender`] + /// to create a [`WithReplyKey`] #[uniffi::constructor] pub fn new(psbt: String, uri: Arc) -> Result { super::SenderBuilder::new(psbt, (*uri).clone()).map(Into::into) @@ -107,24 +107,24 @@ impl NewSender { } #[derive(Clone, uniffi::Object)] -pub struct Sender(super::Sender); +pub struct WithReplyKey(super::WithReplyKey); -impl From for Sender { - fn from(value: super::Sender) -> Self { Self(value) } +impl From for WithReplyKey { + fn from(value: super::WithReplyKey) -> Self { Self(value) } } -impl From for super::Sender { - fn from(value: Sender) -> Self { value.0 } +impl From for super::WithReplyKey { + fn from(value: WithReplyKey) -> Self { value.0 } } #[uniffi::export] -impl Sender { +impl WithReplyKey { #[uniffi::constructor] pub fn load( token: Arc, persister: Arc, ) -> Result { - Ok(super::Sender::from( + Ok(super::WithReplyKey::from( (*persister.load(token).map_err(|e| ImplementationError::from(e.to_string()))?).clone(), ) .into()) @@ -154,7 +154,7 @@ impl Sender { #[uniffi::constructor] pub fn from_json(json: &str) -> Result { - super::Sender::from_json(json).map(Into::into) + super::WithReplyKey::from_json(json).map(Into::into) } pub fn key(&self) -> SenderToken { self.0.key().into() } @@ -248,8 +248,8 @@ impl V2GetContext { #[uniffi::export(with_foreign)] pub trait SenderPersister: Send + Sync { - fn save(&self, sender: Arc) -> Result, ForeignError>; - fn load(&self, token: Arc) -> Result, ForeignError>; + fn save(&self, sender: Arc) -> Result, ForeignError>; + fn load(&self, token: Arc) -> Result, ForeignError>; } // The adapter to use the save and load callbacks @@ -262,16 +262,24 @@ impl CallbackPersisterAdapter { } // Implement the Persister trait for the adapter -impl payjoin::persist::Persister for CallbackPersisterAdapter { +impl payjoin::persist::Persister> + for CallbackPersisterAdapter +{ type Token = SenderToken; // Define the token type type Error = ForeignError; // Define the error type - fn save(&mut self, sender: payjoin::send::v2::Sender) -> Result { - let sender = Sender(super::Sender::from(sender)); + fn save( + &mut self, + sender: payjoin::send::v2::Sender, + ) -> Result { + let sender = WithReplyKey(super::WithReplyKey::from(sender)); self.callback_persister.save(sender.into()).map(|token| (*token).clone()) } - fn load(&self, token: Self::Token) -> Result { + fn load( + &self, + token: Self::Token, + ) -> Result, Self::Error> { // Use the callback to load the sender self.callback_persister.load(token.into()).map(|sender| (*sender).clone().0 .0) } @@ -285,8 +293,10 @@ impl std::fmt::Display for SenderToken { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) } } -impl From for SenderToken { - fn from(value: payjoin::send::v2::Sender) -> Self { SenderToken(value.into()) } +impl From> for SenderToken { + fn from(value: payjoin::send::v2::Sender) -> Self { + SenderToken(value.into()) + } } impl From for SenderToken { diff --git a/payjoin-ffi/tests/bdk_integration_test.rs b/payjoin-ffi/tests/bdk_integration_test.rs index 0dff2532e..5f1f7eac4 100644 --- a/payjoin-ffi/tests/bdk_integration_test.rs +++ b/payjoin-ffi/tests/bdk_integration_test.rs @@ -220,7 +220,7 @@ mod v2 { use bdk::wallet::AddressIndex; use bitcoin_ffi::{Address, Network}; use payjoin_ffi::receive::{NewReceiver, PayjoinProposal, UncheckedProposal, WithContext}; - use payjoin_ffi::send::{Sender, SenderBuilder}; + use payjoin_ffi::send::{WithReplyKey, SenderBuilder}; use payjoin_ffi::uri::Uri; use payjoin_ffi::{NoopPersister, Request}; use payjoin_test_utils::TestServices; @@ -288,7 +288,7 @@ mod v2 { let new_sender = SenderBuilder::new(psbt.to_string(), pj_uri)? .build_recommended(payjoin::bitcoin::FeeRate::BROADCAST_MIN.to_sat_per_kwu())?; let sender_token = new_sender.persist(&mut NoopPersister)?; - let req_ctx = Sender::load(sender_token, &NoopPersister)?; + let req_ctx = WithReplyKey::load(sender_token, &NoopPersister)?; let (request, context) = req_ctx.extract_v2(ohttp_relay.to_owned().into())?; let response = agent .post(request.url.as_string()) diff --git a/payjoin/src/receive/v2/mod.rs b/payjoin/src/receive/v2/mod.rs index 1baa2490a..398e64077 100644 --- a/payjoin/src/receive/v2/mod.rs +++ b/payjoin/src/receive/v2/mod.rs @@ -127,18 +127,20 @@ impl NewReceiver { } } +pub trait ReceiverState {} + #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct Receiver { +pub struct Receiver { pub(crate) state: State, } -impl core::ops::Deref for Receiver { +impl core::ops::Deref for Receiver { type Target = State; fn deref(&self) -> &Self::Target { &self.state } } -impl core::ops::DerefMut for Receiver { +impl core::ops::DerefMut for Receiver { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.state } } @@ -147,6 +149,8 @@ pub struct WithContext { context: SessionContext, } +impl ReceiverState for WithContext {} + impl Receiver { /// Loads a [`Receiver`] from the provided persister using the storage token. pub fn load>>( @@ -281,6 +285,8 @@ pub struct UncheckedProposal { pub(crate) context: SessionContext, } +impl ReceiverState for UncheckedProposal {} + impl Receiver { /// The Sender's Original PSBT pub fn extract_tx_to_schedule_broadcast(&self) -> bitcoin::Transaction { @@ -366,6 +372,8 @@ pub struct MaybeInputsOwned { context: SessionContext, } +impl ReceiverState for MaybeInputsOwned {} + impl Receiver { /// Check that the Original PSBT has no receiver-owned inputs. /// Return original-psbt-rejected error or otherwise refuse to sign undesirable inputs. @@ -389,6 +397,8 @@ pub struct MaybeInputsSeen { context: SessionContext, } +impl ReceiverState for MaybeInputsSeen {} + impl Receiver { /// Make sure that the original transaction inputs have never been seen before. /// This prevents probing attacks. This prevents reentrant Payjoin, where a sender @@ -412,6 +422,8 @@ pub struct OutputsUnknown { context: SessionContext, } +impl ReceiverState for OutputsUnknown {} + impl Receiver { /// Find which outputs belong to the receiver pub fn identify_receiver_outputs( @@ -432,6 +444,8 @@ pub struct WantsOutputs { context: SessionContext, } +impl ReceiverState for WantsOutputs {} + impl Receiver { /// Whether the receiver is allowed to substitute original outputs or not. pub fn output_substitution(&self) -> OutputSubstitution { self.v1.output_substitution() } @@ -476,6 +490,8 @@ pub struct WantsInputs { context: SessionContext, } +impl ReceiverState for WantsInputs {} + impl Receiver { /// Select receiver input such that the payjoin avoids surveillance. /// Return the input chosen that has been applied to the Proposal. @@ -523,6 +539,8 @@ pub struct ProvisionalProposal { context: SessionContext, } +impl ReceiverState for ProvisionalProposal {} + impl Receiver { /// Return a Payjoin Proposal PSBT that the sender will find acceptable. /// @@ -553,6 +571,8 @@ pub struct PayjoinProposal { context: SessionContext, } +impl ReceiverState for PayjoinProposal {} + impl PayjoinProposal { #[cfg(feature = "_multiparty")] // TODO hack to get multi party working. A better solution would be to allow extract_req to be separate from the rest of the v2 context diff --git a/payjoin/src/send/multiparty/mod.rs b/payjoin/src/send/multiparty/mod.rs index 75e23b61e..bb6140ba4 100644 --- a/payjoin/src/send/multiparty/mod.rs +++ b/payjoin/src/send/multiparty/mod.rs @@ -33,7 +33,7 @@ impl<'a> SenderBuilder<'a> { pub struct NewSender(v2::NewSender); #[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct Sender(v2::Sender); +pub struct Sender(v2::Sender); impl Sender { pub fn extract_v2( @@ -76,7 +76,7 @@ impl Sender { hpke_ctx: HpkeContext::new(rs, &self.0.reply_key), ohttp_ctx, }; - Ok((request, PostContext(v2_post_ctx))) + Ok((request, PostContext(v2::Sender { state: v2_post_ctx }))) } } @@ -100,7 +100,7 @@ fn serialize_v2_body( /// Post context is used to process the response from the directory and generate /// the GET context which can be used to extract a request for the receiver -pub struct PostContext(v2::V2PostContext); +pub struct PostContext(v2::Sender); impl PostContext { pub fn process_response(self, response: &[u8]) -> Result { @@ -111,7 +111,7 @@ impl PostContext { /// Get context is used to extract a request for the receiver. In the multiparty context this is a /// merged PSBT with other senders. -pub struct GetContext(v2::V2GetContext); +pub struct GetContext(v2::Sender); impl GetContext { /// Extract the GET request that will give us the psbt to be finalized diff --git a/payjoin/src/send/multiparty/persist.rs b/payjoin/src/send/multiparty/persist.rs index 16e034203..4674959f1 100644 --- a/payjoin/src/send/multiparty/persist.rs +++ b/payjoin/src/send/multiparty/persist.rs @@ -7,8 +7,9 @@ impl NewSender { &self, persister: &mut P, ) -> Result { - let sender = - Sender(v2::Sender { v1: self.0.v1.clone(), reply_key: self.0.reply_key.clone() }); + let sender = Sender(v2::Sender { + state: v2::WithReplyKey { v1: self.0.v1.clone(), reply_key: self.0.reply_key.clone() }, + }); persister.save(sender).map_err(ImplementationError::from) } } diff --git a/payjoin/src/send/v2/mod.rs b/payjoin/src/send/v2/mod.rs index 9d7d6030c..11bd004b9 100644 --- a/payjoin/src/send/v2/mod.rs +++ b/payjoin/src/send/v2/mod.rs @@ -124,6 +124,23 @@ impl<'a> SenderBuilder<'a> { } } +pub trait SenderState {} + +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Sender { + pub(crate) state: State, +} + +impl core::ops::Deref for Sender { + type Target = State; + + fn deref(&self) -> &Self::Target { &self.state } +} + +impl core::ops::DerefMut for Sender { + fn deref_mut(&mut self) -> &mut Self::Target { &mut self.state } +} + /// A new payjoin sender, which must be persisted before initiating the payjoin flow. #[derive(Debug)] pub struct NewSender { @@ -133,11 +150,13 @@ pub struct NewSender { impl NewSender { /// Saves the new [`Sender`] using the provided persister and returns the storage token. - pub fn persist>( + pub fn persist>>( &self, persister: &mut P, ) -> Result { - let sender = Sender { v1: self.v1.clone(), reply_key: self.reply_key.clone() }; + let sender = Sender { + state: WithReplyKey { v1: self.v1.clone(), reply_key: self.reply_key.clone() }, + }; Ok(persister.save(sender)?) } } @@ -145,16 +164,18 @@ impl NewSender { /// A payjoin V2 sender, allowing the construction of a payjoin V2 request /// and the resulting [`V2PostContext`]. #[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct Sender { +pub struct WithReplyKey { /// The v1 Sender. pub(crate) v1: v1::Sender, /// The secret key to decrypt the receiver's reply. pub(crate) reply_key: HpkeSecretKey, } -impl Sender { +impl SenderState for WithReplyKey {} + +impl Sender { /// Loads a [`Sender`] from the provided persister using the storage token. - pub fn load>( + pub fn load>>( token: P::Token, persister: &P, ) -> Result { @@ -170,7 +191,7 @@ impl Sender { pub fn extract_v2( &self, ohttp_relay: impl IntoUrl, - ) -> Result<(Request, V2PostContext), CreateRequestError> { + ) -> Result<(Request, Sender), CreateRequestError> { if let Ok(expiry) = self.v1.endpoint.exp() { if std::time::SystemTime::now() > expiry { return Err(InternalCreateRequestError::Expired(expiry).into()); @@ -199,17 +220,19 @@ impl Sender { let rs = self.extract_rs_pubkey()?; Ok(( request, - V2PostContext { - endpoint: self.v1.endpoint.clone(), - psbt_ctx: PsbtContext { - original_psbt: self.v1.psbt.clone(), - output_substitution: self.v1.output_substitution, - fee_contribution: self.v1.fee_contribution, - payee: self.v1.payee.clone(), - min_fee_rate: self.v1.min_fee_rate, + Sender { + state: V2PostContext { + endpoint: self.v1.endpoint.clone(), + psbt_ctx: PsbtContext { + original_psbt: self.v1.psbt.clone(), + output_substitution: self.v1.output_substitution, + fee_contribution: self.v1.fee_contribution, + payee: self.v1.payee.clone(), + min_fee_rate: self.v1.min_fee_rate, + }, + hpke_ctx: HpkeContext::new(rs, &self.reply_key), + ohttp_ctx, }, - hpke_ctx: HpkeContext::new(rs, &self.reply_key), - ohttp_ctx, }, )) } @@ -276,7 +299,7 @@ pub(crate) fn serialize_v2_body( /// Data required to validate the POST response. /// /// This type is used to process a BIP77 POST response. -/// Call [`Self::process_response`] on it to continue the BIP77 flow. +/// Call [`Sender::process_response`] on it to continue the BIP77 flow. pub struct V2PostContext { /// The endpoint in the Payjoin URI pub(crate) endpoint: Url, @@ -285,7 +308,9 @@ pub struct V2PostContext { pub(crate) ohttp_ctx: ohttp::ClientResponse, } -impl V2PostContext { +impl SenderState for V2PostContext {} + +impl Sender { /// Processes the response for the initial POST message from the sender /// client in the v2 Payjoin protocol. /// @@ -296,19 +321,24 @@ impl V2PostContext { /// /// After this function is called, the sender can poll for a Proposal PSBT /// from the receiver using the returned [`V2GetContext`]. - pub fn process_response(self, response: &[u8]) -> Result { + pub fn process_response( + self, + response: &[u8], + ) -> Result, EncapsulationError> { let response_array: &[u8; crate::directory::ENCAPSULATED_MESSAGE_BYTES] = response .try_into() .map_err(|_| InternalEncapsulationError::InvalidSize(response.len()))?; - let response = ohttp_decapsulate(self.ohttp_ctx, response_array) + let response = ohttp_decapsulate(self.state.ohttp_ctx, response_array) .map_err(InternalEncapsulationError::Ohttp)?; match response.status() { http::StatusCode::OK => { // return OK with new Typestate - Ok(V2GetContext { - endpoint: self.endpoint, - psbt_ctx: self.psbt_ctx, - hpke_ctx: self.hpke_ctx, + Ok(Sender { + state: V2GetContext { + endpoint: self.state.endpoint, + psbt_ctx: self.state.psbt_ctx, + hpke_ctx: self.state.hpke_ctx, + }, }) } _ => Err(InternalEncapsulationError::UnexpectedStatusCode(response.status()))?, @@ -319,7 +349,7 @@ impl V2PostContext { /// Data required to validate the GET response. /// /// This type is used to make a BIP77 GET request and process the response. -/// Call [`Self::process_response`] on it to continue the BIP77 flow. +/// Call [`Sender::process_response`] on it to continue the BIP77 flow. #[derive(Debug, Clone)] pub struct V2GetContext { /// The endpoint in the Payjoin URI @@ -328,7 +358,9 @@ pub struct V2GetContext { pub(crate) hpke_ctx: HpkeContext, } -impl V2GetContext { +impl SenderState for V2GetContext {} + +impl Sender { /// Extract an OHTTP Encapsulated HTTP GET request for the Proposal PSBT pub fn extract_req( &self, @@ -422,18 +454,20 @@ mod test { const SERIALIZED_BODY_V2: &str = "63484e696450384241484d43414141414159386e757447674a647959475857694245623435486f65396c5747626b78682f36624e694f4a6443447544414141414141442b2f2f2f2f41747956754155414141414146366b554865684a38476e536442554f4f7636756a584c72576d734a5244434867495165414141414141415871525233514a62627a30686e513849765130667074476e2b766f746e656f66544141414141414542494b6762317755414141414146366b55336b34656b47484b57524e6241317256357452356b455644564e4348415163584667415578347046636c4e56676f31575741644e3153594e583874706854414243477343527a424541694238512b41366465702b527a393276687932366c5430416a5a6e3450524c6938426639716f422f434d6b30774967502f526a3250575a3367456a556b546c6844524e415130675877544f3774396e2b563134705a366f6c6a554249514d566d7341616f4e5748564d5330324c6654536530653338384c4e697450613155515a794f6968592b464667414241425941464562324769753663344b4f35595730706677336c4770396a4d55554141413d0a763d32"; - fn create_sender_context() -> Result { + fn create_sender_context() -> Result, BoxError> { let endpoint = Url::parse("http://localhost:1234")?; let mut sender = super::Sender { - v1: v1::Sender { - psbt: PARSED_ORIGINAL_PSBT.clone(), - endpoint, - output_substitution: OutputSubstitution::Enabled, - fee_contribution: None, - min_fee_rate: FeeRate::ZERO, - payee: ScriptBuf::from(vec![0x00]), + state: super::WithReplyKey { + v1: v1::Sender { + psbt: PARSED_ORIGINAL_PSBT.clone(), + endpoint, + output_substitution: OutputSubstitution::Enabled, + fee_contribution: None, + min_fee_rate: FeeRate::ZERO, + payee: ScriptBuf::from(vec![0x00]), + }, + reply_key: HpkeKeyPair::gen_keypair().0, }, - reply_key: HpkeKeyPair::gen_keypair().0, }; sender.v1.endpoint.set_exp(SystemTime::now() + Duration::from_secs(60)); sender.v1.endpoint.set_receiver_pubkey(HpkeKeyPair::gen_keypair().1); diff --git a/payjoin/src/send/v2/persist.rs b/payjoin/src/send/v2/persist.rs index 90781dd05..16df89355 100644 --- a/payjoin/src/send/v2/persist.rs +++ b/payjoin/src/send/v2/persist.rs @@ -2,7 +2,7 @@ use std::fmt::{self, Display}; use url::Url; -use super::Sender; +use super::{Sender, WithReplyKey}; use crate::persist::Value; /// Opaque key type for the sender @@ -13,15 +13,15 @@ impl Display for SenderToken { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.0) } } -impl From for SenderToken { - fn from(sender: Sender) -> Self { SenderToken(sender.endpoint().clone()) } +impl From> for SenderToken { + fn from(sender: Sender) -> Self { SenderToken(sender.endpoint().clone()) } } impl AsRef<[u8]> for SenderToken { fn as_ref(&self) -> &[u8] { self.0.as_str().as_bytes() } } -impl Value for Sender { +impl Value for Sender { type Key = SenderToken; fn key(&self) -> Self::Key { SenderToken(self.endpoint().clone()) }