]> Untitled Git - bdk/commitdiff
feat!: change `load_from_persistence` to return an option
author志宇 <hello@evanlinjin.me>
Wed, 1 Nov 2023 01:21:24 +0000 (09:21 +0800)
committer志宇 <hello@evanlinjin.me>
Wed, 15 Nov 2023 23:07:49 +0000 (07:07 +0800)
`PersistBackend::is_empty` is removed. Instead, `load_from_persistence`
returns an option of the changeset. `None` means persistence is empty.
This is a better API than a separate method. We can now differentiate
between a persisted single changeset and nothing persisted.

`Store::aggregate_changeset` is refactored to return a `Result` instead
of a tuple. A new error type (`AggregateChangesetsError`) is introduced
to include the partially-aggregated changeset in the error. This is a
more idiomatic API.

crates/bdk/src/wallet/mod.rs
crates/chain/src/persist.rs
crates/file_store/src/store.rs
example-crates/example_cli/src/lib.rs

index ea76ad65688f1430af7f0a85f11c1fe2158bd33a..c9a1c28cb29efd6680bacd436d1b850823888d39 100644 (file)
@@ -293,6 +293,8 @@ pub enum LoadError<L> {
     Descriptor(crate::descriptor::DescriptorError),
     /// Loading data from the persistence backend failed.
     Load(L),
+    /// Wallet not initialized, persistence backend is empty.
+    NotInitialized,
     /// Data loaded from persistence is missing network type.
     MissingNetwork,
     /// Data loaded from persistence is missing genesis hash.
@@ -307,6 +309,9 @@ where
         match self {
             LoadError::Descriptor(e) => e.fmt(f),
             LoadError::Load(e) => e.fmt(f),
+            LoadError::NotInitialized => {
+                write!(f, "wallet is not initialized, persistence backend is empty")
+            }
             LoadError::MissingNetwork => write!(f, "loaded data is missing network type"),
             LoadError::MissingGenesis => write!(f, "loaded data is missing genesis hash"),
         }
@@ -330,6 +335,8 @@ pub enum NewOrLoadError<W, L> {
     Write(W),
     /// Loading from the persistence backend failed.
     Load(L),
+    /// Wallet is not initialized, persistence backend is empty.
+    NotInitialized,
     /// The loaded genesis hash does not match what was provided.
     LoadedGenesisDoesNotMatch {
         /// The expected genesis block hash.
@@ -356,6 +363,9 @@ where
             NewOrLoadError::Descriptor(e) => e.fmt(f),
             NewOrLoadError::Write(e) => write!(f, "failed to write to persistence: {}", e),
             NewOrLoadError::Load(e) => write!(f, "failed to load from persistence: {}", e),
+            NewOrLoadError::NotInitialized => {
+                write!(f, "wallet is not initialized, persistence backend is empty")
+            }
             NewOrLoadError::LoadedGenesisDoesNotMatch { expected, got } => {
                 write!(f, "loaded genesis hash is not {}, got {:?}", expected, got)
             }
@@ -451,11 +461,26 @@ impl<D> Wallet<D> {
         change_descriptor: Option<E>,
         mut db: D,
     ) -> Result<Self, LoadError<D::LoadError>>
+    where
+        D: PersistBackend<ChangeSet>,
+    {
+        let changeset = db
+            .load_from_persistence()
+            .map_err(LoadError::Load)?
+            .ok_or(LoadError::NotInitialized)?;
+        Self::load_from_changeset(descriptor, change_descriptor, db, changeset)
+    }
+
+    fn load_from_changeset<E: IntoWalletDescriptor>(
+        descriptor: E,
+        change_descriptor: Option<E>,
+        db: D,
+        changeset: ChangeSet,
+    ) -> Result<Self, LoadError<D::LoadError>>
     where
         D: PersistBackend<ChangeSet>,
     {
         let secp = Secp256k1::new();
-        let changeset = db.load_from_persistence().map_err(LoadError::Load)?;
         let network = changeset.network.ok_or(LoadError::MissingNetwork)?;
         let chain =
             LocalChain::from_changeset(changeset.chain).map_err(|_| LoadError::MissingGenesis)?;
@@ -517,8 +542,43 @@ impl<D> Wallet<D> {
     where
         D: PersistBackend<ChangeSet>,
     {
-        if db.is_empty().map_err(NewOrLoadError::Load)? {
-            return Self::new_with_genesis_hash(
+        let changeset = db.load_from_persistence().map_err(NewOrLoadError::Load)?;
+        match changeset {
+            Some(changeset) => {
+                let wallet =
+                    Self::load_from_changeset(descriptor, change_descriptor, db, changeset)
+                        .map_err(|e| match e {
+                            LoadError::Descriptor(e) => NewOrLoadError::Descriptor(e),
+                            LoadError::Load(e) => NewOrLoadError::Load(e),
+                            LoadError::NotInitialized => NewOrLoadError::NotInitialized,
+                            LoadError::MissingNetwork => {
+                                NewOrLoadError::LoadedNetworkDoesNotMatch {
+                                    expected: network,
+                                    got: None,
+                                }
+                            }
+                            LoadError::MissingGenesis => {
+                                NewOrLoadError::LoadedGenesisDoesNotMatch {
+                                    expected: genesis_hash,
+                                    got: None,
+                                }
+                            }
+                        })?;
+                if wallet.network != network {
+                    return Err(NewOrLoadError::LoadedNetworkDoesNotMatch {
+                        expected: network,
+                        got: Some(wallet.network),
+                    });
+                }
+                if wallet.chain.genesis_hash() != genesis_hash {
+                    return Err(NewOrLoadError::LoadedGenesisDoesNotMatch {
+                        expected: genesis_hash,
+                        got: Some(wallet.chain.genesis_hash()),
+                    });
+                }
+                Ok(wallet)
+            }
+            None => Self::new_with_genesis_hash(
                 descriptor,
                 change_descriptor,
                 db,
@@ -528,34 +588,8 @@ impl<D> Wallet<D> {
             .map_err(|e| match e {
                 NewError::Descriptor(e) => NewOrLoadError::Descriptor(e),
                 NewError::Write(e) => NewOrLoadError::Write(e),
-            });
-        }
-
-        let wallet = Self::load(descriptor, change_descriptor, db).map_err(|e| match e {
-            LoadError::Descriptor(e) => NewOrLoadError::Descriptor(e),
-            LoadError::Load(e) => NewOrLoadError::Load(e),
-            LoadError::MissingNetwork => NewOrLoadError::LoadedNetworkDoesNotMatch {
-                expected: network,
-                got: None,
-            },
-            LoadError::MissingGenesis => NewOrLoadError::LoadedGenesisDoesNotMatch {
-                expected: genesis_hash,
-                got: None,
-            },
-        })?;
-        if wallet.network != network {
-            return Err(NewOrLoadError::LoadedNetworkDoesNotMatch {
-                expected: network,
-                got: Some(wallet.network),
-            });
-        }
-        if wallet.chain.genesis_hash() != genesis_hash {
-            return Err(NewOrLoadError::LoadedGenesisDoesNotMatch {
-                expected: genesis_hash,
-                got: Some(wallet.chain.genesis_hash()),
-            });
+            }),
         }
-        Ok(wallet)
     }
 
     /// Get the Bitcoin network the wallet is using.
index 634e369e9a55932e8cfde428feafe3c7f4cb8aac..3c8c8b9e12b2646ca3c2a098de3fafd203f87553 100644 (file)
@@ -79,19 +79,10 @@ pub trait PersistBackend<C> {
     fn write_changes(&mut self, changeset: &C) -> Result<(), Self::WriteError>;
 
     /// Return the aggregate changeset `C` from persistence.
-    fn load_from_persistence(&mut self) -> Result<C, Self::LoadError>;
-
-    /// Returns whether the persistence backend contains no data.
-    fn is_empty(&mut self) -> Result<bool, Self::LoadError>
-    where
-        C: Append,
-    {
-        self.load_from_persistence()
-            .map(|changeset| changeset.is_empty())
-    }
+    fn load_from_persistence(&mut self) -> Result<Option<C>, Self::LoadError>;
 }
 
-impl<C: Default> PersistBackend<C> for () {
+impl<C> PersistBackend<C> for () {
     type WriteError = Infallible;
 
     type LoadError = Infallible;
@@ -100,11 +91,7 @@ impl<C: Default> PersistBackend<C> for () {
         Ok(())
     }
 
-    fn load_from_persistence(&mut self) -> Result<C, Self::LoadError> {
-        Ok(C::default())
-    }
-
-    fn is_empty(&mut self) -> Result<bool, Self::LoadError> {
-        Ok(true)
+    fn load_from_persistence(&mut self) -> Result<Option<C>, Self::LoadError> {
+        Ok(None)
     }
 }
index 8af10cbd0e961190ba21e19f943caa17d6ae2590..bf88b8d3448421db886598b8265e53ec5602edf7 100644 (file)
@@ -23,7 +23,7 @@ pub struct Store<'a, C> {
 
 impl<'a, C> PersistBackend<C> for Store<'a, C>
 where
-    C: Default + Append + serde::Serialize + serde::de::DeserializeOwned,
+    C: Append + serde::Serialize + serde::de::DeserializeOwned,
 {
     type WriteError = std::io::Error;
 
@@ -33,23 +33,14 @@ where
         self.append_changeset(changeset)
     }
 
-    fn load_from_persistence(&mut self) -> Result<C, Self::LoadError> {
-        let (changeset, result) = self.aggregate_changesets();
-        result.map(|_| changeset)
-    }
-
-    fn is_empty(&mut self) -> Result<bool, Self::LoadError> {
-        let init_pos = self.db_file.stream_position()?;
-        let stream_len = self.db_file.seek(io::SeekFrom::End(0))?;
-        let magic_len = self.magic.len() as u64;
-        self.db_file.seek(io::SeekFrom::Start(init_pos))?;
-        Ok(stream_len == magic_len)
+    fn load_from_persistence(&mut self) -> Result<Option<C>, Self::LoadError> {
+        self.aggregate_changesets().map_err(|e| e.iter_error)
     }
 }
 
 impl<'a, C> Store<'a, C>
 where
-    C: Default + Append + serde::Serialize + serde::de::DeserializeOwned,
+    C: Append + serde::Serialize + serde::de::DeserializeOwned,
 {
     /// Create a new [`Store`] file in write-only mode; error if the file exists.
     ///
@@ -156,16 +147,24 @@ where
     ///
     /// **WARNING**: This method changes the write position of the underlying file. The next
     /// changeset will be written over the erroring entry (or the end of the file if none existed).
-    pub fn aggregate_changesets(&mut self) -> (C, Result<(), IterError>) {
-        let mut changeset = C::default();
-        let result = (|| {
-            for next_changeset in self.iter_changesets() {
-                changeset.append(next_changeset?);
+    pub fn aggregate_changesets(&mut self) -> Result<Option<C>, AggregateChangesetsError<C>> {
+        let mut changeset = Option::<C>::None;
+        for next_changeset in self.iter_changesets() {
+            let next_changeset = match next_changeset {
+                Ok(next_changeset) => next_changeset,
+                Err(iter_error) => {
+                    return Err(AggregateChangesetsError {
+                        changeset,
+                        iter_error,
+                    })
+                }
+            };
+            match &mut changeset {
+                Some(changeset) => changeset.append(next_changeset),
+                changeset => *changeset = Some(next_changeset),
             }
-            Ok(())
-        })();
-
-        (changeset, result)
+        }
+        Ok(changeset)
     }
 
     /// Append a new changeset to the file and truncate the file to the end of the appended
@@ -196,6 +195,24 @@ where
     }
 }
 
+/// Error type for [`Store::aggregate_changesets`].
+#[derive(Debug)]
+pub struct AggregateChangesetsError<C> {
+    /// The partially-aggregated changeset.
+    pub changeset: Option<C>,
+
+    /// The error returned by [`EntryIter`].
+    pub iter_error: IterError,
+}
+
+impl<C> std::fmt::Display for AggregateChangesetsError<C> {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        std::fmt::Display::fmt(&self.iter_error, f)
+    }
+}
+
+impl<C: std::fmt::Debug> std::error::Error for AggregateChangesetsError<C> {}
+
 #[cfg(test)]
 mod test {
     use super::*;
@@ -248,25 +265,11 @@ mod test {
         {
             let mut db = Store::<TestChangeSet>::open_or_create_new(&TEST_MAGIC_BYTES, &file_path)
                 .expect("must recover");
-            let (recovered_changeset, r) = db.aggregate_changesets();
-            r.expect("must succeed");
-            assert_eq!(recovered_changeset, changeset);
+            let recovered_changeset = db.aggregate_changesets().expect("must succeed");
+            assert_eq!(recovered_changeset, Some(changeset));
         }
     }
 
-    #[test]
-    fn is_empty() {
-        let mut file = NamedTempFile::new().unwrap();
-        file.write_all(&TEST_MAGIC_BYTES).expect("should write");
-
-        let mut db =
-            Store::<TestChangeSet>::open(&TEST_MAGIC_BYTES, file.path()).expect("must open");
-        assert!(db.is_empty().expect("must read"));
-        db.write_changes(&vec!["hello".to_string(), "world".to_string()])
-            .expect("must write");
-        assert!(!db.is_empty().expect("must read"));
-    }
-
     #[test]
     fn new_fails_if_file_is_too_short() {
         let mut file = NamedTempFile::new().unwrap();
index 0b5d9cd37c841ccca38149701480cdec2b0d64fd..f9574c0e0b60f9be973595ea8a72e32116233e0a 100644 (file)
@@ -687,7 +687,7 @@ where
         Err(err) => return Err(anyhow::anyhow!("failed to init db backend: {:?}", err)),
     };
 
-    let init_changeset = db_backend.load_from_persistence()?;
+    let init_changeset = db_backend.load_from_persistence()?.unwrap_or_default();
 
     Ok((
         args,