Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions payjoin-cli/src/app/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ impl AppTrait for App {
Some((sender_state, persister)) => (sender_state, persister),
None => {
let persister =
SenderPersister::new(self.db.clone(), receiver_pubkey.clone())?;
SenderPersister::new(self.db.clone(), bip21, receiver_pubkey)?;
let psbt = self.create_original_psbt(&address, amount, fee_rate)?;
let sender =
SenderBuilder::from_parts(psbt, pj_param, &address, Some(amount))
Expand Down Expand Up @@ -307,17 +307,17 @@ impl AppTrait for App {

// Process sender sessions
for session_id in send_session_ids {
let sender_persiter = SenderPersister::from_id(self.db.clone(), session_id.clone());
match replay_sender_event_log(&sender_persiter) {
let sender_persister = SenderPersister::from_id(self.db.clone(), session_id.clone());
match replay_sender_event_log(&sender_persister) {
Ok((sender_state, _)) => {
let self_clone = self.clone();
tasks.push(tokio::spawn(async move {
self_clone.process_sender_session(sender_state, &sender_persiter).await
self_clone.process_sender_session(sender_state, &sender_persister).await
}));
}
Err(e) => {
tracing::error!("An error {:?} occurred while replaying Sender session", e);
Self::close_failed_session(&sender_persiter, &session_id, "sender");
Self::close_failed_session(&sender_persister, &session_id, "sender");
}
}
}
Expand Down
15 changes: 15 additions & 0 deletions payjoin-cli/src/db/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ pub(crate) enum Error {
Serialize(serde_json::Error),
#[cfg(feature = "v2")]
Deserialize(serde_json::Error),
#[cfg(feature = "v2")]
DuplicateUri,
#[cfg(feature = "v2")]
DuplicateRk,
}

impl fmt::Display for Error {
Expand All @@ -25,6 +29,13 @@ impl fmt::Display for Error {
Error::Serialize(e) => write!(f, "Serialization failed: {e}"),
#[cfg(feature = "v2")]
Error::Deserialize(e) => write!(f, "Deserialization failed: {e}"),
#[cfg(feature = "v2")]
Error::DuplicateUri => write!(f, "A send session for this URI is already active"),
#[cfg(feature = "v2")]
Error::DuplicateRk => write!(
f,
"A send session with this receiver pubkey is already active under a different URI"
),
}
}
}
Expand All @@ -38,6 +49,10 @@ impl std::error::Error for Error {
Error::Serialize(e) => Some(e),
#[cfg(feature = "v2")]
Error::Deserialize(e) => Some(e),
#[cfg(feature = "v2")]
Error::DuplicateUri => None,
#[cfg(feature = "v2")]
Error::DuplicateRk => None,
}
}
}
Expand Down
5 changes: 4 additions & 1 deletion payjoin-cli/src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ pub(crate) fn now() -> i64 {

pub(crate) const DB_PATH: &str = "payjoin.sqlite";

#[derive(Debug)]
pub(crate) struct Database(Pool<SqliteConnectionManager>);

impl Database {
pub(crate) fn create(path: impl AsRef<Path>) -> Result<Self> {
let manager = SqliteConnectionManager::file(path.as_ref());
let manager = SqliteConnectionManager::file(path.as_ref())
.with_init(|conn| conn.execute_batch("PRAGMA locking_mode = EXCLUSIVE;"));
let pool = Pool::new(manager)?;

// Initialize database schema
Expand All @@ -36,6 +38,7 @@ impl Database {
conn.execute(
"CREATE TABLE IF NOT EXISTS send_sessions (
session_id INTEGER PRIMARY KEY AUTOINCREMENT,
pj_uri TEXT NOT NULL,
receiver_pubkey BLOB NOT NULL,
completed_at INTEGER
)",
Expand Down
104 changes: 99 additions & 5 deletions payjoin-cli/src/db/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,40 @@ impl std::fmt::Display for SessionId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) }
}

#[derive(Clone)]
#[derive(Clone, Debug)]
pub(crate) struct SenderPersister {
db: Arc<Database>,
session_id: SessionId,
}

impl SenderPersister {
pub fn new(db: Arc<Database>, receiver_pubkey: HpkePublicKey) -> crate::db::Result<Self> {
pub fn new(
db: Arc<Database>,
pj_uri: &str,
receiver_pubkey: &HpkePublicKey,
) -> crate::db::Result<Self> {
let conn = db.get_connection()?;
let receiver_pubkey_bytes = receiver_pubkey.to_compressed_bytes();

let (duplicate_uri, duplicate_rk): (bool, bool) = conn.query_row(
"SELECT \
EXISTS(SELECT 1 FROM send_sessions WHERE completed_at IS NULL AND pj_uri = ?1), \
EXISTS(SELECT 1 FROM send_sessions WHERE completed_at IS NULL AND receiver_pubkey = ?2)",
params![pj_uri, &receiver_pubkey_bytes.as_slice()],
|row| Ok((row.get(0)?, row.get(1)?)),
)?;

if duplicate_uri {
return Err(Error::DuplicateUri);
}
if duplicate_rk {
return Err(Error::DuplicateRk);
}

// Create a new session in send_sessions and get its ID
let session_id: i64 = conn.query_row(
"INSERT INTO send_sessions (session_id, receiver_pubkey) VALUES (NULL, ?1) RETURNING session_id",
params![receiver_pubkey.to_compressed_bytes()],
"INSERT INTO send_sessions (pj_uri, receiver_pubkey) VALUES (?1, ?2) RETURNING session_id",
params![pj_uri, &receiver_pubkey_bytes.as_slice()],
|row| row.get(0),
)?;

Expand All @@ -42,7 +62,6 @@ impl SenderPersister {

pub fn from_id(db: Arc<Database>, id: SessionId) -> Self { Self { db, session_id: id } }
}

impl SessionPersister for SenderPersister {
type SessionEvent = SenderSessionEvent;
type InternalStorageError = crate::db::error::Error;
Expand Down Expand Up @@ -268,3 +287,78 @@ impl Database {
Ok(session_ids)
}
}

#[cfg(all(test, feature = "v2"))]
mod tests {
use std::sync::Arc;

use payjoin::HpkeKeyPair;

use super::*;

fn create_test_db() -> Arc<Database> {
// Use an in-memory database for tests
let manager = r2d2_sqlite::SqliteConnectionManager::memory()
.with_init(|conn| conn.execute_batch("PRAGMA locking_mode = EXCLUSIVE;"));
let pool = r2d2::Pool::new(manager).expect("pool creation should succeed");
let conn = pool.get().expect("connection should succeed");
Database::init_schema(&conn).expect("schema init should succeed");
Arc::new(Database(pool))
}

fn make_receiver_pubkey() -> payjoin::HpkePublicKey { HpkeKeyPair::gen_keypair().1 }

/// Second call with the same URI (same active session) should return DuplicateUri.
#[test]
fn test_duplicate_uri_returns_error() {
let db = create_test_db();
let rk1 = make_receiver_pubkey();
let rk2 = make_receiver_pubkey();
let uri = "bitcoin:addr1?pj=https://example.com/BBBBBBBB";

SenderPersister::new(db.clone(), uri, &rk1).expect("first session should succeed");

let err = SenderPersister::new(db, uri, &rk2).expect_err("duplicate URI should fail");
assert!(
matches!(err, crate::db::error::Error::DuplicateUri),
"expected DuplicateUri, got: {err:?}"
);
}

/// Same receiver pubkey under a different URI should return DuplicateRk.
#[test]
fn test_duplicate_rk_returns_error() {
let db = create_test_db();
let rk = make_receiver_pubkey();
let uri1 = "bitcoin:addr1?pj=https://example.com/CCCCCCCC";
let uri2 = "bitcoin:addr1?pj=https://example.com/DDDDDDDD";

SenderPersister::new(db.clone(), uri1, &rk).expect("first session should succeed");

let err = SenderPersister::new(db, uri2, &rk).expect_err("duplicate RK should fail");
assert!(
matches!(err, crate::db::error::Error::DuplicateRk),
"expected DuplicateRk, got: {err:?}"
);
}

/// After a session is marked completed, a new session with the same URI should be allowed.
#[test]
fn test_completed_session_allows_reuse() {
let db = create_test_db();
let rk1 = make_receiver_pubkey();
let rk2 = make_receiver_pubkey();
let uri = "bitcoin:addr1?pj=https://example.com/EEEEEEEE";

let persister =
SenderPersister::new(db.clone(), uri, &rk1).expect("first session should succeed");

// Mark the session as completed
use payjoin::persist::SessionPersister;
persister.close().expect("close should succeed");

// Now a new session with the same URI should succeed (completed sessions don't block)
let result = SenderPersister::new(db, uri, &rk2);
assert!(result.is_ok(), "reuse after completion should succeed");
}
}
Loading