]> Untitled Git - bdk/commitdiff
feat(wallet)!: add persister (`P`) type param to `PersistedWallet<P>`
author志宇 <hello@evanlinjin.me>
Thu, 15 Aug 2024 05:49:18 +0000 (05:49 +0000)
committer志宇 <hello@evanlinjin.me>
Thu, 15 Aug 2024 05:54:34 +0000 (05:54 +0000)
This forces the caller to use the same persister type that they used for
loading/creating when calling `.persist` on `PersistedWallet`.

This is not totally fool-proof since we can have multiple instances of
the same persister type persisting to different databases. However, it
does further enforce some level of safety.

crates/wallet/src/wallet/params.rs
crates/wallet/src/wallet/persisted.rs
crates/wallet/tests/wallet.rs

index 22e7a5b73a92e337f7b38a92816d24320fc73c4b..d901247248b4cd50ae8993f06607489549f1a5c1 100644 (file)
@@ -113,7 +113,7 @@ impl CreateParams {
     pub fn create_wallet<P>(
         self,
         persister: &mut P,
-    ) -> Result<PersistedWallet, CreateWithPersistError<P::Error>>
+    ) -> Result<PersistedWallet<P>, CreateWithPersistError<P::Error>>
     where
         P: WalletPersister,
     {
@@ -124,7 +124,7 @@ impl CreateParams {
     pub async fn create_wallet_async<P>(
         self,
         persister: &mut P,
-    ) -> Result<PersistedWallet, CreateWithPersistError<P::Error>>
+    ) -> Result<PersistedWallet<P>, CreateWithPersistError<P::Error>>
     where
         P: AsyncWalletPersister,
     {
@@ -220,22 +220,22 @@ impl LoadParams {
         self
     }
 
-    /// Load [`PersistedWallet`] with the given `Db`.
+    /// Load [`PersistedWallet`] with the given `persister`.
     pub fn load_wallet<P>(
         self,
         persister: &mut P,
-    ) -> Result<Option<PersistedWallet>, LoadWithPersistError<P::Error>>
+    ) -> Result<Option<PersistedWallet<P>>, LoadWithPersistError<P::Error>>
     where
         P: WalletPersister,
     {
         PersistedWallet::load(persister, self)
     }
 
-    /// Load [`PersistedWallet`] with the given async `Db`.
+    /// Load [`PersistedWallet`] with the given async `persister`.
     pub async fn load_wallet_async<P>(
         self,
         persister: &mut P,
-    ) -> Result<Option<PersistedWallet>, LoadWithPersistError<P::Error>>
+    ) -> Result<Option<PersistedWallet<P>>, LoadWithPersistError<P::Error>>
     where
         P: AsyncWalletPersister,
     {
index cebc3fbd05efc9b973ce30cd2b070e1af235e185..38d489e27bcde04b85ae09f49019d6f8f0f00a7b 100644 (file)
@@ -1,6 +1,7 @@
 use core::{
     fmt,
     future::Future,
+    marker::PhantomData,
     ops::{Deref, DerefMut},
     pin::Pin,
 };
@@ -10,7 +11,7 @@ use chain::Merge;
 
 use crate::{descriptor::DescriptorError, ChangeSet, CreateParams, LoadParams, Wallet};
 
-/// Trait that persists [`Wallet`].
+/// Trait that persists [`PersistedWallet`].
 ///
 /// For an async version, use [`AsyncWalletPersister`].
 ///
@@ -50,7 +51,7 @@ pub trait WalletPersister {
 
 type FutureResult<'a, T, E> = Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'a>>;
 
-/// Async trait that persists [`Wallet`].
+/// Async trait that persists [`PersistedWallet`].
 ///
 /// For a blocking version, use [`WalletPersister`].
 ///
@@ -95,7 +96,7 @@ pub trait AsyncWalletPersister {
         Self: 'a;
 }
 
-/// Represents a persisted wallet.
+/// Represents a persisted wallet which persists into type `P`.
 ///
 /// This is a light wrapper around [`Wallet`] that enforces some level of safety-checking when used
 /// with a [`WalletPersister`] or [`AsyncWalletPersister`] implementation. Safety checks assume that
@@ -107,32 +108,36 @@ pub trait AsyncWalletPersister {
 /// * Ensure there were no previously persisted wallet data before creating a fresh wallet and
 ///     persisting it.
 /// * Only clear the staged changes of [`Wallet`] after persisting succeeds.
+/// * Ensure the wallet is persisted to the same `P` type as when created/loaded. Note that this is
+///     not completely fool-proof as you can have multiple instances of the same `P` type that are
+///     connected to different databases.
 #[derive(Debug)]
-pub struct PersistedWallet(pub(crate) Wallet);
+pub struct PersistedWallet<P> {
+    inner: Wallet,
+    marker: PhantomData<P>,
+}
 
-impl Deref for PersistedWallet {
+impl<P> Deref for PersistedWallet<P> {
     type Target = Wallet;
 
     fn deref(&self) -> &Self::Target {
-        &self.0
+        &self.inner
     }
 }
 
-impl DerefMut for PersistedWallet {
+impl<P> DerefMut for PersistedWallet<P> {
     fn deref_mut(&mut self) -> &mut Self::Target {
-        &mut self.0
+        &mut self.inner
     }
 }
 
-impl PersistedWallet {
+/// Methods when `P` is a [`WalletPersister`].
+impl<P: WalletPersister> PersistedWallet<P> {
     /// Create a new [`PersistedWallet`] with the given `persister` and `params`.
-    pub fn create<P>(
+    pub fn create(
         persister: &mut P,
         params: CreateParams,
-    ) -> Result<Self, CreateWithPersistError<P::Error>>
-    where
-        P: WalletPersister,
-    {
+    ) -> Result<Self, CreateWithPersistError<P::Error>> {
         let existing = P::initialize(persister).map_err(CreateWithPersistError::Persist)?;
         if !existing.is_empty() {
             return Err(CreateWithPersistError::DataAlreadyExists(existing));
@@ -142,17 +147,50 @@ impl PersistedWallet {
         if let Some(changeset) = inner.take_staged() {
             P::persist(persister, &changeset).map_err(CreateWithPersistError::Persist)?;
         }
-        Ok(Self(inner))
+        Ok(Self {
+            inner,
+            marker: PhantomData,
+        })
+    }
+
+    /// Load a previously [`PersistedWallet`] from the given `persister` and `params`.
+    pub fn load(
+        persister: &mut P,
+        params: LoadParams,
+    ) -> Result<Option<Self>, LoadWithPersistError<P::Error>> {
+        let changeset = P::initialize(persister).map_err(LoadWithPersistError::Persist)?;
+        Wallet::load_with_params(changeset, params)
+            .map(|opt| {
+                opt.map(|inner| PersistedWallet {
+                    inner,
+                    marker: PhantomData,
+                })
+            })
+            .map_err(LoadWithPersistError::InvalidChangeSet)
+    }
+
+    /// Persist staged changes of wallet into `persister`.
+    ///
+    /// If the `persister` errors, the staged changes will not be cleared.
+    pub fn persist(&mut self, persister: &mut P) -> Result<bool, P::Error> {
+        match self.inner.staged_mut() {
+            Some(stage) => {
+                P::persist(persister, &*stage)?;
+                let _ = stage.take();
+                Ok(true)
+            }
+            None => Ok(false),
+        }
     }
+}
 
+/// Methods when `P` is an [`AsyncWalletPersister`].
+impl<P: AsyncWalletPersister> PersistedWallet<P> {
     /// Create a new [`PersistedWallet`] witht the given async `persister` and `params`.
-    pub async fn create_async<P>(
+    pub async fn create_async(
         persister: &mut P,
         params: CreateParams,
-    ) -> Result<Self, CreateWithPersistError<P::Error>>
-    where
-        P: AsyncWalletPersister,
-    {
+    ) -> Result<Self, CreateWithPersistError<P::Error>> {
         let existing = P::initialize(persister)
             .await
             .map_err(CreateWithPersistError::Persist)?;
@@ -166,64 +204,35 @@ impl PersistedWallet {
                 .await
                 .map_err(CreateWithPersistError::Persist)?;
         }
-        Ok(Self(inner))
-    }
-
-    /// Load a previously [`PersistedWallet`] from the given `persister` and `params`.
-    pub fn load<P>(
-        persister: &mut P,
-        params: LoadParams,
-    ) -> Result<Option<Self>, LoadWithPersistError<P::Error>>
-    where
-        P: WalletPersister,
-    {
-        let changeset = P::initialize(persister).map_err(LoadWithPersistError::Persist)?;
-        Wallet::load_with_params(changeset, params)
-            .map(|opt| opt.map(PersistedWallet))
-            .map_err(LoadWithPersistError::InvalidChangeSet)
+        Ok(Self {
+            inner,
+            marker: PhantomData,
+        })
     }
 
     /// Load a previously [`PersistedWallet`] from the given async `persister` and `params`.
-    pub async fn load_async<P>(
+    pub async fn load_async(
         persister: &mut P,
         params: LoadParams,
-    ) -> Result<Option<Self>, LoadWithPersistError<P::Error>>
-    where
-        P: AsyncWalletPersister,
-    {
+    ) -> Result<Option<Self>, LoadWithPersistError<P::Error>> {
         let changeset = P::initialize(persister)
             .await
             .map_err(LoadWithPersistError::Persist)?;
         Wallet::load_with_params(changeset, params)
-            .map(|opt| opt.map(PersistedWallet))
+            .map(|opt| {
+                opt.map(|inner| PersistedWallet {
+                    inner,
+                    marker: PhantomData,
+                })
+            })
             .map_err(LoadWithPersistError::InvalidChangeSet)
     }
 
-    /// Persist staged changes of wallet into `persister`.
-    ///
-    /// If the `persister` errors, the staged changes will not be cleared.
-    pub fn persist<P>(&mut self, persister: &mut P) -> Result<bool, P::Error>
-    where
-        P: WalletPersister,
-    {
-        match self.0.staged_mut() {
-            Some(stage) => {
-                P::persist(persister, &*stage)?;
-                let _ = stage.take();
-                Ok(true)
-            }
-            None => Ok(false),
-        }
-    }
-
     /// Persist staged changes of wallet into an async `persister`.
     ///
     /// If the `persister` errors, the staged changes will not be cleared.
-    pub async fn persist_async<'a, P>(&'a mut self, persister: &mut P) -> Result<bool, P::Error>
-    where
-        P: AsyncWalletPersister,
-    {
-        match self.0.staged_mut() {
+    pub async fn persist_async<'a>(&'a mut self, persister: &mut P) -> Result<bool, P::Error> {
+        match self.inner.staged_mut() {
             Some(stage) => {
                 P::persist(persister, &*stage).await?;
                 let _ = stage.take();
index 53edd821584a2da56321eaa7934a05cb03abdc03..c530e779ceb1c2675970a21a573f7c7a851481ef 100644 (file)
@@ -194,7 +194,7 @@ fn wallet_load_checks() -> anyhow::Result<()> {
     where
         CreateDb: Fn(&Path) -> anyhow::Result<Db>,
         OpenDb: Fn(&Path) -> anyhow::Result<Db>,
-        Db: WalletPersister,
+        Db: WalletPersister + std::fmt::Debug,
         Db::Error: std::error::Error + Send + Sync + 'static,
     {
         let temp_dir = tempfile::tempdir().expect("must create tempdir");