]> Untitled Git - bdk/commitdiff
feat(electrum): optimize merkle proof validation with batching
authorWei Chen <wzc110@gmail.com>
Sat, 22 Mar 2025 12:16:12 +0000 (17:46 +0530)
committerWei Chen <wzc110@gmail.com>
Wed, 18 Jun 2025 21:35:45 +0000 (21:35 +0000)
Co-authored-by: keerthi <keerthi.sree2105@gmail.com>
crates/electrum/src/bdk_electrum_client.rs
crates/electrum/tests/test_electrum.rs

index e0eac5083274087e919a49b0d8468a72d812e1f4..f4ab32b206554944596b834844105c1e8256d12b 100644 (file)
@@ -12,6 +12,9 @@ use std::sync::{Arc, Mutex};
 /// We include a chain suffix of a certain length for the purpose of robustness.
 const CHAIN_SUFFIX_LENGTH: u32 = 8;
 
+/// Maximum batch size for proof validation requests
+const MAX_BATCH_SIZE: usize = 100;
+
 /// Wrapper around an [`electrum_client::ElectrumApi`] which includes an internal in-memory
 /// transaction cache to avoid re-fetching already downloaded transactions.
 #[derive(Debug)]
@@ -22,6 +25,8 @@ pub struct BdkElectrumClient<E> {
     tx_cache: Mutex<HashMap<Txid, Arc<Transaction>>>,
     /// The header cache
     block_header_cache: Mutex<HashMap<u32, Header>>,
+    /// Cache of transaction anchors
+    anchor_cache: Mutex<HashMap<(Txid, BlockHash), ConfirmationBlockTime>>,
 }
 
 impl<E: ElectrumApi> BdkElectrumClient<E> {
@@ -31,6 +36,7 @@ impl<E: ElectrumApi> BdkElectrumClient<E> {
             inner: client,
             tx_cache: Default::default(),
             block_header_cache: Default::default(),
+            anchor_cache: Default::default(),
         }
     }
 
@@ -135,13 +141,19 @@ impl<E: ElectrumApi> BdkElectrumClient<E> {
 
         let mut tx_update = TxUpdate::<ConfirmationBlockTime>::default();
         let mut last_active_indices = BTreeMap::<K, u32>::default();
+        let mut pending_anchors = Vec::new();
         for keychain in request.keychains() {
             let spks = request
                 .iter_spks(keychain.clone())
                 .map(|(spk_i, spk)| (spk_i, SpkWithExpectedTxids::from(spk)));
-            if let Some(last_active_index) =
-                self.populate_with_spks(start_time, &mut tx_update, spks, stop_gap, batch_size)?
-            {
+            if let Some(last_active_index) = self.populate_with_spks(
+                start_time,
+                &mut tx_update,
+                spks,
+                stop_gap,
+                batch_size,
+                &mut pending_anchors,
+            )? {
                 last_active_indices.insert(keychain, last_active_index);
             }
         }
@@ -151,6 +163,13 @@ impl<E: ElectrumApi> BdkElectrumClient<E> {
             self.fetch_prev_txout(&mut tx_update)?;
         }
 
+        if !pending_anchors.is_empty() {
+            let anchors = self.batch_fetch_anchors(&pending_anchors)?;
+            for (txid, anchor) in anchors {
+                tx_update.anchors.insert((anchor, txid));
+            }
+        }
+
         let chain_update = match tip_and_latest_blocks {
             Some((chain_tip, latest_blocks)) => Some(chain_update(
                 chain_tip,
@@ -204,6 +223,7 @@ impl<E: ElectrumApi> BdkElectrumClient<E> {
         };
 
         let mut tx_update = TxUpdate::<ConfirmationBlockTime>::default();
+        let mut pending_anchors = Vec::new();
         self.populate_with_spks(
             start_time,
             &mut tx_update,
@@ -213,15 +233,33 @@ impl<E: ElectrumApi> BdkElectrumClient<E> {
                 .map(|(i, spk)| (i as u32, spk)),
             usize::MAX,
             batch_size,
+            &mut pending_anchors,
+        )?;
+        self.populate_with_txids(
+            start_time,
+            &mut tx_update,
+            request.iter_txids(),
+            &mut pending_anchors,
+        )?;
+        self.populate_with_outpoints(
+            start_time,
+            &mut tx_update,
+            request.iter_outpoints(),
+            &mut pending_anchors,
         )?;
-        self.populate_with_txids(start_time, &mut tx_update, request.iter_txids())?;
-        self.populate_with_outpoints(start_time, &mut tx_update, request.iter_outpoints())?;
 
         // Fetch previous `TxOut`s for fee calculation if flag is enabled.
         if fetch_prev_txouts {
             self.fetch_prev_txout(&mut tx_update)?;
         }
 
+        if !pending_anchors.is_empty() {
+            let anchors = self.batch_fetch_anchors(&pending_anchors)?;
+            for (txid, anchor) in anchors {
+                tx_update.anchors.insert((anchor, txid));
+            }
+        }
+
         let chain_update = match tip_and_latest_blocks {
             Some((chain_tip, latest_blocks)) => Some(chain_update(
                 chain_tip,
@@ -249,16 +287,17 @@ impl<E: ElectrumApi> BdkElectrumClient<E> {
         mut spks_with_expected_txids: impl Iterator<Item = (u32, SpkWithExpectedTxids)>,
         stop_gap: usize,
         batch_size: usize,
+        pending_anchors: &mut Vec<(Txid, usize)>,
     ) -> Result<Option<u32>, Error> {
-        let mut unused_spk_count = 0_usize;
-        let mut last_active_index = Option::<u32>::None;
+        let mut unused_spk_count = 0;
+        let mut last_active_index = None;
 
         loop {
             let spks = (0..batch_size)
                 .map_while(|_| spks_with_expected_txids.next())
                 .collect::<Vec<_>>();
             if spks.is_empty() {
-                return Ok(last_active_index);
+                break;
             }
 
             let spk_histories = self
@@ -267,9 +306,9 @@ impl<E: ElectrumApi> BdkElectrumClient<E> {
 
             for ((spk_index, spk), spk_history) in spks.into_iter().zip(spk_histories) {
                 if spk_history.is_empty() {
-                    unused_spk_count = unused_spk_count.saturating_add(1);
+                    unused_spk_count += 1;
                     if unused_spk_count >= stop_gap {
-                        return Ok(last_active_index);
+                        break;
                     }
                 } else {
                     last_active_index = Some(spk_index);
@@ -292,7 +331,7 @@ impl<E: ElectrumApi> BdkElectrumClient<E> {
                     match tx_res.height.try_into() {
                         // Returned heights 0 & -1 are reserved for unconfirmed txs.
                         Ok(height) if height > 0 => {
-                            self.validate_merkle_for_anchor(tx_update, tx_res.tx_hash, height)?;
+                            pending_anchors.push((tx_res.tx_hash, height));
                         }
                         _ => {
                             tx_update.seen_ats.insert((tx_res.tx_hash, start_time));
@@ -301,6 +340,8 @@ impl<E: ElectrumApi> BdkElectrumClient<E> {
                 }
             }
         }
+
+        Ok(last_active_index)
     }
 
     /// Populate the `tx_update` with associated transactions/anchors of `outpoints`.
@@ -312,6 +353,7 @@ impl<E: ElectrumApi> BdkElectrumClient<E> {
         start_time: u64,
         tx_update: &mut TxUpdate<ConfirmationBlockTime>,
         outpoints: impl IntoIterator<Item = OutPoint>,
+        pending_anchors: &mut Vec<(Txid, usize)>,
     ) -> Result<(), Error> {
         for outpoint in outpoints {
             let op_txid = outpoint.txid;
@@ -337,7 +379,7 @@ impl<E: ElectrumApi> BdkElectrumClient<E> {
                     match res.height.try_into() {
                         // Returned heights 0 & -1 are reserved for unconfirmed txs.
                         Ok(height) if height > 0 => {
-                            self.validate_merkle_for_anchor(tx_update, res.tx_hash, height)?;
+                            pending_anchors.push((res.tx_hash, height));
                         }
                         _ => {
                             tx_update.seen_ats.insert((res.tx_hash, start_time));
@@ -359,7 +401,7 @@ impl<E: ElectrumApi> BdkElectrumClient<E> {
                     match res.height.try_into() {
                         // Returned heights 0 & -1 are reserved for unconfirmed txs.
                         Ok(height) if height > 0 => {
-                            self.validate_merkle_for_anchor(tx_update, res.tx_hash, height)?;
+                            pending_anchors.push((res.tx_hash, height));
                         }
                         _ => {
                             tx_update.seen_ats.insert((res.tx_hash, start_time));
@@ -368,6 +410,7 @@ impl<E: ElectrumApi> BdkElectrumClient<E> {
                 }
             }
         }
+
         Ok(())
     }
 
@@ -377,6 +420,7 @@ impl<E: ElectrumApi> BdkElectrumClient<E> {
         start_time: u64,
         tx_update: &mut TxUpdate<ConfirmationBlockTime>,
         txids: impl IntoIterator<Item = Txid>,
+        pending_anchors: &mut Vec<(Txid, usize)>,
     ) -> Result<(), Error> {
         for txid in txids {
             let tx = match self.fetch_tx(txid) {
@@ -402,7 +446,7 @@ impl<E: ElectrumApi> BdkElectrumClient<E> {
                 match r.height.try_into() {
                     // Returned heights 0 & -1 are reserved for unconfirmed txs.
                     Ok(height) if height > 0 => {
-                        self.validate_merkle_for_anchor(tx_update, txid, height)?;
+                        pending_anchors.push((txid, height));
                     }
                     _ => {
                         tx_update.seen_ats.insert((r.tx_hash, start_time));
@@ -412,52 +456,99 @@ impl<E: ElectrumApi> BdkElectrumClient<E> {
 
             tx_update.txs.push(tx);
         }
+
         Ok(())
     }
 
-    // Helper function which checks if a transaction is confirmed by validating the merkle proof.
-    // An anchor is inserted if the transaction is validated to be in a confirmed block.
-    fn validate_merkle_for_anchor(
+    /// Batch validate Merkle proofs, cache each confirmation anchor, and return them.
+    fn batch_fetch_anchors(
         &self,
-        tx_update: &mut TxUpdate<ConfirmationBlockTime>,
-        txid: Txid,
-        confirmation_height: usize,
-    ) -> Result<(), Error> {
-        if let Ok(merkle_res) = self
-            .inner
-            .transaction_get_merkle(&txid, confirmation_height)
+        txs_with_heights: &[(Txid, usize)],
+    ) -> Result<Vec<(Txid, ConfirmationBlockTime)>, Error> {
+        let mut results = Vec::with_capacity(txs_with_heights.len());
+        let mut to_fetch = Vec::new();
+
+        // Build a map for height to block hash conversions. This is for obtaining block hash data
+        // with minimum `fetch_header` calls.
+        let mut height_to_hash: HashMap<u32, BlockHash> = HashMap::new();
+        for &(_, height) in txs_with_heights {
+            let h = height as u32;
+            if !height_to_hash.contains_key(&h) {
+                // Try to obtain hash from the header cache, or fetch the header if absent.
+                let hash = self.fetch_header(h)?.block_hash();
+                height_to_hash.insert(h, hash);
+            }
+        }
+
+        // Check cache.
         {
-            let mut header = self.fetch_header(merkle_res.block_height as u32)?;
-            let mut is_confirmed_tx = electrum_client::utils::validate_merkle_proof(
-                &txid,
-                &header.merkle_root,
-                &merkle_res,
-            );
-
-            // Merkle validation will fail if the header in `block_header_cache` is outdated, so we
-            // want to check if there is a new header and validate against the new one.
-            if !is_confirmed_tx {
-                header = self.update_header(merkle_res.block_height as u32)?;
-                is_confirmed_tx = electrum_client::utils::validate_merkle_proof(
+            let anchor_cache = self.anchor_cache.lock().unwrap();
+            for &(txid, height) in txs_with_heights {
+                let h = height as u32;
+                let hash = height_to_hash[&h];
+                if let Some(anchor) = anchor_cache.get(&(txid, hash)) {
+                    results.push((txid, *anchor));
+                } else {
+                    to_fetch.push((txid, height, hash));
+                }
+            }
+        }
+
+        // Fetch missing proofs in batches
+        for chunk in to_fetch.chunks(MAX_BATCH_SIZE) {
+            for &(txid, height, hash) in chunk {
+                // Fetch the raw proof.
+                let proof = self.inner.transaction_get_merkle(&txid, height)?;
+
+                // Validate against header, retrying once on stale header.
+                let mut header = self.fetch_header(height as u32)?;
+                let mut valid = electrum_client::utils::validate_merkle_proof(
                     &txid,
                     &header.merkle_root,
-                    &merkle_res,
+                    &proof,
                 );
-            }
+                if !valid {
+                    header = self.update_header(height as u32)?;
+                    valid = electrum_client::utils::validate_merkle_proof(
+                        &txid,
+                        &header.merkle_root,
+                        &proof,
+                    );
+                }
 
-            if is_confirmed_tx {
-                tx_update.anchors.insert((
-                    ConfirmationBlockTime {
+                // Build and cache the anchor if merkle proof is valid.
+                if valid {
+                    let anchor = ConfirmationBlockTime {
                         confirmation_time: header.time as u64,
                         block_id: BlockId {
-                            height: merkle_res.block_height as u32,
+                            height: height as u32,
                             hash: header.block_hash(),
                         },
-                    },
-                    txid,
-                ));
+                    };
+                    self.anchor_cache
+                        .lock()
+                        .unwrap()
+                        .insert((txid, hash), anchor);
+                    results.push((txid, anchor));
+                }
             }
         }
+
+        Ok(results)
+    }
+
+    /// Validate a single transaction’s Merkle proof, cache its confirmation anchor, and update.
+    #[allow(dead_code)]
+    fn validate_anchor_for_update(
+        &self,
+        tx_update: &mut TxUpdate<ConfirmationBlockTime>,
+        txid: Txid,
+        confirmation_height: usize,
+    ) -> Result<(), Error> {
+        let anchors = self.batch_fetch_anchors(&[(txid, confirmation_height)])?;
+        for (txid, anchor) in anchors {
+            tx_update.anchors.insert((anchor, txid));
+        }
         Ok(())
     }
 
index 5302e62f271a51fd97cea0dc626a3c08e1c0931e..7b6a63cd8559da9e82e66db5b7f06471c1b555ff 100644 (file)
@@ -20,6 +20,7 @@ use core::time::Duration;
 use electrum_client::ElectrumApi;
 use std::collections::{BTreeSet, HashMap, HashSet};
 use std::str::FromStr;
+use std::time::Instant;
 
 // Batch size for `sync_with_electrum`.
 const BATCH_SIZE: usize = 5;
@@ -881,3 +882,51 @@ fn test_check_fee_calculation() -> anyhow::Result<()> {
     }
     Ok(())
 }
+
+#[test]
+pub fn test_sync_performance() -> anyhow::Result<()> {
+    const EXPECTED_MAX_SYNC_TIME: Duration = Duration::from_secs(5);
+    const NUM_ADDRESSES: usize = 1000;
+
+    let env = TestEnv::new()?;
+    let electrum_client = electrum_client::Client::new(env.electrsd.electrum_url.as_str())?;
+    let client = BdkElectrumClient::new(electrum_client);
+
+    // Generate test addresses.
+    let mut spks = Vec::with_capacity(NUM_ADDRESSES);
+    for _ in 0..NUM_ADDRESSES {
+        spks.push(get_test_spk());
+    }
+
+    // Mine some blocks and send transactions.
+    env.mine_blocks(101, None)?;
+    for spk in spks.iter().take(10) {
+        let addr = Address::from_script(spk, Network::Regtest)?;
+        env.send(&addr, Amount::from_sat(10_000))?;
+    }
+    env.mine_blocks(1, None)?;
+
+    // Setup receiver.
+    let (mut recv_chain, _) = LocalChain::from_genesis_hash(env.bitcoind.client.get_block_hash(0)?);
+    let mut recv_graph = IndexedTxGraph::<ConfirmationBlockTime, _>::new({
+        let mut recv_index = SpkTxOutIndex::default();
+        for spk in spks.iter() {
+            recv_index.insert_spk((), spk.clone());
+        }
+        recv_index
+    });
+
+    // Measure sync time.
+    let start = Instant::now();
+    let _ = sync_with_electrum(&client, spks.clone(), &mut recv_chain, &mut recv_graph)?;
+    let sync_duration = start.elapsed();
+
+    assert!(
+        sync_duration <= EXPECTED_MAX_SYNC_TIME,
+        "Sync took {:?}, which is longer than expected {:?}",
+        sync_duration,
+        EXPECTED_MAX_SYNC_TIME
+    );
+
+    Ok(())
+}