]> Untitled Git - bdk-cli/commitdiff
Add SQLite backing store for payjoin sessions
authorMshehu5 <musheu@gmail.com>
Mon, 16 Feb 2026 13:31:08 +0000 (14:31 +0100)
committerMshehu5 <musheu@gmail.com>
Thu, 2 Jul 2026 11:30:16 +0000 (12:30 +0100)
Persist payjoin sender and receiver state in SQLite so interrupted
payjoin sessions can be resumed after the CLI exits. Add dedicated
tables for send and receive sessions, append-only event logs for
state replay, receiver pubkey lookup for sender sessions, and
seen-input tracking for replay protection.

This follows the intended async payjoin design by persisting session
state across interruptions. SQLite keeps the initial persistence
backend simple and builds on existing rusqlite support, at the cost
of a small payjoin-specific schema and serialization layer.

src/error.rs
src/payjoin/db.rs [new file with mode: 0644]

index a3e046248a7287527f8d587ffce6911af1a196e1..dbdc7b8112ceaa8ef2471dc05d093f81c8564240 100644 (file)
@@ -152,6 +152,10 @@ pub enum BDKCliError {
     #[error("Payjoin create request error: {0}")]
     PayjoinCreateRequest(#[from] payjoin::send::v2::CreateRequestError),
 
+    #[cfg(feature = "payjoin")]
+    #[error("Payjoin database error: {0}")]
+    PayjoinDb(#[from] crate::payjoin::db::Error),
+
     #[cfg(feature = "bip322")]
     #[error("BIP-322 error: {0}")]
     Bip322Error(#[from] bdk_bip322::error::Error),
@@ -183,3 +187,16 @@ impl From<bdk_wallet::rusqlite::Error> for BDKCliError {
         BDKCliError::RusqliteError(Box::new(err))
     }
 }
+
+#[cfg(feature = "payjoin")]
+impl<ApiErr, StorageErr, ErrorState>
+    From<payjoin::persist::PersistedError<ApiErr, StorageErr, ErrorState>> for BDKCliError
+where
+    ApiErr: std::error::Error,
+    StorageErr: std::error::Error,
+    ErrorState: std::fmt::Debug,
+{
+    fn from(e: payjoin::persist::PersistedError<ApiErr, StorageErr, ErrorState>) -> Self {
+        BDKCliError::Generic(e.to_string())
+    }
+}
diff --git a/src/payjoin/db.rs b/src/payjoin/db.rs
new file mode 100644 (file)
index 0000000..39089e1
--- /dev/null
@@ -0,0 +1,449 @@
+use std::fmt;
+use std::path::{Path, PathBuf};
+use std::sync::{Arc, Mutex, MutexGuard};
+
+use bdk_wallet::rusqlite::{Connection, ToSql, params, types::ToSqlOutput};
+use payjoin::HpkePublicKey;
+use payjoin::bitcoin::OutPoint;
+use payjoin::bitcoin::consensus::encode::serialize;
+use payjoin::persist::SessionPersister;
+use payjoin::receive::v2::SessionEvent as ReceiverSessionEvent;
+use payjoin::send::v2::SessionEvent as SenderSessionEvent;
+
+use crate::error::BDKCliError;
+use crate::utils::prepare_home_dir;
+
+pub type Result<T> = std::result::Result<T, Error>;
+
+/// Error type for payjoin database operations
+#[derive(Debug)]
+pub enum Error {
+    /// SQLite database error
+    Rusqlite(bdk_wallet::rusqlite::Error),
+    /// JSON serialization error
+    Serialize(serde_json::Error),
+    /// JSON deserialization error
+    Deserialize(serde_json::Error),
+}
+
+impl std::fmt::Display for Error {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        match self {
+            Error::Rusqlite(e) => write!(f, "Database operation failed: {e}"),
+            Error::Serialize(e) => write!(f, "Serialization failed: {e}"),
+            Error::Deserialize(e) => write!(f, "Deserialization failed: {e}"),
+        }
+    }
+}
+
+impl std::error::Error for Error {
+    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+        match self {
+            Error::Rusqlite(e) => Some(e),
+            Error::Serialize(e) => Some(e),
+            Error::Deserialize(e) => Some(e),
+        }
+    }
+}
+
+impl From<bdk_wallet::rusqlite::Error> for Error {
+    fn from(error: bdk_wallet::rusqlite::Error) -> Self {
+        Error::Rusqlite(error)
+    }
+}
+
+impl From<Error> for payjoin::ImplementationError {
+    fn from(error: Error) -> Self {
+        payjoin::ImplementationError::new(error)
+    }
+}
+
+/// Default filename for the payjoin database
+pub const DB_FILENAME: &str = "payjoin.sqlite";
+
+pub fn open_payjoin_db(
+    datadir: Option<PathBuf>,
+    wallet_name: &str,
+) -> std::result::Result<Arc<Database>, BDKCliError> {
+    let wallet_dir = prepare_home_dir(datadir)?.join(wallet_name);
+    std::fs::create_dir_all(&wallet_dir).map_err(|e| BDKCliError::Generic(e.to_string()))?;
+    Ok(Arc::new(Database::create(wallet_dir.join(DB_FILENAME))?))
+}
+
+/// Returns the current Unix timestamp in seconds
+#[inline]
+fn now() -> i64 {
+    std::time::SystemTime::now()
+        .duration_since(std::time::UNIX_EPOCH)
+        .unwrap()
+        .as_secs() as i64
+}
+
+pub struct Database {
+    conn: Mutex<Connection>,
+}
+
+impl Database {
+    pub fn create(path: impl AsRef<Path>) -> Result<Self> {
+        let conn = Connection::open(path.as_ref())?;
+        Self::init_schema(&conn)?;
+        Ok(Self {
+            conn: Mutex::new(conn),
+        })
+    }
+
+    fn conn(&self) -> MutexGuard<'_, Connection> {
+        self.conn
+            .lock()
+            .expect("Database mutex should not be poisoned")
+    }
+
+    fn init_schema(conn: &Connection) -> Result<()> {
+        conn.execute("PRAGMA foreign_keys = ON", [])?;
+
+        conn.execute(
+            "CREATE TABLE IF NOT EXISTS send_sessions (
+                session_id INTEGER PRIMARY KEY AUTOINCREMENT,
+                receiver_pubkey BLOB NOT NULL,
+                completed_at INTEGER
+            )",
+            [],
+        )?;
+
+        conn.execute(
+            "CREATE TABLE IF NOT EXISTS receive_sessions (
+                session_id INTEGER PRIMARY KEY AUTOINCREMENT,
+                completed_at INTEGER
+            )",
+            [],
+        )?;
+
+        conn.execute(
+            "CREATE TABLE IF NOT EXISTS send_session_events (
+                id INTEGER PRIMARY KEY AUTOINCREMENT,
+                session_id INTEGER NOT NULL,
+                event_data TEXT NOT NULL,
+                created_at INTEGER NOT NULL,
+                FOREIGN KEY(session_id) REFERENCES send_sessions(session_id)
+            )",
+            [],
+        )?;
+
+        conn.execute(
+            "CREATE TABLE IF NOT EXISTS receive_session_events (
+                id INTEGER PRIMARY KEY AUTOINCREMENT,
+                session_id INTEGER NOT NULL,
+                event_data TEXT NOT NULL,
+                created_at INTEGER NOT NULL,
+                FOREIGN KEY(session_id) REFERENCES receive_sessions(session_id)
+            )",
+            [],
+        )?;
+
+        conn.execute(
+            "CREATE TABLE IF NOT EXISTS inputs_seen (
+                outpoint BLOB PRIMARY KEY,
+                created_at INTEGER NOT NULL
+            )",
+            [],
+        )?;
+
+        Ok(())
+    }
+
+    /// Inserts the input and returns true if the input was seen before, false otherwise.
+    /// This is used for replay protection to prevent probing attacks.
+    pub fn insert_input_seen_before(&self, input: OutPoint) -> Result<bool> {
+        let key = serialize(&input);
+        let was_seen_before = self.conn().execute(
+            "INSERT OR IGNORE INTO inputs_seen (outpoint, created_at) VALUES (?1, ?2)",
+            params![key, now()],
+        )? == 0;
+        Ok(was_seen_before)
+    }
+
+    /// Returns IDs of all active (incomplete) receive sessions
+    pub fn get_recv_session_ids(&self) -> Result<Vec<SessionId>> {
+        let conn = self.conn();
+        let mut stmt =
+            conn.prepare("SELECT session_id FROM receive_sessions WHERE completed_at IS NULL ORDER BY session_id DESC")?;
+
+        let session_rows = stmt.query_map([], |row| {
+            let session_id: i64 = row.get(0)?;
+            Ok(SessionId(session_id))
+        })?;
+
+        let mut session_ids = Vec::new();
+        for session_row in session_rows {
+            session_ids.push(session_row?);
+        }
+
+        Ok(session_ids)
+    }
+
+    /// Returns IDs of all active (incomplete) send sessions
+    pub fn get_send_session_ids(&self) -> Result<Vec<SessionId>> {
+        let conn = self.conn();
+        let mut stmt =
+            conn.prepare("SELECT session_id FROM send_sessions WHERE completed_at IS NULL ORDER BY session_id DESC")?;
+
+        let session_rows = stmt.query_map([], |row| {
+            let session_id: i64 = row.get(0)?;
+            Ok(SessionId(session_id))
+        })?;
+
+        let mut session_ids = Vec::new();
+        for session_row in session_rows {
+            session_ids.push(session_row?);
+        }
+
+        Ok(session_ids)
+    }
+
+    /// Returns the receiver public key for a send session
+    pub fn get_send_session_receiver_pk(&self, session_id: &SessionId) -> Result<HpkePublicKey> {
+        let conn = self.conn();
+        let mut stmt =
+            conn.prepare("SELECT receiver_pubkey FROM send_sessions WHERE session_id = ?1")?;
+        let receiver_pubkey: Vec<u8> = stmt.query_row(params![session_id], |row| row.get(0))?;
+        Ok(HpkePublicKey::from_compressed_bytes(&receiver_pubkey).expect("Valid receiver pubkey"))
+    }
+
+    /// Returns IDs and completion timestamps of all completed send sessions
+    pub fn get_inactive_send_session_ids(&self) -> Result<Vec<(SessionId, u64)>> {
+        let conn = self.conn();
+        let mut stmt = conn.prepare(
+            "SELECT session_id, completed_at FROM send_sessions WHERE completed_at IS NOT NULL",
+        )?;
+        let session_rows = stmt.query_map([], |row| {
+            let session_id: i64 = row.get(0)?;
+            let completed_at: u64 = row.get(1)?;
+            Ok((SessionId(session_id), completed_at))
+        })?;
+
+        let mut session_ids = Vec::new();
+        for session_row in session_rows {
+            session_ids.push(session_row?);
+        }
+        Ok(session_ids)
+    }
+
+    /// Returns IDs and completion timestamps of all completed receive sessions
+    pub fn get_inactive_recv_session_ids(&self) -> Result<Vec<(SessionId, u64)>> {
+        let conn = self.conn();
+        let mut stmt = conn.prepare(
+            "SELECT session_id, completed_at FROM receive_sessions WHERE completed_at IS NOT NULL",
+        )?;
+        let session_rows = stmt.query_map([], |row| {
+            let session_id: i64 = row.get(0)?;
+            let completed_at: u64 = row.get(1)?;
+            Ok((SessionId(session_id), completed_at))
+        })?;
+
+        let mut session_ids = Vec::new();
+        for session_row in session_rows {
+            session_ids.push(session_row?);
+        }
+        Ok(session_ids)
+    }
+
+    /// Formats a Unix timestamp into local date time text.
+    pub fn format_unix_timestamp(&self, timestamp: u64) -> Result<String> {
+        let Ok(timestamp) = i64::try_from(timestamp) else {
+            return Ok(format!("Invalid timestamp ({timestamp})"));
+        };
+        let conn = self.conn();
+        let dt: Option<String> = conn.query_row(
+            "SELECT datetime(?1, 'unixepoch', 'localtime')",
+            params![timestamp],
+            |row| row.get(0),
+        )?;
+        Ok(dt.unwrap_or_else(|| format!("Invalid timestamp ({timestamp})")))
+    }
+}
+
+/// Wrapper type for session IDs
+#[derive(Debug, Clone)]
+pub struct SessionId(i64);
+
+impl ToSql for SessionId {
+    fn to_sql(&self) -> bdk_wallet::rusqlite::Result<ToSqlOutput<'_>> {
+        self.0.to_sql()
+    }
+}
+
+impl fmt::Display for SessionId {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "{}", self.0)
+    }
+}
+
+impl SessionId {
+    pub fn as_i64(&self) -> i64 {
+        self.0
+    }
+}
+
+/// Persister for payjoin v2 send sessions
+#[derive(Clone)]
+pub struct SenderPersister {
+    db: Arc<Database>,
+    session_id: SessionId,
+}
+
+impl SenderPersister {
+    /// Creates a new sender persister, creating a new session in the database
+    pub fn new(db: Arc<Database>, receiver_pubkey: HpkePublicKey) -> Result<Self> {
+        let session_id: i64 = db.conn().query_row(
+            "INSERT INTO send_sessions (session_id, receiver_pubkey) VALUES (NULL, ?1) RETURNING session_id",
+            params![receiver_pubkey.to_compressed_bytes()],
+            |row| row.get(0),
+        )?;
+
+        Ok(Self {
+            db,
+            session_id: SessionId(session_id),
+        })
+    }
+
+    /// Creates a persister from an existing session ID
+    pub fn from_id(db: Arc<Database>, id: SessionId) -> Self {
+        Self { db, session_id: id }
+    }
+}
+
+impl SessionPersister for SenderPersister {
+    type SessionEvent = SenderSessionEvent;
+    type InternalStorageError = Error;
+
+    fn save_event(
+        &self,
+        event: SenderSessionEvent,
+    ) -> std::result::Result<(), Self::InternalStorageError> {
+        let event_data = serde_json::to_string(&event).map_err(Error::Serialize)?;
+
+        self.db.conn().execute(
+            "INSERT INTO send_session_events (session_id, event_data, created_at) VALUES (?1, ?2, ?3)",
+            params![self.session_id, event_data, now()],
+        )?;
+
+        Ok(())
+    }
+
+    fn load(
+        &self,
+    ) -> std::result::Result<Box<dyn Iterator<Item = SenderSessionEvent>>, Self::InternalStorageError>
+    {
+        let conn = self.db.conn();
+        let mut stmt = conn.prepare(
+            "SELECT event_data FROM send_session_events WHERE session_id = ?1 ORDER BY id ASC",
+        )?;
+
+        let event_rows = stmt.query_map(params![self.session_id], |row| {
+            let event_data: String = row.get(0)?;
+            Ok(event_data)
+        })?;
+
+        let events: Vec<SenderSessionEvent> = event_rows
+            .map(|row| {
+                let event_data = row.expect("Failed to read event data from database");
+                serde_json::from_str::<SenderSessionEvent>(&event_data)
+                    .expect("Database corruption: failed to deserialize session event")
+            })
+            .collect();
+
+        Ok(Box::new(events.into_iter()))
+    }
+
+    fn close(&self) -> std::result::Result<(), Self::InternalStorageError> {
+        self.db.conn().execute(
+            "UPDATE send_sessions SET completed_at = ?1 WHERE session_id = ?2",
+            params![now(), self.session_id],
+        )?;
+
+        Ok(())
+    }
+}
+
+/// Persister for payjoin v2 receive sessions
+#[derive(Clone)]
+pub struct ReceiverPersister {
+    db: Arc<Database>,
+    session_id: SessionId,
+}
+
+impl ReceiverPersister {
+    /// Creates a new receiver persister, creating a new session in the database
+    pub fn new(db: Arc<Database>) -> Result<Self> {
+        let session_id: i64 = db.conn().query_row(
+            "INSERT INTO receive_sessions (session_id) VALUES (NULL) RETURNING session_id",
+            [],
+            |row| row.get(0),
+        )?;
+
+        Ok(Self {
+            db,
+            session_id: SessionId(session_id),
+        })
+    }
+
+    /// Creates a persister from an existing session ID
+    pub fn from_id(db: Arc<Database>, id: SessionId) -> Self {
+        Self { db, session_id: id }
+    }
+}
+
+impl SessionPersister for ReceiverPersister {
+    type SessionEvent = ReceiverSessionEvent;
+    type InternalStorageError = Error;
+
+    fn save_event(
+        &self,
+        event: ReceiverSessionEvent,
+    ) -> std::result::Result<(), Self::InternalStorageError> {
+        let event_data = serde_json::to_string(&event).map_err(Error::Serialize)?;
+
+        self.db.conn().execute(
+            "INSERT INTO receive_session_events (session_id, event_data, created_at) VALUES (?1, ?2, ?3)",
+            params![self.session_id, event_data, now()],
+        )?;
+
+        Ok(())
+    }
+
+    fn load(
+        &self,
+    ) -> std::result::Result<
+        Box<dyn Iterator<Item = ReceiverSessionEvent>>,
+        Self::InternalStorageError,
+    > {
+        let conn = self.db.conn();
+        let mut stmt = conn.prepare(
+            "SELECT event_data FROM receive_session_events WHERE session_id = ?1 ORDER BY id ASC",
+        )?;
+
+        let event_rows = stmt.query_map(params![self.session_id], |row| {
+            let event_data: String = row.get(0)?;
+            Ok(event_data)
+        })?;
+
+        let events: Vec<ReceiverSessionEvent> = event_rows
+            .map(|row| {
+                let event_data = row.expect("Failed to read event data from database");
+                serde_json::from_str::<ReceiverSessionEvent>(&event_data)
+                    .expect("Database corruption: failed to deserialize session event")
+            })
+            .collect();
+
+        Ok(Box::new(events.into_iter()))
+    }
+
+    fn close(&self) -> std::result::Result<(), Self::InternalStorageError> {
+        self.db.conn().execute(
+            "UPDATE receive_sessions SET completed_at = ?1 WHERE session_id = ?2",
+            params![now(), self.session_id],
+        )?;
+
+        Ok(())
+    }
+}