]> Untitled Git - bdk/commitdiff
feat(chain,wallet)!: Change persist-traits to be "safer"
author志宇 <hello@evanlinjin.me>
Fri, 19 Jul 2024 07:05:38 +0000 (07:05 +0000)
committer志宇 <hello@evanlinjin.me>
Fri, 19 Jul 2024 07:05:38 +0000 (07:05 +0000)
Previously, `Persist{Async}With::persist` can be directly called as a
method on the type (i.e. `Wallet`). However, the `db: Db` (which is an
input) may not be initialized. We want a design which makes it harder
for the caller to make this mistake.

We change `Persist{Async}With::persist` to be an "associated function"
which takes two inputs: `db: &mut Db` and `changeset`. However, the
implementer cannot take directly from `Self` (as it's no longer an
input). So we introduce a new trait `Staged` which defines the staged
changeset type and a method that gives us a `&mut` of the staged
changes.

crates/chain/src/persist.rs
crates/wallet/src/wallet/mod.rs
crates/wallet/src/wallet/persisted.rs

index 5f7b37e4c8a34d04a5d829dd34f1ff6abdca2fce..6bcdb6bdfd9e3a76bf7fd3c40c8085cf1767e63c 100644 (file)
@@ -6,10 +6,21 @@ use core::{
 
 use alloc::boxed::Box;
 
+use crate::Merge;
+
+/// Represents a type that contains staged changes.
+pub trait Staged {
+    /// Type for staged changes.
+    type ChangeSet: Merge;
+
+    /// Get mutable reference of staged changes.
+    fn staged(&mut self) -> &mut Self::ChangeSet;
+}
+
 /// Trait that persists the type with `Db`.
 ///
 /// Methods of this trait should not be called directly.
-pub trait PersistWith<Db>: Sized {
+pub trait PersistWith<Db>: Staged + Sized {
     /// Parameters for [`PersistWith::create`].
     type CreateParams;
     /// Parameters for [`PersistWith::load`].
@@ -21,20 +32,23 @@ pub trait PersistWith<Db>: Sized {
     /// Error type of [`PersistWith::persist`].
     type PersistError;
 
-    /// Create the type and initialize the `Db`.
+    /// Initialize the `Db` and create `Self`.
     fn create(db: &mut Db, params: Self::CreateParams) -> Result<Self, Self::CreateError>;
 
-    /// Load the type from the `Db`.
+    /// Initialize the `Db` and load a previously-persisted `Self`.
     fn load(db: &mut Db, params: Self::LoadParams) -> Result<Option<Self>, Self::LoadError>;
 
-    /// Persist staged changes into `Db`.
-    fn persist(&mut self, db: &mut Db) -> Result<bool, Self::PersistError>;
+    /// Persist changes to the `Db`.
+    fn persist(
+        db: &mut Db,
+        changeset: &<Self as Staged>::ChangeSet,
+    ) -> Result<(), Self::PersistError>;
 }
 
 type FutureResult<'a, T, E> = Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'a>>;
 
 /// Trait that persists the type with an async `Db`.
-pub trait PersistAsyncWith<Db>: Sized {
+pub trait PersistAsyncWith<Db>: Staged + Sized {
     /// Parameters for [`PersistAsyncWith::create`].
     type CreateParams;
     /// Parameters for [`PersistAsyncWith::load`].
@@ -46,14 +60,17 @@ pub trait PersistAsyncWith<Db>: Sized {
     /// Error type of [`PersistAsyncWith::persist`].
     type PersistError;
 
-    /// Create the type and initialize the `Db`.
+    /// Initialize the `Db` and create `Self`.
     fn create(db: &mut Db, params: Self::CreateParams) -> FutureResult<Self, Self::CreateError>;
 
-    /// Load the type from `Db`.
+    /// Initialize the `Db` and load a previously-persisted `Self`.
     fn load(db: &mut Db, params: Self::LoadParams) -> FutureResult<Option<Self>, Self::LoadError>;
 
-    /// Persist staged changes into `Db`.
-    fn persist<'a>(&'a mut self, db: &'a mut Db) -> FutureResult<'a, bool, Self::PersistError>;
+    /// Persist changes to the `Db`.
+    fn persist<'a>(
+        db: &'a mut Db,
+        changeset: &'a <Self as Staged>::ChangeSet,
+    ) -> FutureResult<'a, (), Self::PersistError>;
 }
 
 /// Represents a persisted `T`.
@@ -102,14 +119,24 @@ impl<T> Persisted<T> {
     }
 
     /// Persist staged changes of `T` into `Db`.
+    ///
+    /// If the database errors, the staged changes will not be cleared.
     pub fn persist<Db>(&mut self, db: &mut Db) -> Result<bool, T::PersistError>
     where
         T: PersistWith<Db>,
     {
-        self.inner.persist(db)
+        let stage = T::staged(&mut self.inner);
+        if stage.is_empty() {
+            return Ok(false);
+        }
+        T::persist(db, &*stage)?;
+        stage.take();
+        Ok(true)
     }
 
     /// Persist staged changes of `T` into an async `Db`.
+    ///
+    /// If the database errors, the staged changes will not be cleared.
     pub async fn persist_async<'a, Db>(
         &'a mut self,
         db: &'a mut Db,
@@ -117,7 +144,13 @@ impl<T> Persisted<T> {
     where
         T: PersistAsyncWith<Db>,
     {
-        self.inner.persist(db).await
+        let stage = T::staged(&mut self.inner);
+        if stage.is_empty() {
+            return Ok(false);
+        }
+        T::persist(db, &*stage).await?;
+        stage.take();
+        Ok(true)
     }
 }
 
index d2d7c0ad04746a759aaa0778e642dc6791c668b0..84fe52896b30d2242b92c43ef3d8917721a06932 100644 (file)
@@ -42,6 +42,7 @@ use bitcoin::{
 use bitcoin::{consensus::encode::serialize, transaction, BlockHash, Psbt};
 use bitcoin::{constants::genesis_block, Amount};
 use bitcoin::{secp256k1::Secp256k1, Weight};
+use chain::Staged;
 use core::fmt;
 use core::mem;
 use core::ops::Deref;
@@ -119,6 +120,14 @@ pub struct Wallet {
     secp: SecpCtx,
 }
 
+impl Staged for Wallet {
+    type ChangeSet = ChangeSet;
+
+    fn staged(&mut self) -> &mut Self::ChangeSet {
+        &mut self.stage
+    }
+}
+
 /// An update to [`Wallet`].
 ///
 /// It updates [`KeychainTxOutIndex`], [`bdk_chain::TxGraph`] and [`local_chain::LocalChain`] atomically.
index 70d9f27cceae49f250c1ad1a31a592e8032f3124..5dfea3f2a3df3cd8cfcd54e939ebf8500cc7cac9 100644 (file)
@@ -41,14 +41,10 @@ impl<'c> chain::PersistWith<bdk_chain::sqlite::Transaction<'c>> for Wallet {
     }
 
     fn persist(
-        &mut self,
-        conn: &mut bdk_chain::sqlite::Transaction,
-    ) -> Result<bool, Self::PersistError> {
-        if let Some(changeset) = self.take_staged() {
-            changeset.persist_to_sqlite(conn)?;
-            return Ok(true);
-        }
-        Ok(false)
+        db: &mut bdk_chain::sqlite::Transaction<'c>,
+        changeset: &<Self as chain::Staged>::ChangeSet,
+    ) -> Result<(), Self::PersistError> {
+        changeset.persist_to_sqlite(db)
     }
 }
 
@@ -82,13 +78,12 @@ impl chain::PersistWith<bdk_chain::sqlite::Connection> for Wallet {
     }
 
     fn persist(
-        &mut self,
         db: &mut bdk_chain::sqlite::Connection,
-    ) -> Result<bool, Self::PersistError> {
-        let mut db_tx = db.transaction()?;
-        let has_changes = chain::PersistWith::persist(self, &mut db_tx)?;
-        db_tx.commit()?;
-        Ok(has_changes)
+        changeset: &<Self as chain::Staged>::ChangeSet,
+    ) -> Result<(), Self::PersistError> {
+        let db_tx = db.transaction()?;
+        changeset.persist_to_sqlite(&db_tx)?;
+        db_tx.commit()
     }
 }
 
@@ -126,14 +121,10 @@ impl chain::PersistWith<bdk_file_store::Store<crate::ChangeSet>> for Wallet {
     }
 
     fn persist(
-        &mut self,
         db: &mut bdk_file_store::Store<crate::ChangeSet>,
-    ) -> Result<bool, Self::PersistError> {
-        if let Some(changeset) = self.take_staged() {
-            db.append_changeset(&changeset)?;
-            return Ok(true);
-        }
-        Ok(false)
+        changeset: &<Self as chain::Staged>::ChangeSet,
+    ) -> Result<(), Self::PersistError> {
+        db.append_changeset(changeset)
     }
 }