]> Untitled Git - bdk/commitdiff
[bdk_chain_redesign] Add balance methods to `IndexedTxGraph`
author志宇 <hello@evanlinjin.me>
Mon, 27 Mar 2023 11:55:57 +0000 (19:55 +0800)
committer志宇 <hello@evanlinjin.me>
Mon, 27 Mar 2023 11:55:57 +0000 (19:55 +0800)
crates/chain/src/chain_data.rs
crates/chain/src/indexed_tx_graph.rs

index 43eb64f6e63ed4ed0f7617cb43fbe8bcc39972aa..df5a5e9c76071337d509ef1e116863a19b249706 100644 (file)
@@ -14,6 +14,15 @@ pub enum ObservedIn<A> {
     Mempool(u64),
 }
 
+impl<A: Clone> ObservedIn<&A> {
+    pub fn into_owned(self) -> ObservedIn<A> {
+        match self {
+            ObservedIn::Block(a) => ObservedIn::Block(a.clone()),
+            ObservedIn::Mempool(last_seen) => ObservedIn::Mempool(last_seen),
+        }
+    }
+}
+
 impl ChainPosition for ObservedIn<BlockId> {
     fn height(&self) -> TxHeight {
         match self {
@@ -259,4 +268,16 @@ impl<I: ChainPosition> FullTxOut<I> {
     }
 }
 
+impl<A: Clone> FullTxOut<ObservedIn<&A>> {
+    pub fn into_owned(self) -> FullTxOut<ObservedIn<A>> {
+        FullTxOut {
+            outpoint: self.outpoint,
+            txout: self.txout,
+            chain_position: self.chain_position.into_owned(),
+            spent_by: self.spent_by.map(|(o, txid)| (o.into_owned(), txid)),
+            is_on_coinbase: self.is_on_coinbase,
+        }
+    }
+}
+
 // TODO: make test
index 5071fb2c76afbec582b9fe66e985068522250858..5361437e1471f01c4995d036caae33c6697bea98 100644 (file)
@@ -4,6 +4,7 @@ use alloc::collections::BTreeSet;
 use bitcoin::{OutPoint, Transaction, TxOut};
 
 use crate::{
+    keychain::Balance,
     sparse_chain::ChainPosition,
     tx_graph::{Additions, TxGraph, TxInGraph},
     BlockAnchor, ChainOracle, FullTxOut, ObservedIn, TxIndex, TxIndexAdditions,
@@ -260,4 +261,86 @@ impl<A: BlockAnchor, I: TxIndex> IndexedTxGraph<A, I> {
         self.try_list_chain_utxos(chain)
             .map(|r| r.expect("error is infallible"))
     }
+
+    pub fn try_balance<C, F>(
+        &self,
+        chain: C,
+        tip: u32,
+        mut should_trust: F,
+    ) -> Result<Balance, C::Error>
+    where
+        C: ChainOracle,
+        ObservedIn<A>: ChainPosition + Clone,
+        F: FnMut(&I::SpkIndex) -> bool,
+    {
+        let mut immature = 0;
+        let mut trusted_pending = 0;
+        let mut untrusted_pending = 0;
+        let mut confirmed = 0;
+
+        for res in self.try_list_chain_txouts(&chain) {
+            let TxOutInChain { spk_index, txout } = res?;
+            let txout = txout.into_owned();
+
+            match &txout.chain_position {
+                ObservedIn::Block(_) => {
+                    if txout.is_on_coinbase {
+                        if txout.is_mature(tip) {
+                            confirmed += txout.txout.value;
+                        } else {
+                            immature += txout.txout.value;
+                        }
+                    }
+                }
+                ObservedIn::Mempool(_) => {
+                    if should_trust(spk_index) {
+                        trusted_pending += txout.txout.value;
+                    } else {
+                        untrusted_pending += txout.txout.value;
+                    }
+                }
+            }
+        }
+
+        Ok(Balance {
+            immature,
+            trusted_pending,
+            untrusted_pending,
+            confirmed,
+        })
+    }
+
+    pub fn balance<C, F>(&self, chain: C, tip: u32, should_trust: F) -> Balance
+    where
+        C: ChainOracle<Error = Infallible>,
+        ObservedIn<A>: ChainPosition + Clone,
+        F: FnMut(&I::SpkIndex) -> bool,
+    {
+        self.try_balance(chain, tip, should_trust)
+            .expect("error is infallible")
+    }
+
+    pub fn try_balance_at<C>(&self, chain: C, height: u32) -> Result<u64, C::Error>
+    where
+        C: ChainOracle,
+        ObservedIn<A>: ChainPosition + Clone,
+    {
+        let mut sum = 0;
+        for res in self.try_list_chain_txouts(chain) {
+            let txo = res?.txout.into_owned();
+            if txo.is_spendable_at(height) {
+                sum += txo.txout.value;
+            }
+        }
+        Ok(sum)
+    }
+
+    pub fn balance_at<C>(&self, chain: C, height: u32) -> u64
+    where
+        C: ChainOracle<Error = Infallible>,
+        ObservedIn<A>: ChainPosition + Clone,
+    {
+        self.try_balance_at(chain, height)
+            .expect("error is infallible")
+    }
 }