]> Untitled Git - bdk/commitdiff
[wallet] Refill the address pool whenever necessary
authorAlekos Filini <alekos.filini@gmail.com>
Thu, 6 Aug 2020 16:11:07 +0000 (18:11 +0200)
committerAlekos Filini <alekos.filini@gmail.com>
Thu, 6 Aug 2020 16:11:07 +0000 (18:11 +0200)
src/cli.rs
src/wallet/mod.rs

index 735c0af92a0ce53486c4fe4041b8ac9f757917da..d4e4578905700cbdc8da1e5e6140ee9ed7d3a3a9 100644 (file)
@@ -291,7 +291,7 @@ where
     if let Some(_sub_matches) = matches.subcommand_matches("get_new_address") {
         Ok(Some(format!("{}", wallet.get_new_address()?)))
     } else if let Some(_sub_matches) = matches.subcommand_matches("sync") {
-        maybe_await!(wallet.sync(None, None))?;
+        maybe_await!(wallet.sync(None))?;
         Ok(None)
     } else if let Some(_sub_matches) = matches.subcommand_matches("list_unspent") {
         let mut res = String::new();
index f59424d1261ad6d3f1748ba5b6a50715de186a3f..7948a672aa7c78c07a02d392ca89133a524e52bb 100644 (file)
@@ -22,8 +22,7 @@ pub mod tx_builder;
 pub mod utils;
 
 pub use tx_builder::TxBuilder;
-
-use self::utils::IsDust;
+use utils::IsDust;
 
 use crate::blockchain::{noop_progress, Blockchain, OfflineBlockchain, OnlineBlockchain};
 use crate::database::{BatchDatabase, BatchOperations, DatabaseUtils};
@@ -33,6 +32,8 @@ use crate::psbt::{utils::PSBTUtils, PSBTSatisfier, PSBTSigner};
 use crate::signer::Signer;
 use crate::types::*;
 
+const CACHE_ADDR_BATCH_SIZE: u32 = 100;
+
 pub type OfflineWallet<D> = Wallet<OfflineBlockchain, D>;
 
 pub struct Wallet<B: Blockchain, D: BatchDatabase> {
@@ -93,11 +94,7 @@ where
     }
 
     pub fn get_new_address(&self) -> Result<Address, Error> {
-        let index = self
-            .database
-            .borrow_mut()
-            .increment_last_index(ScriptType::External)?;
-        // TODO: refill the address pool if index is close to the last cached addr
+        let index = self.fetch_and_increment_index(ScriptType::External)?;
 
         self.descriptor
             .derive(index)?
@@ -185,8 +182,10 @@ where
         // script is unknown in the database
         let input_witness_weight = std::cmp::max(
             self.get_descriptor_for(ScriptType::Internal)
+                .0
                 .max_satisfaction_weight(),
             self.get_descriptor_for(ScriptType::External)
+                .0
                 .max_satisfaction_weight(),
         );
 
@@ -283,7 +282,7 @@ where
                 None => continue,
             };
 
-            let desc = self.get_descriptor_for(script_type);
+            let (desc, _) = self.get_descriptor_for(script_type);
             psbt_input.hd_keypaths = desc.get_hd_keypaths(child)?;
             let derived_descriptor = desc.derive(child)?;
 
@@ -537,10 +536,13 @@ where
 
     // Internals
 
-    fn get_descriptor_for(&self, script_type: ScriptType) -> &ExtendedDescriptor {
+    fn get_descriptor_for(&self, script_type: ScriptType) -> (&ExtendedDescriptor, ScriptType) {
         let desc = match script_type {
-            ScriptType::External => &self.descriptor,
-            ScriptType::Internal => &self.change_descriptor.as_ref().unwrap_or(&self.descriptor),
+            ScriptType::Internal if self.change_descriptor.is_some() => (
+                self.change_descriptor.as_ref().unwrap(),
+                ScriptType::Internal,
+            ),
+            _ => (&self.descriptor, ScriptType::External),
         };
 
         desc
@@ -557,22 +559,70 @@ where
     }
 
     fn get_change_address(&self) -> Result<Script, Error> {
-        let (desc, script_type) = if self.change_descriptor.is_none() {
-            (&self.descriptor, ScriptType::External)
-        } else {
-            (
-                self.change_descriptor.as_ref().unwrap(),
-                ScriptType::Internal,
-            )
+        let (desc, script_type) = self.get_descriptor_for(ScriptType::Internal);
+        let index = self.fetch_and_increment_index(script_type)?;
+
+        Ok(desc.derive(index)?.script_pubkey())
+    }
+
+    fn fetch_and_increment_index(&self, script_type: ScriptType) -> Result<u32, Error> {
+        let (descriptor, script_type) = self.get_descriptor_for(script_type);
+        let index = match descriptor.is_fixed() {
+            true => 0,
+            false => self
+                .database
+                .borrow_mut()
+                .increment_last_index(script_type)?,
         };
 
-        // TODO: refill the address pool if index is close to the last cached addr
-        let index = self
+        if self
             .database
-            .borrow_mut()
-            .increment_last_index(script_type)?;
+            .borrow()
+            .get_script_pubkey_from_path(script_type, index)?
+            .is_none()
+        {
+            self.cache_addresses(script_type, index, CACHE_ADDR_BATCH_SIZE)?;
+        }
 
-        Ok(desc.derive(index)?.script_pubkey())
+        Ok(index)
+    }
+
+    fn cache_addresses(
+        &self,
+        script_type: ScriptType,
+        from: u32,
+        mut count: u32,
+    ) -> Result<(), Error> {
+        let (descriptor, script_type) = self.get_descriptor_for(script_type);
+        if descriptor.is_fixed() {
+            if from > 0 {
+                return Ok(());
+            }
+
+            count = 1;
+        }
+
+        let mut address_batch = self.database.borrow().begin_batch();
+
+        let start_time = time::Instant::new();
+        for i in from..(from + count) {
+            address_batch.set_script_pubkey(
+                &descriptor.derive(i)?.script_pubkey(),
+                script_type,
+                i,
+            )?;
+        }
+
+        info!(
+            "Derivation of {} addresses from {} took {} ms",
+            count,
+            from,
+            start_time.elapsed().as_millis()
+        );
+
+        self.database.borrow_mut().commit_batch(address_batch)?;
+
+        Ok(())
     }
 
     fn get_available_utxos(
@@ -621,25 +671,19 @@ where
 
         // try to add hd_keypaths if we've already seen the output
         for (psbt_input, out) in psbt.inputs.iter_mut().zip(input_utxos.iter()) {
-            debug!("searching hd_keypaths for out: {:?}", out);
-
             if let Some(out) = out {
-                let option_path = self
+                if let Some((script_type, child)) = self
                     .database
                     .borrow()
-                    .get_path_from_script_pubkey(&out.script_pubkey)?;
-
-                debug!("found descriptor path {:?}", option_path);
-
-                let (script_type, child) = match option_path {
-                    None => continue,
-                    Some((script_type, child)) => (script_type, child),
-                };
-
-                // merge hd_keypaths
-                let desc = self.get_descriptor_for(script_type);
-                let mut hd_keypaths = desc.get_hd_keypaths(child)?;
-                psbt_input.hd_keypaths.append(&mut hd_keypaths);
+                    .get_path_from_script_pubkey(&out.script_pubkey)?
+                {
+                    debug!("Found descriptor {:?}/{}", script_type, child);
+
+                    // merge hd_keypaths
+                    let (desc, _) = self.get_descriptor_for(script_type);
+                    let mut hd_keypaths = desc.get_hd_keypaths(child)?;
+                    psbt_input.hd_keypaths.append(&mut hd_keypaths);
+                }
             }
         }
 
@@ -669,61 +713,36 @@ where
     }
 
     #[maybe_async]
-    pub fn sync(
-        &self,
-        max_address: Option<u32>,
-        _batch_query_size: Option<usize>,
-    ) -> Result<(), Error> {
-        debug!("begin sync...");
-        // TODO: consider taking an RwLock as writere here to prevent other "read-only" calls to
-        // break because the db is in an inconsistent state
-
-        let max_address = if self.descriptor.is_fixed() {
-            0
-        } else {
-            max_address.unwrap_or(100)
-        };
-
-        // TODO:
-        // let batch_query_size = batch_query_size.unwrap_or(20);
+    pub fn sync(&self, max_address_param: Option<u32>) -> Result<(), Error> {
+        debug!("Begin sync...");
 
-        let last_addr = self
+        let max_address = match self.descriptor.is_fixed() {
+            true => 0,
+            false => max_address_param.unwrap_or(CACHE_ADDR_BATCH_SIZE),
+        };
+        if self
             .database
             .borrow()
-            .get_script_pubkey_from_path(ScriptType::External, max_address)?;
-
-        // cache a few of our addresses
-        if last_addr.is_none() {
-            let mut address_batch = self.database.borrow().begin_batch();
-            let start = time::Instant::new();
+            .get_script_pubkey_from_path(ScriptType::External, max_address)?
+            .is_none()
+        {
+            self.cache_addresses(ScriptType::External, 0, max_address)?;
+        }
 
-            for i in 0..=max_address {
-                let derived = self.descriptor.derive(i).unwrap();
+        if let Some(change_descriptor) = &self.change_descriptor {
+            let max_address = match change_descriptor.is_fixed() {
+                true => 0,
+                false => max_address_param.unwrap_or(CACHE_ADDR_BATCH_SIZE),
+            };
 
-                address_batch.set_script_pubkey(
-                    &derived.script_pubkey(),
-                    ScriptType::External,
-                    i,
-                )?;
-            }
-            if self.change_descriptor.is_some() {
-                for i in 0..=max_address {
-                    let derived = self.change_descriptor.as_ref().unwrap().derive(i).unwrap();
-
-                    address_batch.set_script_pubkey(
-                        &derived.script_pubkey(),
-                        ScriptType::Internal,
-                        i,
-                    )?;
-                }
+            if self
+                .database
+                .borrow()
+                .get_script_pubkey_from_path(ScriptType::Internal, max_address)?
+                .is_none()
+            {
+                self.cache_addresses(ScriptType::Internal, 0, max_address)?;
             }
-
-            info!(
-                "derivation of {} addresses, took {} ms",
-                max_address,
-                start.elapsed().as_millis()
-            );
-            self.database.borrow_mut().commit_batch(address_batch)?;
         }
 
         maybe_await!(self.client.sync(
@@ -740,3 +759,104 @@ where
         Ok(tx.txid())
     }
 }
+
+#[cfg(test)]
+mod test {
+    use bitcoin::Network;
+
+    use crate::database::memory::MemoryDatabase;
+    use crate::database::Database;
+    use crate::types::ScriptType;
+
+    use super::*;
+
+    #[test]
+    fn test_cache_addresses_fixed() {
+        let db = MemoryDatabase::new();
+        let wallet: OfflineWallet<_> = Wallet::new_offline(
+            "wpkh(L5EZftvrYaSudiozVRzTqLcHLNDoVn7H5HSfM9BAN6tMJX8oTWz6)",
+            None,
+            Network::Testnet,
+            db,
+        )
+        .unwrap();
+
+        assert_eq!(
+            wallet.get_new_address().unwrap().to_string(),
+            "tb1qj08ys4ct2hzzc2hcz6h2hgrvlmsjynaw43s835"
+        );
+        assert_eq!(
+            wallet.get_new_address().unwrap().to_string(),
+            "tb1qj08ys4ct2hzzc2hcz6h2hgrvlmsjynaw43s835"
+        );
+
+        assert!(wallet
+            .database
+            .borrow_mut()
+            .get_script_pubkey_from_path(ScriptType::External, 0)
+            .unwrap()
+            .is_some());
+        assert!(wallet
+            .database
+            .borrow_mut()
+            .get_script_pubkey_from_path(ScriptType::Internal, 0)
+            .unwrap()
+            .is_none());
+    }
+
+    #[test]
+    fn test_cache_addresses() {
+        let db = MemoryDatabase::new();
+        let wallet: OfflineWallet<_> = Wallet::new_offline("wpkh(tpubEBr4i6yk5nf5DAaJpsi9N2pPYBeJ7fZ5Z9rmN4977iYLCGco1VyjB9tvvuvYtfZzjD5A8igzgw3HeWeeKFmanHYqksqZXYXGsw5zjnj7KM9/*)", None, Network::Testnet, db).unwrap();
+
+        assert_eq!(
+            wallet.get_new_address().unwrap().to_string(),
+            "tb1q6yn66vajcctph75pvylgkksgpp6nq04ppwct9a"
+        );
+        assert_eq!(
+            wallet.get_new_address().unwrap().to_string(),
+            "tb1q4er7kxx6sssz3q7qp7zsqsdx4erceahhax77d7"
+        );
+
+        assert!(wallet
+            .database
+            .borrow_mut()
+            .get_script_pubkey_from_path(ScriptType::External, CACHE_ADDR_BATCH_SIZE - 1)
+            .unwrap()
+            .is_some());
+        assert!(wallet
+            .database
+            .borrow_mut()
+            .get_script_pubkey_from_path(ScriptType::External, CACHE_ADDR_BATCH_SIZE)
+            .unwrap()
+            .is_none());
+    }
+
+    #[test]
+    fn test_cache_addresses_refill() {
+        let db = MemoryDatabase::new();
+        let wallet: OfflineWallet<_> = Wallet::new_offline("wpkh(tpubEBr4i6yk5nf5DAaJpsi9N2pPYBeJ7fZ5Z9rmN4977iYLCGco1VyjB9tvvuvYtfZzjD5A8igzgw3HeWeeKFmanHYqksqZXYXGsw5zjnj7KM9/*)", None, Network::Testnet, db).unwrap();
+
+        assert_eq!(
+            wallet.get_new_address().unwrap().to_string(),
+            "tb1q6yn66vajcctph75pvylgkksgpp6nq04ppwct9a"
+        );
+        assert!(wallet
+            .database
+            .borrow_mut()
+            .get_script_pubkey_from_path(ScriptType::External, CACHE_ADDR_BATCH_SIZE - 1)
+            .unwrap()
+            .is_some());
+
+        for _ in 0..CACHE_ADDR_BATCH_SIZE {
+            wallet.get_new_address().unwrap();
+        }
+
+        assert!(wallet
+            .database
+            .borrow_mut()
+            .get_script_pubkey_from_path(ScriptType::External, CACHE_ADDR_BATCH_SIZE * 2 - 1)
+            .unwrap()
+            .is_some());
+    }
+}