]> Untitled Git - bdk/commitdiff
feat(chain): `SyncRequest` now uses `ExactSizeIterator`s
author志宇 <hello@evanlinjin.me>
Sat, 27 Apr 2024 12:40:08 +0000 (20:40 +0800)
committer志宇 <hello@evanlinjin.me>
Sat, 27 Apr 2024 12:40:08 +0000 (20:40 +0800)
This allows the caller to track sync progress.

crates/chain/src/spk_client.rs
example-crates/example_esplora/src/main.rs

index 7873ba227072d7dfd49f0b0d4bfe98a39683cd09..eefa211c6ae6892ee7a3c95d68134be2958c452d 100644 (file)
@@ -1,6 +1,6 @@
 //! Helper types for spk-based blockchain clients.
 
-use core::{fmt::Debug, ops::RangeBounds};
+use core::{fmt::Debug, marker::PhantomData, ops::RangeBounds};
 
 use alloc::{boxed::Box, collections::BTreeMap, vec::Vec};
 use bitcoin::{OutPoint, Script, ScriptBuf, Txid};
@@ -18,11 +18,11 @@ pub struct SyncRequest {
     /// [`LocalChain::tip`]: crate::local_chain::LocalChain::tip
     pub chain_tip: CheckPoint,
     /// Transactions that spend from or to these indexed script pubkeys.
-    pub spks: Box<dyn Iterator<Item = ScriptBuf> + Send>,
+    pub spks: Box<dyn ExactSizeIterator<Item = ScriptBuf> + Send>,
     /// Transactions with these txids.
-    pub txids: Box<dyn Iterator<Item = Txid> + Send>,
+    pub txids: Box<dyn ExactSizeIterator<Item = Txid> + Send>,
     /// Transactions with these outpoints or spent from these outpoints.
-    pub outpoints: Box<dyn Iterator<Item = OutPoint> + Send>,
+    pub outpoints: Box<dyn ExactSizeIterator<Item = OutPoint> + Send>,
 }
 
 impl SyncRequest {
@@ -42,7 +42,7 @@ impl SyncRequest {
     #[must_use]
     pub fn set_spks(
         mut self,
-        spks: impl IntoIterator<IntoIter = impl Iterator<Item = ScriptBuf> + Send + 'static>,
+        spks: impl IntoIterator<IntoIter = impl ExactSizeIterator<Item = ScriptBuf> + Send + 'static>,
     ) -> Self {
         self.spks = Box::new(spks.into_iter());
         self
@@ -54,7 +54,7 @@ impl SyncRequest {
     #[must_use]
     pub fn set_txids(
         mut self,
-        txids: impl IntoIterator<IntoIter = impl Iterator<Item = Txid> + Send + 'static>,
+        txids: impl IntoIterator<IntoIter = impl ExactSizeIterator<Item = Txid> + Send + 'static>,
     ) -> Self {
         self.txids = Box::new(txids.into_iter());
         self
@@ -66,7 +66,9 @@ impl SyncRequest {
     #[must_use]
     pub fn set_outpoints(
         mut self,
-        outpoints: impl IntoIterator<IntoIter = impl Iterator<Item = OutPoint> + Send + 'static>,
+        outpoints: impl IntoIterator<
+            IntoIter = impl ExactSizeIterator<Item = OutPoint> + Send + 'static,
+        >,
     ) -> Self {
         self.outpoints = Box::new(outpoints.into_iter());
         self
@@ -79,11 +81,11 @@ impl SyncRequest {
     pub fn chain_spks(
         mut self,
         spks: impl IntoIterator<
-            IntoIter = impl Iterator<Item = ScriptBuf> + Send + 'static,
+            IntoIter = impl ExactSizeIterator<Item = ScriptBuf> + Send + 'static,
             Item = ScriptBuf,
         >,
     ) -> Self {
-        self.spks = Box::new(self.spks.chain(spks));
+        self.spks = Box::new(ExactSizeChain::new(self.spks, spks.into_iter()));
         self
     }
 
@@ -93,9 +95,12 @@ impl SyncRequest {
     #[must_use]
     pub fn chain_txids(
         mut self,
-        txids: impl IntoIterator<IntoIter = impl Iterator<Item = Txid> + Send + 'static, Item = Txid>,
+        txids: impl IntoIterator<
+            IntoIter = impl ExactSizeIterator<Item = Txid> + Send + 'static,
+            Item = Txid,
+        >,
     ) -> Self {
-        self.txids = Box::new(self.txids.chain(txids));
+        self.txids = Box::new(ExactSizeChain::new(self.txids, txids.into_iter()));
         self
     }
 
@@ -106,39 +111,42 @@ impl SyncRequest {
     pub fn chain_outpoints(
         mut self,
         outpoints: impl IntoIterator<
-            IntoIter = impl Iterator<Item = OutPoint> + Send + 'static,
+            IntoIter = impl ExactSizeIterator<Item = OutPoint> + Send + 'static,
             Item = OutPoint,
         >,
     ) -> Self {
-        self.outpoints = Box::new(self.outpoints.chain(outpoints));
+        self.outpoints = Box::new(ExactSizeChain::new(self.outpoints, outpoints.into_iter()));
         self
     }
 
-    /// Add a closure that will be called for each [`Script`] synced in this request.
+    /// Add a closure that will be called for [`Script`]s previously added to this request.
     ///
     /// This consumes the [`SyncRequest`] and returns the updated one.
     #[must_use]
-    pub fn inspect_spks(mut self, inspect: impl Fn(&Script) + Send + Sync + 'static) -> Self {
+    pub fn inspect_spks(
+        mut self,
+        mut inspect: impl FnMut(&Script) + Send + Sync + 'static,
+    ) -> Self {
         self.spks = Box::new(self.spks.inspect(move |spk| inspect(spk)));
         self
     }
 
-    /// Add a closure that will be called for each [`Txid`] synced in this request.
+    /// Add a closure that will be called for [`Txid`]s previously added to this request.
     ///
     /// This consumes the [`SyncRequest`] and returns the updated one.
     #[must_use]
-    pub fn inspect_txids(mut self, inspect: impl Fn(&Txid) + Send + Sync + 'static) -> Self {
+    pub fn inspect_txids(mut self, mut inspect: impl FnMut(&Txid) + Send + Sync + 'static) -> Self {
         self.txids = Box::new(self.txids.inspect(move |txid| inspect(txid)));
         self
     }
 
-    /// Add a closure that will be called for each [`OutPoint`] synced in this request.
+    /// Add a closure that will be called for [`OutPoint`]s previously added to this request.
     ///
     /// This consumes the [`SyncRequest`] and returns the updated one.
     #[must_use]
     pub fn inspect_outpoints(
         mut self,
-        inspect: impl Fn(&OutPoint) + Send + Sync + 'static,
+        mut inspect: impl FnMut(&OutPoint) + Send + Sync + 'static,
     ) -> Self {
         self.outpoints = Box::new(self.outpoints.inspect(move |op| inspect(op)));
         self
@@ -313,3 +321,64 @@ pub struct FullScanResult<K> {
     /// Last active indices for the corresponding keychains (`K`).
     pub last_active_indices: BTreeMap<K, u32>,
 }
+
+/// A version of [`core::iter::Chain`] which can combine two [`ExactSizeIterator`]s to form a new
+/// [`ExactSizeIterator`].
+///
+/// The danger of this is explained in [the `ExactSizeIterator` docs]
+/// (https://doc.rust-lang.org/core/iter/trait.ExactSizeIterator.html#when-shouldnt-an-adapter-be-exactsizeiterator).
+/// This does not apply here since it would be impossible to scan an item count that overflows
+/// `usize` anyway.
+struct ExactSizeChain<A, B, I> {
+    a: Option<A>,
+    b: Option<B>,
+    i: PhantomData<I>,
+}
+
+impl<A, B, I> ExactSizeChain<A, B, I> {
+    fn new(a: A, b: B) -> Self {
+        ExactSizeChain {
+            a: Some(a),
+            b: Some(b),
+            i: PhantomData,
+        }
+    }
+}
+
+impl<A, B, I> Iterator for ExactSizeChain<A, B, I>
+where
+    A: Iterator<Item = I>,
+    B: Iterator<Item = I>,
+{
+    type Item = I;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        if let Some(a) = &mut self.a {
+            let item = a.next();
+            if item.is_some() {
+                return item;
+            }
+            self.a = None;
+        }
+        if let Some(b) = &mut self.b {
+            let item = b.next();
+            if item.is_some() {
+                return item;
+            }
+            self.b = None;
+        }
+        None
+    }
+}
+
+impl<A, B, I> ExactSizeIterator for ExactSizeChain<A, B, I>
+where
+    A: ExactSizeIterator<Item = I>,
+    B: ExactSizeIterator<Item = I>,
+{
+    fn len(&self) -> usize {
+        let a_len = self.a.as_ref().map(|a| a.len()).unwrap_or(0);
+        let b_len = self.b.as_ref().map(|a| a.len()).unwrap_or(0);
+        a_len + b_len
+    }
+}
index 46eb18b810334eb36650a4f246145401344db52a..e785bcc3bf0a6c9548d1acb6d8aeea61791a03fc 100644 (file)
@@ -248,7 +248,7 @@ fn main() -> anyhow::Result<()> {
                         .map(|(k, i, spk)| (k.to_owned(), i, spk.to_owned()))
                         .collect::<Vec<_>>();
                     request = request.chain_spks(all_spks.into_iter().map(|(k, i, spk)| {
-                        eprintln!("scanning {}:{}", k, i);
+                        eprint!("scanning {}:{}", k, i);
                         // Flush early to ensure we print at every iteration.
                         let _ = io::stderr().flush();
                         spk
@@ -262,7 +262,7 @@ fn main() -> anyhow::Result<()> {
                         .collect::<Vec<_>>();
                     request =
                         request.chain_spks(unused_spks.into_iter().map(move |(k, i, spk)| {
-                            eprintln!(
+                            eprint!(
                                 "Checking if address {} {}:{} has been used",
                                 Address::from_script(&spk, args.network).unwrap(),
                                 k,
@@ -287,7 +287,7 @@ fn main() -> anyhow::Result<()> {
                         utxos
                             .into_iter()
                             .inspect(|utxo| {
-                                eprintln!(
+                                eprint!(
                                     "Checking if outpoint {} (value: {}) has been spent",
                                     utxo.outpoint, utxo.txout.value
                                 );
@@ -308,13 +308,38 @@ fn main() -> anyhow::Result<()> {
                         .map(|canonical_tx| canonical_tx.tx_node.txid)
                         .collect::<Vec<Txid>>();
                     request = request.chain_txids(unconfirmed_txids.into_iter().inspect(|txid| {
-                        eprintln!("Checking if {} is confirmed yet", txid);
+                        eprint!("Checking if {} is confirmed yet", txid);
                         // Flush early to ensure we print at every iteration.
                         let _ = io::stderr().flush();
                     }));
                 }
             }
 
+            let total_spks = request.spks.len();
+            let total_txids = request.txids.len();
+            let total_ops = request.outpoints.len();
+            request = request
+                .inspect_spks({
+                    let mut visited = 0;
+                    move |_| {
+                        visited += 1;
+                        eprintln!(" [ {:>6.2}% ]", (visited * 100) as f32 / total_spks as f32)
+                    }
+                })
+                .inspect_txids({
+                    let mut visited = 0;
+                    move |_| {
+                        visited += 1;
+                        eprintln!(" [ {:>6.2}% ]", (visited * 100) as f32 / total_txids as f32)
+                    }
+                })
+                .inspect_outpoints({
+                    let mut visited = 0;
+                    move |_| {
+                        visited += 1;
+                        eprintln!(" [ {:>6.2}% ]", (visited * 100) as f32 / total_ops as f32)
+                    }
+                });
             let mut update = client.sync(request, scan_options.parallel_requests)?;
 
             // Update last seen unconfirmed