]> Untitled Git - bdk/commitdiff
feat(wallet)!: add `new_or_load` methods
author志宇 <hello@evanlinjin.me>
Fri, 27 Oct 2023 06:14:25 +0000 (14:14 +0800)
committer志宇 <hello@evanlinjin.me>
Wed, 15 Nov 2023 23:07:48 +0000 (07:07 +0800)
These methods try to load wallet from persistence and initializes the
wallet instead if non-existant.

An internal helper method `create_signers` is added to reuse code.
Documentation is also improved.

crates/bdk/src/wallet/mod.rs
example-crates/wallet_esplora_async/src/main.rs
example-crates/wallet_esplora_blocking/src/main.rs

index 05a8342664140254004f6dbb255ee0db4e6b1344..ea76ad65688f1430af7f0a85f11c1fe2158bd33a 100644 (file)
@@ -28,7 +28,7 @@ use bdk_chain::{
     Append, BlockId, ChainPosition, ConfirmationTime, ConfirmationTimeHeightAnchor, FullTxOut,
     IndexedTxGraph, Persist, PersistBackend,
 };
-use bitcoin::secp256k1::Secp256k1;
+use bitcoin::secp256k1::{All, Secp256k1};
 use bitcoin::sighash::{EcdsaSighashType, TapSighashType};
 use bitcoin::{
     absolute, Address, Network, OutPoint, Script, ScriptBuf, Sequence, Transaction, TxOut, Txid,
@@ -253,8 +253,13 @@ impl Wallet {
     }
 }
 
+/// The error type when constructing a fresh [`Wallet`].
+///
+/// Methods [`new`] and [`new_with_genesis_hash`] may return this error.
+///
+/// [`new`]: Wallet::new
+/// [`new_with_genesis_hash`]: Wallet::new_with_genesis_hash
 #[derive(Debug)]
-/// Error returned from [`Wallet::new`]
 pub enum NewError<W> {
     /// There was problem with the passed-in descriptor(s).
     Descriptor(crate::descriptor::DescriptorError),
@@ -277,7 +282,11 @@ where
 #[cfg(feature = "std")]
 impl<W> std::error::Error for NewError<W> where W: core::fmt::Display + core::fmt::Debug {}
 
-/// An error that may occur when loading a [`Wallet`] from persistence.
+/// The error type when loading a [`Wallet`] from persistence.
+///
+/// Method [`load`] may return this error.
+///
+/// [`load`]: Wallet::load
 #[derive(Debug)]
 pub enum LoadError<L> {
     /// There was a problem with the passed-in descriptor(s).
@@ -307,6 +316,64 @@ where
 #[cfg(feature = "std")]
 impl<L> std::error::Error for LoadError<L> where L: core::fmt::Display + core::fmt::Debug {}
 
+/// Error type for when we try load a [`Wallet`] from persistence and creating it if non-existant.
+///
+/// Methods [`new_or_load`] and [`new_or_load_with_genesis_hash`] may return this error.
+///
+/// [`new_or_load`]: Wallet::new_or_load
+/// [`new_or_load_with_genesis_hash`]: Wallet::new_or_load_with_genesis_hash
+#[derive(Debug)]
+pub enum NewOrLoadError<W, L> {
+    /// There is a problem with the passed-in descriptor.
+    Descriptor(crate::descriptor::DescriptorError),
+    /// Writing to the persistence backend failed.
+    Write(W),
+    /// Loading from the persistence backend failed.
+    Load(L),
+    /// The loaded genesis hash does not match what was provided.
+    LoadedGenesisDoesNotMatch {
+        /// The expected genesis block hash.
+        expected: BlockHash,
+        /// The block hash loaded from persistence.
+        got: Option<BlockHash>,
+    },
+    /// The loaded network type does not match what was provided.
+    LoadedNetworkDoesNotMatch {
+        /// The expected network type.
+        expected: Network,
+        /// The network type loaded from persistence.
+        got: Option<Network>,
+    },
+}
+
+impl<W, L> fmt::Display for NewOrLoadError<W, L>
+where
+    W: fmt::Display,
+    L: fmt::Display,
+{
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        match self {
+            NewOrLoadError::Descriptor(e) => e.fmt(f),
+            NewOrLoadError::Write(e) => write!(f, "failed to write to persistence: {}", e),
+            NewOrLoadError::Load(e) => write!(f, "failed to load from persistence: {}", e),
+            NewOrLoadError::LoadedGenesisDoesNotMatch { expected, got } => {
+                write!(f, "loaded genesis hash is not {}, got {:?}", expected, got)
+            }
+            NewOrLoadError::LoadedNetworkDoesNotMatch { expected, got } => {
+                write!(f, "loaded network type is not {}, got {:?}", expected, got)
+            }
+        }
+    }
+}
+
+#[cfg(feature = "std")]
+impl<W, L> std::error::Error for NewOrLoadError<W, L>
+where
+    W: core::fmt::Display + core::fmt::Debug,
+    L: core::fmt::Display + core::fmt::Debug,
+{
+}
+
 /// An error that may occur when inserting a transaction into [`Wallet`].
 #[derive(Debug)]
 pub enum InsertTxError {
@@ -321,8 +388,7 @@ pub enum InsertTxError {
 }
 
 impl<D> Wallet<D> {
-    /// Create a wallet from a `descriptor` (and an optional `change_descriptor`) and load related
-    /// transaction data from `db`.
+    /// Initialize an empty [`Wallet`].
     pub fn new<E: IntoWalletDescriptor>(
         descriptor: E,
         change_descriptor: Option<E>,
@@ -336,9 +402,10 @@ impl<D> Wallet<D> {
         Self::new_with_genesis_hash(descriptor, change_descriptor, db, network, genesis_hash)
     }
 
-    /// Create a new [`Wallet`] with a custom genesis hash.
+    /// Initialize an empty [`Wallet`] with a custom genesis hash.
     ///
-    /// This is like [`Wallet::new`] with an additional `custom_genesis_hash` parameter.
+    /// This is like [`Wallet::new`] with an additional `genesis_hash` parameter. This is useful
+    /// for syncing from alternative networks.
     pub fn new_with_genesis_hash<E: IntoWalletDescriptor>(
         descriptor: E,
         change_descriptor: Option<E>,
@@ -350,35 +417,18 @@ impl<D> Wallet<D> {
         D: PersistBackend<ChangeSet>,
     {
         let secp = Secp256k1::new();
-        let (chain, _) = LocalChain::from_genesis_hash(genesis_hash);
-        let mut indexed_graph = IndexedTxGraph::<
-            ConfirmationTimeHeightAnchor,
-            KeychainTxOutIndex<KeychainKind>,
-        >::default();
+        let (chain, chain_changeset) = LocalChain::from_genesis_hash(genesis_hash);
+        let mut index = KeychainTxOutIndex::<KeychainKind>::default();
 
-        let (descriptor, keymap) = into_wallet_descriptor_checked(descriptor, &secp, network)
-            .map_err(NewError::Descriptor)?;
-        indexed_graph
-            .index
-            .add_keychain(KeychainKind::External, descriptor.clone());
-        let signers = Arc::new(SignersContainer::build(keymap, &descriptor, &secp));
-
-        let change_signers = Arc::new(match change_descriptor {
-            Some(desc) => {
-                let (descriptor, keymap) = into_wallet_descriptor_checked(desc, &secp, network)
-                    .map_err(NewError::Descriptor)?;
-                let signers = SignersContainer::build(keymap, &descriptor, &secp);
-                indexed_graph
-                    .index
-                    .add_keychain(KeychainKind::Internal, descriptor);
-                signers
-            }
-            None => SignersContainer::new(),
-        });
+        let (signers, change_signers) =
+            create_signers(&mut index, &secp, descriptor, change_descriptor, network)
+                .map_err(NewError::Descriptor)?;
+
+        let indexed_graph = IndexedTxGraph::new(index);
 
         let mut persist = Persist::new(db);
         persist.stage(ChangeSet {
-            chain: chain.initial_changeset(),
+            chain: chain_changeset,
             indexed_tx_graph: indexed_graph.initial_changeset(),
             network: Some(network),
         });
@@ -395,7 +445,7 @@ impl<D> Wallet<D> {
         })
     }
 
-    /// Load [`Wallet`] from persistence.
+    /// Load [`Wallet`] from the given persistence backend.
     pub fn load<E: IntoWalletDescriptor>(
         descriptor: E,
         change_descriptor: Option<E>,
@@ -405,31 +455,15 @@ impl<D> Wallet<D> {
         D: PersistBackend<ChangeSet>,
     {
         let secp = Secp256k1::new();
-
         let changeset = db.load_from_persistence().map_err(LoadError::Load)?;
         let network = changeset.network.ok_or(LoadError::MissingNetwork)?;
-
         let chain =
             LocalChain::from_changeset(changeset.chain).map_err(|_| LoadError::MissingGenesis)?;
-
         let mut index = KeychainTxOutIndex::<KeychainKind>::default();
 
-        let (descriptor, keymap) = into_wallet_descriptor_checked(descriptor, &secp, network)
-            .map_err(LoadError::Descriptor)?;
-        let signers = Arc::new(SignersContainer::build(keymap, &descriptor, &secp));
-        index.add_keychain(KeychainKind::External, descriptor);
-
-        let change_signers = Arc::new(match change_descriptor {
-            Some(descriptor) => {
-                let (descriptor, keymap) =
-                    into_wallet_descriptor_checked(descriptor, &secp, network)
-                        .map_err(LoadError::Descriptor)?;
-                let signers = SignersContainer::build(keymap, &descriptor, &secp);
-                index.add_keychain(KeychainKind::Internal, descriptor);
-                signers
-            }
-            None => SignersContainer::new(),
-        });
+        let (signers, change_signers) =
+            create_signers(&mut index, &secp, descriptor, change_descriptor, network)
+                .map_err(LoadError::Descriptor)?;
 
         let indexed_graph = IndexedTxGraph::new(index);
         let persist = Persist::new(db);
@@ -445,6 +479,85 @@ impl<D> Wallet<D> {
         })
     }
 
+    /// Either loads [`Wallet`] from persistence, or initializes it if it does not exist.
+    ///
+    /// This method will fail if the loaded [`Wallet`] has different parameters to those provided.
+    pub fn new_or_load<E: IntoWalletDescriptor>(
+        descriptor: E,
+        change_descriptor: Option<E>,
+        db: D,
+        network: Network,
+    ) -> Result<Self, NewOrLoadError<D::WriteError, D::LoadError>>
+    where
+        D: PersistBackend<ChangeSet>,
+    {
+        let genesis_hash = genesis_block(network).block_hash();
+        Self::new_or_load_with_genesis_hash(
+            descriptor,
+            change_descriptor,
+            db,
+            network,
+            genesis_hash,
+        )
+    }
+
+    /// Either loads [`Wallet`] from persistence, or initializes it if it does not exist (with a
+    /// custom genesis hash).
+    ///
+    /// This method will fail if the loaded [`Wallet`] has different parameters to those provided.
+    /// This is like [`Wallet::new_or_load`] with an additional `genesis_hash` parameter. This is
+    /// useful for syncing from alternative networks.
+    pub fn new_or_load_with_genesis_hash<E: IntoWalletDescriptor>(
+        descriptor: E,
+        change_descriptor: Option<E>,
+        mut db: D,
+        network: Network,
+        genesis_hash: BlockHash,
+    ) -> Result<Self, NewOrLoadError<D::WriteError, D::LoadError>>
+    where
+        D: PersistBackend<ChangeSet>,
+    {
+        if db.is_empty().map_err(NewOrLoadError::Load)? {
+            return Self::new_with_genesis_hash(
+                descriptor,
+                change_descriptor,
+                db,
+                network,
+                genesis_hash,
+            )
+            .map_err(|e| match e {
+                NewError::Descriptor(e) => NewOrLoadError::Descriptor(e),
+                NewError::Write(e) => NewOrLoadError::Write(e),
+            });
+        }
+
+        let wallet = Self::load(descriptor, change_descriptor, db).map_err(|e| match e {
+            LoadError::Descriptor(e) => NewOrLoadError::Descriptor(e),
+            LoadError::Load(e) => NewOrLoadError::Load(e),
+            LoadError::MissingNetwork => NewOrLoadError::LoadedNetworkDoesNotMatch {
+                expected: network,
+                got: None,
+            },
+            LoadError::MissingGenesis => NewOrLoadError::LoadedGenesisDoesNotMatch {
+                expected: genesis_hash,
+                got: None,
+            },
+        })?;
+        if wallet.network != network {
+            return Err(NewOrLoadError::LoadedNetworkDoesNotMatch {
+                expected: network,
+                got: Some(wallet.network),
+            });
+        }
+        if wallet.chain.genesis_hash() != genesis_hash {
+            return Err(NewOrLoadError::LoadedGenesisDoesNotMatch {
+                expected: genesis_hash,
+                got: Some(wallet.chain.genesis_hash()),
+            });
+        }
+        Ok(wallet)
+    }
+
     /// Get the Bitcoin network the wallet is using.
     pub fn network(&self) -> Network {
         self.network
@@ -2158,6 +2271,30 @@ fn new_local_utxo(
     }
 }
 
+fn create_signers<E: IntoWalletDescriptor>(
+    index: &mut KeychainTxOutIndex<KeychainKind>,
+    secp: &Secp256k1<All>,
+    descriptor: E,
+    change_descriptor: Option<E>,
+    network: Network,
+) -> Result<(Arc<SignersContainer>, Arc<SignersContainer>), crate::descriptor::error::Error> {
+    let (descriptor, keymap) = into_wallet_descriptor_checked(descriptor, secp, network)?;
+    let signers = Arc::new(SignersContainer::build(keymap, &descriptor, secp));
+    index.add_keychain(KeychainKind::External, descriptor);
+
+    let change_signers = match change_descriptor {
+        Some(descriptor) => {
+            let (descriptor, keymap) = into_wallet_descriptor_checked(descriptor, secp, network)?;
+            let signers = Arc::new(SignersContainer::build(keymap, &descriptor, secp));
+            index.add_keychain(KeychainKind::Internal, descriptor);
+            signers
+        }
+        None => Arc::new(SignersContainer::new()),
+    };
+
+    Ok((signers, change_signers))
+}
+
 #[macro_export]
 #[doc(hidden)]
 /// Macro for getting a wallet for use in a doctest
index 5c7d09d59ac160b6c10da04e92c7a46df3a2f7d9..e44db9d8d65eed1950b706eb3b84bd6964695135 100644 (file)
@@ -2,6 +2,7 @@ use std::{io::Write, str::FromStr};
 
 use bdk::{
     bitcoin::{Address, Network},
+    chain::PersistBackend,
     wallet::{AddressIndex, Update},
     SignOptions, Wallet,
 };
index f4de498cdb86335edaa53de837ac3d784f889196..ec1076597aa1a7fcf686224ba58a70dafd573b32 100644 (file)
@@ -7,6 +7,7 @@ use std::{io::Write, str::FromStr};
 
 use bdk::{
     bitcoin::{Address, Network},
+    chain::PersistBackend,
     wallet::{AddressIndex, Update},
     SignOptions, Wallet,
 };