]> Untitled Git - bdk/commitdiff
Various `RpcBlockchain` improvements
author志宇 <hello@evanlinjin.me>
Sat, 30 Jul 2022 13:12:18 +0000 (21:12 +0800)
committer志宇 <hello@evanlinjin.me>
Thu, 4 Aug 2022 03:29:38 +0000 (11:29 +0800)
These are as suggested by @danielabrozzoni and @afilini

Also introduced `RpcSyncParams::force_start_time` for users who
prioritise reliability above all else.

Also improved logging.

src/blockchain/rpc.rs

index f711452bf1ed873ae52390022170a072094f35c6..914d037591bc1d075e4a1772c7f96aaee6ad7efa 100644 (file)
@@ -40,14 +40,14 @@ use crate::error::MissingCachedScripts;
 use crate::{BlockTime, Error, FeeRate, KeychainKind, LocalUtxo, TransactionDetails};
 use bitcoin::Script;
 use bitcoincore_rpc::json::{
-    GetTransactionResult, GetTransactionResultDetailCategory, ImportMultiOptions,
-    ImportMultiRequest, ImportMultiRequestScriptPubkey, ImportMultiRescanSince,
-    ListTransactionResult, ScanningDetails,
+    GetTransactionResultDetailCategory, ImportMultiOptions, ImportMultiRequest,
+    ImportMultiRequestScriptPubkey, ImportMultiRescanSince, ListTransactionResult,
+    ListUnspentResultEntry, ScanningDetails,
 };
 use bitcoincore_rpc::jsonrpc::serde_json::{json, Value};
 use bitcoincore_rpc::Auth as RpcAuth;
 use bitcoincore_rpc::{Client, RpcApi};
-use log::debug;
+use log::{debug, info};
 use serde::{Deserialize, Serialize};
 use std::collections::{HashMap, HashSet};
 use std::path::PathBuf;
@@ -93,6 +93,8 @@ pub struct RpcSyncParams {
     pub start_script_count: usize,
     /// Time in unix seconds in which initial sync will start scanning from (0 to start from genesis).
     pub start_time: u64,
+    /// Forces every sync to use `start_time` as import timestamp.
+    pub force_start_time: bool,
     /// RPC poll rate (in seconds) to get state updates.
     pub poll_rate_sec: u64,
 }
@@ -102,6 +104,7 @@ impl Default for RpcSyncParams {
         Self {
             start_script_count: 100,
             start_time: 0,
+            force_start_time: false,
             poll_rate_sec: 3,
         }
     }
@@ -180,54 +183,15 @@ impl GetBlockHash for RpcBlockchain {
 }
 
 impl WalletSync for RpcBlockchain {
-    fn wallet_setup<D: BatchDatabase>(
-        &self,
-        db: &mut D,
-        progress_update: Box<dyn Progress>,
-    ) -> Result<(), Error> {
-        let db_scripts = db.iter_script_pubkeys(None)?;
-
-        // this is a hack to check whether the scripts are coming from a derivable descriptor
-        // we assume for non-derivable descriptors, the initial script count is always 1
-        let is_derivable = db_scripts.len() > 1;
-
-        // ensure db scripts meet start script count requirements
-        if is_derivable && db_scripts.len() < self.sync_params.start_script_count {
-            return Err(Error::MissingCachedScripts(MissingCachedScripts {
-                last_count: db_scripts.len(),
-                missing_count: self.sync_params.start_script_count - db_scripts.len(),
-            }));
-        }
-
-        // this tells Core wallet where to sync from for imported scripts
-        let start_epoch = db
-            .get_sync_time()?
-            .map_or(self.sync_params.start_time, |st| st.block_time.timestamp);
-
-        // import all scripts from db into Core wallet
-        if self.is_descriptors {
-            import_descriptors(&self.client, start_epoch, db_scripts.iter())?;
-        } else {
-            import_multi(&self.client, start_epoch, db_scripts.iter())?;
-        }
-
-        // await sync (TODO: Maybe make this async)
-        await_wallet_scan(
-            &self.client,
-            self.sync_params.poll_rate_sec,
-            &*progress_update,
-        )?;
-
-        // begin db batch updates
-        let mut db_batch = db.begin_batch();
-
-        // update batch: obtain db state then update state with core txids
-        DbState::from_db(db)?
-            .update_state(&self.client, db)?
-            .update_batch::<D>(&mut db_batch)?;
+    fn wallet_setup<D>(&self, db: &mut D, prog: Box<dyn Progress>) -> Result<(), Error>
+    where
+        D: BatchDatabase,
+    {
+        let batch = DbState::new(db, &self.sync_params, &*prog)?
+            .sync_with_core(&self.client, self.is_descriptors)?
+            .as_db_batch()?;
 
-        // apply batch updates to db
-        db.commit_batch(db_batch)
+        db.commit_batch(batch)
     }
 }
 
@@ -322,7 +286,13 @@ fn list_wallet_dir(client: &Client) -> Result<Vec<String>, Error> {
 }
 
 /// Represents the state of the [`crate::database::Database`].
-struct DbState {
+struct DbState<'a, D> {
+    db: &'a D,
+    params: &'a RpcSyncParams,
+    prog: &'a dyn Progress,
+
+    ext_spks: Vec<Script>,
+    int_spks: Vec<Script>,
     txs: HashMap<Txid, TransactionDetails>,
     utxos: HashSet<LocalUtxo>,
     last_indexes: HashMap<KeychainKind, u32>,
@@ -331,53 +301,94 @@ struct DbState {
     retained_txs: HashSet<Txid>, // txs to retain (everything else should be deleted)
     updated_txs: HashSet<Txid>,  // txs to update
     updated_utxos: HashSet<LocalUtxo>, // utxos to update
-    updated_last_indexes: HashSet<KeychainKind>,
 }
 
-impl DbState {
+impl<'a, D: BatchDatabase> DbState<'a, D> {
     /// Obtain [DbState] from [crate::database::Database].
-    fn from_db<D: BatchDatabase>(db: &D) -> Result<Self, Error> {
+    fn new(db: &'a D, params: &'a RpcSyncParams, prog: &'a dyn Progress) -> Result<Self, Error> {
+        let ext_spks = db.iter_script_pubkeys(Some(KeychainKind::External))?;
+        let int_spks = db.iter_script_pubkeys(Some(KeychainKind::Internal))?;
+
+        // This is a hack to see whether atleast one of the keychains comes from a derivable
+        // descriptor. We assume that non-derivable descriptors always has a script count of 1.
+        let last_count = std::cmp::max(ext_spks.len(), int_spks.len());
+        let has_derivable = last_count > 1;
+
+        // If at least one descriptor is derivable, we need to ensure scriptPubKeys are sufficiently
+        // cached.
+        if has_derivable && last_count < params.start_script_count {
+            let inner_err = MissingCachedScripts {
+                last_count,
+                missing_count: params.start_script_count - last_count,
+            };
+            debug!("requesting more spks with: {:?}", inner_err);
+            return Err(Error::MissingCachedScripts(inner_err));
+        }
+
         let txs = db
             .iter_txs(true)?
             .into_iter()
             .map(|tx| (tx.txid, tx))
             .collect::<HashMap<_, _>>();
+
         let utxos = db.iter_utxos()?.into_iter().collect::<HashSet<_>>();
+
         let last_indexes = [KeychainKind::External, KeychainKind::Internal]
             .iter()
-            .filter_map(|keychain| {
-                db.get_last_index(*keychain)
-                    .map(|v| v.map(|i| (*keychain, i)))
-                    .transpose()
+            .filter_map(|keychain| match db.get_last_index(*keychain) {
+                Ok(li_opt) => li_opt.map(|li| Ok((*keychain, li))),
+                Err(err) => Some(Err(err)),
             })
             .collect::<Result<HashMap<_, _>, Error>>()?;
 
+        info!("initial db state: txs={} utxos={}", txs.len(), utxos.len());
+
+        // "delta" fields
         let retained_txs = HashSet::with_capacity(txs.len());
         let updated_txs = HashSet::with_capacity(txs.len());
         let updated_utxos = HashSet::with_capacity(utxos.len());
-        let updated_last_indexes = HashSet::with_capacity(last_indexes.len());
 
         Ok(Self {
+            db,
+            params,
+            prog,
+            ext_spks,
+            int_spks,
             txs,
             utxos,
             last_indexes,
             retained_txs,
             updated_txs,
             updated_utxos,
-            updated_last_indexes,
         })
     }
 
-    /// Update [DbState] with Core wallet state
-    fn update_state<D>(&mut self, client: &Client, db: &D) -> Result<&mut Self, Error>
-    where
-        D: BatchDatabase,
-    {
-        let tx_iter = CoreTxIter::new(client, 10);
+    /// Sync states of [BatchDatabase] and Core wallet.
+    /// First we import all `scriptPubKey`s from database into core wallet
+    fn sync_with_core(&mut self, client: &Client, is_descriptor: bool) -> Result<&mut Self, Error> {
+        // this tells Core wallet where to sync from for imported scripts
+        let start_epoch = if self.params.force_start_time {
+            self.params.start_time
+        } else {
+            self.db
+                .get_sync_time()?
+                .map_or(self.params.start_time, |st| st.block_time.timestamp)
+        };
 
-        for tx_res in tx_iter {
-            let tx_res = tx_res?;
+        // sync scriptPubKeys from Database to Core wallet
+        let scripts_iter = self.ext_spks.iter().chain(&self.int_spks);
+        if is_descriptor {
+            import_descriptors(client, start_epoch, scripts_iter)?;
+        } else {
+            import_multi(client, start_epoch, scripts_iter)?;
+        }
+
+        // wait for Core wallet to rescan (TODO: maybe make this async)
+        await_wallet_scan(client, self.params.poll_rate_sec, self.prog)?;
 
+        // loop through results of Core RPC method `listtransactions`
+        for tx_res in CoreTxIter::new(client, 100) {
+            let tx_res = tx_res?;
             let mut updated = false;
 
             let db_tx = self.txs.entry(tx_res.info.txid).or_insert_with(|| {
@@ -390,11 +401,11 @@ impl DbState {
 
             // update raw tx (if needed)
             let raw_tx =
-                match &db_tx.transaction {
+                &*match &mut db_tx.transaction {
                     Some(raw_tx) => raw_tx,
-                    None => {
+                    db_tx_opt => {
                         updated = true;
-                        db_tx.transaction.insert(client.get_raw_transaction(
+                        db_tx_opt.insert(client.get_raw_transaction(
                             &tx_res.info.txid,
                             tx_res.info.blockhash.as_ref(),
                         )?)
@@ -415,7 +426,7 @@ impl DbState {
             }
 
             // update received (if needed)
-            let received = Self::_received_from_raw_tx(db, raw_tx)?;
+            let received = Self::received_from_raw_tx(self.db, raw_tx)?;
             if db_tx.received != received {
                 updated = true;
                 db_tx.received = received;
@@ -436,7 +447,7 @@ impl DbState {
                     })?;
 
                 if let Some((keychain, index)) =
-                    db.get_path_from_script_pubkey(&txout.script_pubkey)?
+                    self.db.get_path_from_script_pubkey(&txout.script_pubkey)?
                 {
                     let utxo = LocalUtxo {
                         outpoint: OutPoint::new(tx_res.info.txid, tx_res.detail.vout),
@@ -445,7 +456,7 @@ impl DbState {
                         is_spent: false,
                     };
                     self.updated_utxos.insert(utxo);
-                    self._update_last_index(keychain, index);
+                    self.update_last_index(keychain, index);
                 }
             }
 
@@ -456,16 +467,20 @@ impl DbState {
             }
         }
 
-        // update sent from tx inputs
+        // obtain vector of `TransactionDetails::sent` changes
         let sent_updates = self
             .txs
             .values()
-            .filter_map(|db_tx| {
-                let txid = self.retained_txs.get(&db_tx.txid)?;
-                self._sent_from_raw_tx(db, db_tx.transaction.as_ref()?)
+            // only bother to update txs that are retained
+            .filter(|db_tx| self.retained_txs.contains(&db_tx.txid))
+            // only bother to update txs where the raw tx is accessable
+            .filter_map(|db_tx| (db_tx.transaction.as_ref().map(|tx| (tx, db_tx.sent))))
+            // recalcuate sent value, only update txs in which sent value is changed
+            .filter_map(|(raw_tx, old_sent)| {
+                self.sent_from_raw_tx(raw_tx)
                     .map(|sent| {
-                        if db_tx.sent != sent {
-                            Some((*txid, sent))
+                        if sent != old_sent {
+                            Some((raw_tx.txid(), sent))
                         } else {
                             None
                         }
@@ -475,8 +490,10 @@ impl DbState {
             .collect::<Result<Vec<_>, _>>()?;
 
         // record send updates
-        sent_updates.into_iter().for_each(|(txid, sent)| {
+        sent_updates.iter().for_each(|&(txid, sent)| {
+            // apply sent field changes
             self.txs.entry(txid).and_modify(|db_tx| db_tx.sent = sent);
+            // mark tx as modified
             self.updated_txs.insert(txid);
         });
 
@@ -484,25 +501,21 @@ impl DbState {
         let core_utxos = client
             .list_unspent(Some(0), None, None, Some(true), None)?
             .into_iter()
-            .filter_map(|utxo_res| {
-                db.get_path_from_script_pubkey(&utxo_res.script_pub_key)
-                    .transpose()
-                    .map(|v| {
-                        v.map(|(keychain, index)| {
-                            // update last index if needed
-                            self._update_last_index(keychain, index);
-
-                            LocalUtxo {
-                                outpoint: OutPoint::new(utxo_res.txid, utxo_res.vout),
-                                keychain,
-                                txout: TxOut {
-                                    value: utxo_res.amount.as_sat(),
-                                    script_pubkey: utxo_res.script_pub_key,
-                                },
-                                is_spent: false,
-                            }
-                        })
-                    })
+            .filter_map(|utxo_entry| {
+                let path_result = self
+                    .db
+                    .get_path_from_script_pubkey(&utxo_entry.script_pub_key)
+                    .transpose()?;
+
+                let utxo_result = match path_result {
+                    Ok((keychain, index)) => {
+                        self.update_last_index(keychain, index);
+                        Ok(Self::make_local_utxo(utxo_entry, keychain, false))
+                    }
+                    Err(err) => Err(err),
+                };
+
+                Some(utxo_result)
             })
             .collect::<Result<HashSet<_>, Error>>()?;
 
@@ -521,19 +534,8 @@ impl DbState {
         Ok(self)
     }
 
-    /// We want to filter out conflicting transactions.
-    /// Only accept transactions that are already confirmed, or existing in mempool.
-    fn _filter_tx(client: &Client, res: GetTransactionResult) -> Option<GetTransactionResult> {
-        if res.info.confirmations > 0 || client.get_mempool_entry(&res.info.txid).is_ok() {
-            Some(res)
-        } else {
-            debug!("tx filtered: {}", res.info.txid);
-            None
-        }
-    }
-
     /// Calculates received amount from raw tx.
-    fn _received_from_raw_tx<D: BatchDatabase>(db: &D, raw_tx: &Transaction) -> Result<u64, Error> {
+    fn received_from_raw_tx(db: &D, raw_tx: &Transaction) -> Result<u64, Error> {
         raw_tx.output.iter().try_fold(0_u64, |recv, txo| {
             let v = if db.is_mine(&txo.script_pubkey)? {
                 txo.value
@@ -545,15 +547,16 @@ impl DbState {
     }
 
     /// Calculates sent from raw tx.
-    fn _sent_from_raw_tx<D: BatchDatabase>(
-        &self,
-        db: &D,
-        raw_tx: &Transaction,
-    ) -> Result<u64, Error> {
+    fn sent_from_raw_tx(&self, raw_tx: &Transaction) -> Result<u64, Error> {
+        let get_output = |outpoint: &OutPoint| {
+            let raw_tx = self.txs.get(&outpoint.txid)?.transaction.as_ref()?;
+            raw_tx.output.get(outpoint.vout as usize)
+        };
+
         raw_tx.input.iter().try_fold(0_u64, |sent, txin| {
-            let v = match self._previous_output(&txin.previous_output) {
+            let v = match get_output(&txin.previous_output) {
                 Some(prev_txo) => {
-                    if db.is_mine(&prev_txo.script_pubkey)? {
+                    if self.db.is_mine(&prev_txo.script_pubkey)? {
                         prev_txo.value
                     } else {
                         0
@@ -565,60 +568,74 @@ impl DbState {
         })
     }
 
-    fn _previous_output(&self, outpoint: &OutPoint) -> Option<&TxOut> {
-        let prev_tx = self.txs.get(&outpoint.txid)?.transaction.as_ref()?;
-        prev_tx.output.get(outpoint.vout as usize)
-    }
-
-    fn _update_last_index(&mut self, keychain: KeychainKind, index: u32) {
-        let mut updated = false;
-
+    // updates the db state's last_index for the given keychain (if larger than current last_index)
+    fn update_last_index(&mut self, keychain: KeychainKind, index: u32) {
         self.last_indexes
             .entry(keychain)
             .and_modify(|last| {
                 if *last < index {
-                    updated = true;
                     *last = index;
                 }
             })
-            .or_insert_with(|| {
-                updated = true;
-                index
-            });
-
-        if updated {
-            self.updated_last_indexes.insert(keychain);
+            .or_insert_with(|| index);
+    }
+
+    fn make_local_utxo(
+        entry: ListUnspentResultEntry,
+        keychain: KeychainKind,
+        is_spent: bool,
+    ) -> LocalUtxo {
+        LocalUtxo {
+            outpoint: OutPoint::new(entry.txid, entry.vout),
+            txout: TxOut {
+                value: entry.amount.as_sat(),
+                script_pubkey: entry.script_pub_key,
+            },
+            keychain,
+            is_spent,
         }
     }
 
     /// Prepare db batch operations.
-    fn update_batch<D: BatchDatabase>(&self, batch: &mut D::Batch) -> Result<(), Error> {
-        // delete stale txs from db
-        // stale = not retained
+    fn as_db_batch(&self) -> Result<D::Batch, Error> {
+        let mut batch = self.db.begin_batch();
+        let mut del_txs = 0_u32;
+
+        // delete stale (not retained) txs from db
         self.txs
             .keys()
             .filter(|&txid| !self.retained_txs.contains(txid))
-            .try_for_each(|txid| batch.del_tx(txid, false).map(|_| ()))?;
+            .try_for_each(|txid| -> Result<(), Error> {
+                batch.del_tx(txid, false)?;
+                del_txs += 1;
+                Ok(())
+            })?;
 
         // update txs
         self.updated_txs
             .iter()
-            .filter_map(|txid| self.txs.get(txid))
-            .try_for_each(|txd| batch.set_tx(txd))?;
+            .inspect(|&txid| debug!("updating tx: {}", txid))
+            .try_for_each(|txid| batch.set_tx(self.txs.get(txid).unwrap()))?;
 
         // update utxos
         self.updated_utxos
             .iter()
-            .inspect(|&utxo| println!("updating: {:?}", utxo.txout))
+            .inspect(|&utxo| debug!("updating utxo: {}", utxo.outpoint))
             .try_for_each(|utxo| batch.set_utxo(utxo))?;
 
         // update last indexes
-        self.updated_last_indexes
+        self.last_indexes
             .iter()
-            .map(|keychain| self.last_indexes.get_key_value(keychain).unwrap())
             .try_for_each(|(&keychain, &index)| batch.set_last_index(keychain, index))?;
 
-        Ok(())
+        info!(
+            "db batch updates: del_txs={}, update_txs={}, update_utxos={}",
+            del_txs,
+            self.updated_txs.len(),
+            self.updated_utxos.len()
+        );
+
+        Ok(batch)
     }
 }
 
@@ -678,6 +695,7 @@ where
     Ok(())
 }
 
+/// Iterates through results of multiple `listtransactions` calls.
 struct CoreTxIter<'a> {
     client: &'a Client,
     page_size: usize,
@@ -688,7 +706,11 @@ struct CoreTxIter<'a> {
 }
 
 impl<'a> CoreTxIter<'a> {
-    fn new(client: &'a Client, page_size: usize) -> Self {
+    fn new(client: &'a Client, mut page_size: usize) -> Self {
+        if page_size > 1000 {
+            page_size = 1000;
+        }
+
         Self {
             client,
             page_size,
@@ -700,7 +722,7 @@ impl<'a> CoreTxIter<'a> {
 
     /// We want to filter out conflicting transactions.
     /// Only accept transactions that are already confirmed, or existing in mempool.
-    fn tx_ok(&self, item: &ListTransactionResult) -> bool {
+    fn keep_tx(&self, item: &ListTransactionResult) -> bool {
         item.info.confirmations > 0 || self.client.get_mempool_entry(&item.info.txid).is_ok()
     }
 }
@@ -715,7 +737,7 @@ impl<'a> Iterator for CoreTxIter<'a> {
             }
 
             if let Some(item) = self.stack.pop() {
-                if self.tx_ok(&item) {
+                if self.keep_tx(&item) {
                     return Some(Ok(item));
                 }
             }
@@ -750,32 +772,26 @@ impl<'a> Iterator for CoreTxIter<'a> {
     }
 }
 
-fn get_scanning_details(client: &Client) -> Result<ScanningDetails, Error> {
+fn await_wallet_scan(client: &Client, rate_sec: u64, progress: &dyn Progress) -> Result<(), Error> {
     #[derive(Deserialize)]
     struct CallResult {
         scanning: ScanningDetails,
     }
-    let result: CallResult = client.call("getwalletinfo", &[])?;
-    Ok(result.scanning)
-}
 
-fn await_wallet_scan(
-    client: &Client,
-    poll_rate_sec: u64,
-    progress_update: &dyn Progress,
-) -> Result<(), Error> {
-    let dur = Duration::from_secs(poll_rate_sec);
+    let dur = Duration::from_secs(rate_sec);
     loop {
-        match get_scanning_details(client)? {
-            ScanningDetails::Scanning { duration, progress } => {
-                println!("scanning: duration={}, progress={}", duration, progress);
-                progress_update
-                    .update(progress, Some(format!("elapsed for {} seconds", duration)))?;
+        match client.call::<CallResult>("getwalletinfo", &[])?.scanning {
+            ScanningDetails::Scanning {
+                duration,
+                progress: pc,
+            } => {
+                debug!("scanning: duration={}, progress={}", duration, pc);
+                progress.update(pc, Some(format!("elapsed for {} seconds", duration)))?;
                 thread::sleep(dur);
             }
             ScanningDetails::NotScanning(_) => {
-                progress_update.update(1.0, None)?;
-                println!("scanning: done!");
+                progress.update(1.0, None)?;
+                info!("scanning: done!");
                 return Ok(());
             }
         };