]> Untitled Git - bdk/commitdiff
Improve `txout` listing and balance APIs for redesigned structures
author志宇 <hello@evanlinjin.me>
Wed, 10 May 2023 06:14:29 +0000 (14:14 +0800)
committer志宇 <hello@evanlinjin.me>
Wed, 10 May 2023 15:57:24 +0000 (23:57 +0800)
Instead of relying on a `OwnedIndexer` trait to filter for relevant
txouts, we move the listing and balance methods from `IndexedTxGraph` to
`TxGraph` and add an additional input (list of relevant outpoints) to
these methods.

This provides a simpler implementation and a more flexible API.

crates/chain/src/indexed_tx_graph.rs
crates/chain/src/keychain/txout_index.rs
crates/chain/src/spk_txout_index.rs
crates/chain/src/tx_graph.rs
crates/chain/tests/test_indexed_tx_graph.rs

index 7ab0ffa8a488d69b28b72382f0a19cbcf7126452..f69b227a21f2929cc28af0a0e2df2971bf3c865b 100644 (file)
@@ -1,12 +1,9 @@
-use core::convert::Infallible;
-
 use alloc::vec::Vec;
-use bitcoin::{OutPoint, Script, Transaction, TxOut};
+use bitcoin::{OutPoint, Transaction, TxOut};
 
 use crate::{
-    keychain::Balance,
     tx_graph::{Additions, TxGraph},
-    Anchor, Append, BlockId, ChainOracle, FullTxOut, ObservedAs,
+    Anchor, Append,
 };
 
 /// A struct that combines [`TxGraph`] and an [`Indexer`] implementation.
@@ -29,6 +26,14 @@ impl<A, I: Default> Default for IndexedTxGraph<A, I> {
 }
 
 impl<A, I> IndexedTxGraph<A, I> {
+    /// Construct a new [`IndexedTxGraph`] with a given `index`.
+    pub fn new(index: I) -> Self {
+        Self {
+            index,
+            graph: TxGraph::default(),
+        }
+    }
+
     /// Get a reference of the internal transaction graph.
     pub fn graph(&self) -> &TxGraph<A> {
         &self.graph
@@ -157,115 +162,6 @@ where
     }
 }
 
-impl<A: Anchor, I: OwnedIndexer> IndexedTxGraph<A, I> {
-    pub fn try_list_owned_txouts<'a, C: ChainOracle + 'a>(
-        &'a self,
-        chain: &'a C,
-        chain_tip: BlockId,
-    ) -> impl Iterator<Item = Result<FullTxOut<ObservedAs<A>>, C::Error>> + 'a {
-        self.graph()
-            .try_list_chain_txouts(chain, chain_tip)
-            .filter(|r| {
-                if let Ok(full_txout) = r {
-                    if !self.index.is_spk_owned(&full_txout.txout.script_pubkey) {
-                        return false;
-                    }
-                }
-                true
-            })
-    }
-
-    pub fn list_owned_txouts<'a, C: ChainOracle<Error = Infallible> + 'a>(
-        &'a self,
-        chain: &'a C,
-        chain_tip: BlockId,
-    ) -> impl Iterator<Item = FullTxOut<ObservedAs<A>>> + 'a {
-        self.try_list_owned_txouts(chain, chain_tip)
-            .map(|r| r.expect("oracle is infallible"))
-    }
-
-    pub fn try_list_owned_unspents<'a, C: ChainOracle + 'a>(
-        &'a self,
-        chain: &'a C,
-        chain_tip: BlockId,
-    ) -> impl Iterator<Item = Result<FullTxOut<ObservedAs<A>>, C::Error>> + 'a {
-        self.graph()
-            .try_list_chain_unspents(chain, chain_tip)
-            .filter(|r| {
-                if let Ok(full_txout) = r {
-                    if !self.index.is_spk_owned(&full_txout.txout.script_pubkey) {
-                        return false;
-                    }
-                }
-                true
-            })
-    }
-
-    pub fn list_owned_unspents<'a, C: ChainOracle<Error = Infallible> + 'a>(
-        &'a self,
-        chain: &'a C,
-        chain_tip: BlockId,
-    ) -> impl Iterator<Item = FullTxOut<ObservedAs<A>>> + 'a {
-        self.try_list_owned_unspents(chain, chain_tip)
-            .map(|r| r.expect("oracle is infallible"))
-    }
-
-    pub fn try_balance<C, F>(
-        &self,
-        chain: &C,
-        chain_tip: BlockId,
-        mut should_trust: F,
-    ) -> Result<Balance, C::Error>
-    where
-        C: ChainOracle,
-        F: FnMut(&Script) -> bool,
-    {
-        let tip_height = chain_tip.height;
-
-        let mut immature = 0;
-        let mut trusted_pending = 0;
-        let mut untrusted_pending = 0;
-        let mut confirmed = 0;
-
-        for res in self.try_list_owned_unspents(chain, chain_tip) {
-            let txout = res?;
-
-            match &txout.chain_position {
-                ObservedAs::Confirmed(_) => {
-                    if txout.is_confirmed_and_spendable(tip_height) {
-                        confirmed += txout.txout.value;
-                    } else if !txout.is_mature(tip_height) {
-                        immature += txout.txout.value;
-                    }
-                }
-                ObservedAs::Unconfirmed(_) => {
-                    if should_trust(&txout.txout.script_pubkey) {
-                        trusted_pending += txout.txout.value;
-                    } else {
-                        untrusted_pending += txout.txout.value;
-                    }
-                }
-            }
-        }
-
-        Ok(Balance {
-            immature,
-            trusted_pending,
-            untrusted_pending,
-            confirmed,
-        })
-    }
-
-    pub fn balance<C, F>(&self, chain: &C, chain_tip: BlockId, should_trust: F) -> Balance
-    where
-        C: ChainOracle<Error = Infallible>,
-        F: FnMut(&Script) -> bool,
-    {
-        self.try_balance(chain, chain_tip, should_trust)
-            .expect("error is infallible")
-    }
-}
-
 /// A structure that represents changes to an [`IndexedTxGraph`].
 #[derive(Clone, Debug, PartialEq)]
 #[cfg_attr(
@@ -324,9 +220,3 @@ pub trait Indexer {
     /// Determines whether the transaction should be included in the index.
     fn is_tx_relevant(&self, tx: &Transaction) -> bool;
 }
-
-/// A trait that extends [`Indexer`] to also index "owned" script pubkeys.
-pub trait OwnedIndexer: Indexer {
-    /// Determines whether a given script pubkey (`spk`) is owned.
-    fn is_spk_owned(&self, spk: &Script) -> bool;
-}
index c7a8dd54b42746aff83608a42307dfa0766a16ae..397c43386d2dee8150a4ceaac2e42ecf2b506117 100644 (file)
@@ -1,6 +1,6 @@
 use crate::{
     collections::*,
-    indexed_tx_graph::{Indexer, OwnedIndexer},
+    indexed_tx_graph::Indexer,
     miniscript::{Descriptor, DescriptorPublicKey},
     spk_iter::BIP32_MAX_INDEX,
     ForEachTxOut, SpkIterator, SpkTxOutIndex,
@@ -109,12 +109,6 @@ impl<K: Clone + Ord + Debug> Indexer for KeychainTxOutIndex<K> {
     }
 }
 
-impl<K: Clone + Ord + Debug> OwnedIndexer for KeychainTxOutIndex<K> {
-    fn is_spk_owned(&self, spk: &Script) -> bool {
-        self.index_of_spk(spk).is_some()
-    }
-}
-
 impl<K: Clone + Ord + Debug> KeychainTxOutIndex<K> {
     /// Scans an object for relevant outpoints, which are stored and indexed internally.
     ///
@@ -153,6 +147,11 @@ impl<K: Clone + Ord + Debug> KeychainTxOutIndex<K> {
         &self.inner
     }
 
+    /// Get a reference to the set of indexed outpoints.
+    pub fn outpoints(&self) -> &BTreeSet<((K, u32), OutPoint)> {
+        self.inner.outpoints()
+    }
+
     /// Return a reference to the internal map of the keychain to descriptors.
     pub fn keychains(&self) -> &BTreeMap<K, Descriptor<DescriptorPublicKey>> {
         &self.keychains
index ae94414921c103635ea71c4232f96f3de06efcca..0eaec4bb790b6a8e453dc5ee238942116d5c94df 100644 (file)
@@ -2,7 +2,7 @@ use core::ops::RangeBounds;
 
 use crate::{
     collections::{hash_map::Entry, BTreeMap, BTreeSet, HashMap},
-    indexed_tx_graph::{Indexer, OwnedIndexer},
+    indexed_tx_graph::Indexer,
     ForEachTxOut,
 };
 use bitcoin::{self, OutPoint, Script, Transaction, TxOut, Txid};
@@ -75,12 +75,6 @@ impl<I: Clone + Ord> Indexer for SpkTxOutIndex<I> {
     }
 }
 
-impl<I: Clone + Ord + 'static> OwnedIndexer for SpkTxOutIndex<I> {
-    fn is_spk_owned(&self, spk: &Script) -> bool {
-        self.spk_indices.get(spk).is_some()
-    }
-}
-
 /// This macro is used instead of a member function of `SpkTxOutIndex`, which would result in a
 /// compiler error[E0521]: "borrowed data escapes out of closure" when we attempt to take a
 /// reference out of the `ForEachTxOut` closure during scanning.
@@ -126,6 +120,11 @@ impl<I: Clone + Ord> SpkTxOutIndex<I> {
         scan_txout!(self, op, txout)
     }
 
+    /// Get a reference to the set of indexed outpoints.
+    pub fn outpoints(&self) -> &BTreeSet<(I, OutPoint)> {
+        &self.spk_txouts
+    }
+
     /// Iterate over all known txouts that spend to tracked script pubkeys.
     pub fn txouts(
         &self,
index ef3f3847cebed66ba9ac2401f78da5cde8066dc5..f0a095ace64ec6677d8da9f93a13981bab517148 100644 (file)
 //! ```
 
 use crate::{
-    collections::*, Anchor, Append, BlockId, ChainOracle, ForEachTxOut, FullTxOut, ObservedAs,
+    collections::*, keychain::Balance, Anchor, Append, BlockId, ChainOracle, ForEachTxOut,
+    FullTxOut, ObservedAs,
 };
 use alloc::vec::Vec;
-use bitcoin::{OutPoint, Transaction, TxOut, Txid};
+use bitcoin::{OutPoint, Script, Transaction, TxOut, Txid};
 use core::{
     convert::Infallible,
     ops::{Deref, RangeInclusive},
@@ -762,107 +763,190 @@ impl<A: Anchor> TxGraph<A> {
             .map(|r| r.expect("oracle is infallible"))
     }
 
-    /// List outputs that are in `chain` with `chain_tip`.
+    /// Get a filtered list of outputs from the given `outpoints` that are in `chain` with
+    /// `chain_tip`.
     ///
-    /// Floating ouputs are not iterated over.
+    /// `outpoints` is a list of outpoints we are interested in, coupled with the associated txout's
+    /// script pubkey index (`S`).
     ///
-    /// The `filter_predicate` should return true for outputs that we wish to iterate over.
+    /// Floating outputs are ignored.
     ///
     /// # Error
     ///
-    /// A returned item can error if the [`ChainOracle`] implementation (`chain`) fails.
+    /// An [`Iterator::Item`] can be an [`Err`] if the [`ChainOracle`] implementation (`chain`)
+    /// fails.
     ///
-    /// If the [`ChainOracle`] is infallible, [`list_chain_txouts`] can be used instead.
+    /// If the [`ChainOracle`] implementation is infallible, [`filter_chain_txouts`] can be used
+    /// instead.
     ///
-    /// [`list_chain_txouts`]: Self::list_chain_txouts
-    pub fn try_list_chain_txouts<'a, C: ChainOracle + 'a>(
+    /// [`filter_chain_txouts`]: Self::filter_chain_txouts
+    pub fn try_filter_chain_txouts<'a, C: ChainOracle + 'a, S: Clone + 'a>(
         &'a self,
         chain: &'a C,
         chain_tip: BlockId,
-    ) -> impl Iterator<Item = Result<FullTxOut<ObservedAs<A>>, C::Error>> + 'a {
-        self.try_list_chain_txs(chain, chain_tip)
-            .flat_map(move |tx_res| match tx_res {
-                Ok(canonical_tx) => canonical_tx
-                    .node
-                    .output
-                    .iter()
-                    .enumerate()
-                    .map(|(vout, txout)| {
-                        let outpoint = OutPoint::new(canonical_tx.node.txid, vout as _);
-                        Ok((outpoint, txout.clone(), canonical_tx.clone()))
-                    })
-                    .collect::<Vec<_>>(),
-                Err(err) => vec![Err(err)],
-            })
-            .map(move |res| -> Result<_, C::Error> {
-                let (
-                    outpoint,
-                    txout,
-                    CanonicalTx {
-                        observed_as,
-                        node: tx_node,
-                    },
-                ) = res?;
-                let chain_position = observed_as.cloned();
-                let spent_by = self
-                    .try_get_chain_spend(chain, chain_tip, outpoint)?
-                    .map(|(obs_as, txid)| (obs_as.cloned(), txid));
-                let is_on_coinbase = tx_node.tx.is_coin_base();
-                Ok(FullTxOut {
-                    outpoint,
-                    txout,
-                    chain_position,
-                    spent_by,
-                    is_on_coinbase,
-                })
-            })
+        outpoints: impl IntoIterator<Item = (S, OutPoint)> + 'a,
+    ) -> impl Iterator<Item = Result<(S, FullTxOut<ObservedAs<A>>), C::Error>> + 'a {
+        outpoints
+            .into_iter()
+            .map(
+                move |(spk_i, op)| -> Result<Option<(S, FullTxOut<_>)>, C::Error> {
+                    let tx_node = match self.get_tx_node(op.txid) {
+                        Some(n) => n,
+                        None => return Ok(None),
+                    };
+
+                    let txout = match tx_node.tx.output.get(op.vout as usize) {
+                        Some(txout) => txout.clone(),
+                        None => return Ok(None),
+                    };
+
+                    let chain_position =
+                        match self.try_get_chain_position(chain, chain_tip, op.txid)? {
+                            Some(pos) => pos.cloned(),
+                            None => return Ok(None),
+                        };
+
+                    let spent_by = self
+                        .try_get_chain_spend(chain, chain_tip, op)?
+                        .map(|(a, txid)| (a.cloned(), txid));
+
+                    Ok(Some((
+                        spk_i,
+                        FullTxOut {
+                            outpoint: op,
+                            txout,
+                            chain_position,
+                            spent_by,
+                            is_on_coinbase: tx_node.tx.is_coin_base(),
+                        },
+                    )))
+                },
+            )
+            .filter_map(Result::transpose)
     }
 
-    /// List outputs that are in `chain` with `chain_tip`.
+    /// Get a filtered list of outputs from the given `outpoints` that are in `chain` with
+    /// `chain_tip`.
     ///
-    /// This is the infallible version of [`try_list_chain_txouts`].
+    /// This is the infallible version of [`try_filter_chain_txouts`].
     ///
-    /// [`try_list_chain_txouts`]: Self::try_list_chain_txouts
-    pub fn list_chain_txouts<'a, C: ChainOracle<Error = Infallible> + 'a>(
+    /// [`try_filter_chain_txouts`]: Self::try_filter_chain_txouts
+    pub fn filter_chain_txouts<'a, C: ChainOracle<Error = Infallible> + 'a, S: Clone + 'a>(
         &'a self,
         chain: &'a C,
         chain_tip: BlockId,
-    ) -> impl Iterator<Item = FullTxOut<ObservedAs<A>>> + 'a {
-        self.try_list_chain_txouts(chain, chain_tip)
-            .map(|r| r.expect("error in infallible"))
+        outpoints: impl IntoIterator<Item = (S, OutPoint)> + 'a,
+    ) -> impl Iterator<Item = (S, FullTxOut<ObservedAs<A>>)> + 'a {
+        self.try_filter_chain_txouts(chain, chain_tip, outpoints)
+            .map(|r| r.expect("oracle is infallible"))
     }
 
-    /// List unspent outputs (UTXOs) that are in `chain` with `chain_tip`.
+    /// Get a filtered list of unspent outputs (UTXOs) from the given `outpoints` that are in
+    /// `chain` with `chain_tip`.
+    ///
+    /// `outpoints` is a list of outpoints we are interested in, coupled with the associated txout's
+    /// script pubkey index (`S`).
     ///
-    /// Floating outputs are not iterated over.
+    /// Floating outputs are ignored.
     ///
     /// # Error
     ///
-    /// An item can be an error if the [`ChainOracle`] implementation fails. If the oracle is
-    /// infallible, [`list_chain_unspents`] can be used instead.
+    /// An [`Iterator::Item`] can be an [`Err`] if the [`ChainOracle`] implementation (`chain`)
+    /// fails.
     ///
-    /// [`list_chain_unspents`]: Self::list_chain_unspents
-    pub fn try_list_chain_unspents<'a, C: ChainOracle + 'a>(
+    /// If the [`ChainOracle`] implementation is infallible, [`filter_chain_unspents`] can be used
+    /// instead.
+    ///
+    /// [`filter_chain_unspents`]: Self::filter_chain_unspents
+    pub fn try_filter_chain_unspents<'a, C: ChainOracle + 'a, S: Clone + 'a>(
         &'a self,
         chain: &'a C,
         chain_tip: BlockId,
-    ) -> impl Iterator<Item = Result<FullTxOut<ObservedAs<A>>, C::Error>> + 'a {
-        self.try_list_chain_txouts(chain, chain_tip)
-            .filter(|r| matches!(r, Ok(txo) if txo.spent_by.is_none()))
+        outpoints: impl IntoIterator<Item = (S, OutPoint)> + 'a,
+    ) -> impl Iterator<Item = Result<(S, FullTxOut<ObservedAs<A>>), C::Error>> + 'a {
+        self.try_filter_chain_txouts(chain, chain_tip, outpoints)
+            .filter(|r| !matches!(r, Ok((_, full_txo)) if full_txo.spent_by.is_some()))
     }
 
-    /// List unspent outputs (UTXOs) that are in `chain` with `chain_tip`.
+    /// Get a filtered list of unspent outputs (UTXOs) from the given `outpoints` that are in
+    /// `chain` with `chain_tip`.
     ///
-    /// This is the infallible version of [`try_list_chain_unspents`].
+    /// This is the infallible version of [`try_filter_chain_unspents`].
     ///
-    /// [`try_list_chain_unspents`]: Self::try_list_chain_unspents
-    pub fn list_chain_unspents<'a, C: ChainOracle<Error = Infallible> + 'a>(
+    /// [`try_filter_chain_unspents`]: Self::try_filter_chain_unspents
+    pub fn filter_chain_unspents<'a, C: ChainOracle<Error = Infallible> + 'a, S: Clone + 'a>(
         &'a self,
         chain: &'a C,
-        static_block: BlockId,
-    ) -> impl Iterator<Item = FullTxOut<ObservedAs<A>>> + 'a {
-        self.try_list_chain_unspents(chain, static_block)
-            .map(|r| r.expect("error is infallible"))
+        chain_tip: BlockId,
+        txouts: impl IntoIterator<Item = (S, OutPoint)> + 'a,
+    ) -> impl Iterator<Item = (S, FullTxOut<ObservedAs<A>>)> + 'a {
+        self.try_filter_chain_unspents(chain, chain_tip, txouts)
+            .map(|r| r.expect("oracle is infallible"))
+    }
+
+    /// Get the total balance of `outpoints` that are in `chain` of `chain_tip`.
+    ///
+    /// The output of `trust_predicate` should return `true` for scripts that we trust.
+    ///
+    /// If the provided [`ChainOracle`] implementation (`chain`) is infallible, [`balance`] can be
+    /// used instead.
+    ///
+    /// [`balance`]: Self::balance
+    pub fn try_balance<C: ChainOracle, S: Clone>(
+        &self,
+        chain: &C,
+        chain_tip: BlockId,
+        outpoints: impl IntoIterator<Item = (S, OutPoint)>,
+        mut trust_predicate: impl FnMut(&S, &Script) -> bool,
+    ) -> Result<Balance, C::Error> {
+        let mut immature = 0;
+        let mut trusted_pending = 0;
+        let mut untrusted_pending = 0;
+        let mut confirmed = 0;
+
+        for res in self.try_filter_chain_unspents(chain, chain_tip, outpoints) {
+            let (spk_i, txout) = res?;
+
+            match &txout.chain_position {
+                ObservedAs::Confirmed(_) => {
+                    if txout.is_confirmed_and_spendable(chain_tip.height) {
+                        confirmed += txout.txout.value;
+                    } else if !txout.is_mature(chain_tip.height) {
+                        immature += txout.txout.value;
+                    }
+                }
+                ObservedAs::Unconfirmed(_) => {
+                    if trust_predicate(&spk_i, &txout.txout.script_pubkey) {
+                        trusted_pending += txout.txout.value;
+                    } else {
+                        untrusted_pending += txout.txout.value;
+                    }
+                }
+            }
+        }
+
+        Ok(Balance {
+            immature,
+            trusted_pending,
+            untrusted_pending,
+            confirmed,
+        })
+    }
+
+    /// Get the total balance of `outpoints` that are in `chain` of `chain_tip`.
+    ///
+    /// This is the infallible version of [`try_balance`].
+    ///
+    /// [`try_balance`]: Self::try_balance
+    pub fn balance<C: ChainOracle<Error = Infallible>, S: Clone>(
+        &self,
+        chain: &C,
+        chain_tip: BlockId,
+        outpoints: impl IntoIterator<Item = (S, OutPoint)>,
+        trust_predicate: impl FnMut(&S, &Script) -> bool,
+    ) -> Balance {
+        self.try_balance(chain, chain_tip, outpoints, trust_predicate)
+            .expect("oracle is infallible")
     }
 }
 
index f32ffe4f0b28dbec2a5316f1460e1d062f842db9..f231f76835fd29a9e0fd4fa78027dae9077ab038 100644 (file)
@@ -236,23 +236,36 @@ fn test_list_owned_txouts() {
                 .map(|&hash| BlockId { height, hash })
                 .expect("block must exist");
             let txouts = graph
-                .list_owned_txouts(&local_chain, chain_tip)
+                .graph()
+                .filter_chain_txouts(
+                    &local_chain,
+                    chain_tip,
+                    graph.index.outpoints().iter().cloned(),
+                )
                 .collect::<Vec<_>>();
 
             let utxos = graph
-                .list_owned_unspents(&local_chain, chain_tip)
+                .graph()
+                .filter_chain_unspents(
+                    &local_chain,
+                    chain_tip,
+                    graph.index.outpoints().iter().cloned(),
+                )
                 .collect::<Vec<_>>();
 
-            let balance = graph.balance(&local_chain, chain_tip, |spk: &Script| {
-                trusted_spks.contains(spk)
-            });
+            let balance = graph.graph().balance(
+                &local_chain,
+                chain_tip,
+                graph.index.outpoints().iter().cloned(),
+                |_, spk: &Script| trusted_spks.contains(spk),
+            );
 
             assert_eq!(txouts.len(), 5);
             assert_eq!(utxos.len(), 4);
 
             let confirmed_txouts_txid = txouts
                 .iter()
-                .filter_map(|full_txout| {
+                .filter_map(|(_, full_txout)| {
                     if matches!(full_txout.chain_position, ObservedAs::Confirmed(_)) {
                         Some(full_txout.outpoint.txid)
                     } else {
@@ -263,7 +276,7 @@ fn test_list_owned_txouts() {
 
             let unconfirmed_txouts_txid = txouts
                 .iter()
-                .filter_map(|full_txout| {
+                .filter_map(|(_, full_txout)| {
                     if matches!(full_txout.chain_position, ObservedAs::Unconfirmed(_)) {
                         Some(full_txout.outpoint.txid)
                     } else {
@@ -274,7 +287,7 @@ fn test_list_owned_txouts() {
 
             let confirmed_utxos_txid = utxos
                 .iter()
-                .filter_map(|full_txout| {
+                .filter_map(|(_, full_txout)| {
                     if matches!(full_txout.chain_position, ObservedAs::Confirmed(_)) {
                         Some(full_txout.outpoint.txid)
                     } else {
@@ -285,7 +298,7 @@ fn test_list_owned_txouts() {
 
             let unconfirmed_utxos_txid = utxos
                 .iter()
-                .filter_map(|full_txout| {
+                .filter_map(|(_, full_txout)| {
                     if matches!(full_txout.chain_position, ObservedAs::Unconfirmed(_)) {
                         Some(full_txout.outpoint.txid)
                     } else {