]> Untitled Git - bdk/commitdiff
feat(chain): add `get` and `range` methods to `CheckPoint`
author志宇 <hello@evanlinjin.me>
Wed, 6 Mar 2024 05:04:12 +0000 (13:04 +0800)
committer志宇 <hello@evanlinjin.me>
Fri, 5 Apr 2024 08:36:00 +0000 (16:36 +0800)
These methods allow us to query for checkpoints contained within the
linked list by height and height range. This is useful to determine
checkpoints to fetch for chain sources without having to refer back to
the `LocalChain`.

Currently this is not implemented efficiently, but in the future, we
will change `CheckPoint` to use a skip list structure.

crates/bdk/src/wallet/mod.rs
crates/bitcoind_rpc/tests/test_emitter.rs
crates/chain/src/local_chain.rs
crates/chain/src/tx_graph.rs
crates/chain/tests/test_indexed_tx_graph.rs
crates/chain/tests/test_local_chain.rs
crates/esplora/tests/blocking_ext.rs

index d7b825fe1c2058e4b3621d966569dde298613471..db252d3ae91595d19df620178b22e38fb3be2386 100644 (file)
@@ -1127,18 +1127,14 @@ impl<D> Wallet<D> {
                 // anchor tx to checkpoint with lowest height that is >= position's height
                 let anchor = self
                     .chain
-                    .blocks()
                     .range(height..)
-                    .next()
+                    .last()
                     .ok_or(InsertTxError::ConfirmationHeightCannotBeGreaterThanTip {
                         tip_height: self.chain.tip().height(),
                         tx_height: height,
                     })
-                    .map(|(&anchor_height, &hash)| ConfirmationTimeHeightAnchor {
-                        anchor_block: BlockId {
-                            height: anchor_height,
-                            hash,
-                        },
+                    .map(|anchor_cp| ConfirmationTimeHeightAnchor {
+                        anchor_block: anchor_cp.block_id(),
                         confirmation_height: height,
                         confirmation_time: time,
                     })?;
index 2161db0df4c2c5863265b8ebc248a025df7b4dee..97946da99def67a2d17e375c256fe8d7450003bd 100644 (file)
@@ -57,12 +57,15 @@ pub fn test_sync_local_chain() -> anyhow::Result<()> {
     }
 
     assert_eq!(
-        local_chain.blocks(),
-        &exp_hashes
+        local_chain
+            .iter_checkpoints()
+            .map(|cp| (cp.height(), cp.hash()))
+            .collect::<BTreeSet<_>>(),
+        exp_hashes
             .iter()
             .enumerate()
             .map(|(i, hash)| (i as u32, *hash))
-            .collect(),
+            .collect::<BTreeSet<_>>(),
         "final local_chain state is unexpected",
     );
 
@@ -110,12 +113,15 @@ pub fn test_sync_local_chain() -> anyhow::Result<()> {
     }
 
     assert_eq!(
-        local_chain.blocks(),
-        &exp_hashes
+        local_chain
+            .iter_checkpoints()
+            .map(|cp| (cp.height(), cp.hash()))
+            .collect::<BTreeSet<_>>(),
+        exp_hashes
             .iter()
             .enumerate()
             .map(|(i, hash)| (i as u32, *hash))
-            .collect(),
+            .collect::<BTreeSet<_>>(),
         "final local_chain state is unexpected after reorg",
     );
 
index 9be62dee386fd94342c27c176dbe691884664df9..acb5b5138a9a555d9c146975846a1ec8d934166f 100644 (file)
@@ -1,10 +1,12 @@
 //! The [`LocalChain`] is a local implementation of [`ChainOracle`].
 
 use core::convert::Infallible;
+use core::ops::RangeBounds;
 
 use crate::collections::BTreeMap;
 use crate::{BlockId, ChainOracle};
 use alloc::sync::Arc;
+use alloc::vec::Vec;
 use bitcoin::block::Header;
 use bitcoin::BlockHash;
 
@@ -148,6 +150,36 @@ impl CheckPoint {
     pub fn iter(&self) -> CheckPointIter {
         self.clone().into_iter()
     }
+
+    /// Get checkpoint at `height`.
+    ///
+    /// Returns `None` if checkpoint at `height` does not exist`.
+    pub fn get(&self, height: u32) -> Option<Self> {
+        self.range(height..=height).next()
+    }
+
+    /// Iterate checkpoints over a height range.
+    ///
+    /// Note that we always iterate checkpoints in reverse height order (iteration starts at tip
+    /// height).
+    pub fn range<R>(&self, range: R) -> impl Iterator<Item = CheckPoint>
+    where
+        R: RangeBounds<u32>,
+    {
+        let start_bound = range.start_bound().cloned();
+        let end_bound = range.end_bound().cloned();
+        self.iter()
+            .skip_while(move |cp| match end_bound {
+                core::ops::Bound::Included(inc_bound) => cp.height() > inc_bound,
+                core::ops::Bound::Excluded(exc_bound) => cp.height() >= exc_bound,
+                core::ops::Bound::Unbounded => false,
+            })
+            .take_while(move |cp| match start_bound {
+                core::ops::Bound::Included(inc_bound) => cp.height() >= inc_bound,
+                core::ops::Bound::Excluded(exc_bound) => cp.height() > exc_bound,
+                core::ops::Bound::Unbounded => true,
+            })
+    }
 }
 
 /// Iterates over checkpoints backwards.
@@ -205,18 +237,28 @@ pub struct Update {
 #[derive(Debug, Clone)]
 pub struct LocalChain {
     tip: CheckPoint,
-    index: BTreeMap<u32, BlockHash>,
 }
 
 impl PartialEq for LocalChain {
     fn eq(&self, other: &Self) -> bool {
-        self.index == other.index
+        self.iter_checkpoints()
+            .map(|cp| cp.block_id())
+            .collect::<Vec<_>>()
+            == other
+                .iter_checkpoints()
+                .map(|cp| cp.block_id())
+                .collect::<Vec<_>>()
     }
 }
 
+// TODO: Figure out whether we can get rid of this
 impl From<LocalChain> for BTreeMap<u32, BlockHash> {
     fn from(value: LocalChain) -> Self {
-        value.index
+        value
+            .tip
+            .iter()
+            .map(|cp| (cp.height(), cp.hash()))
+            .collect()
     }
 }
 
@@ -228,18 +270,16 @@ impl ChainOracle for LocalChain {
         block: BlockId,
         chain_tip: BlockId,
     ) -> Result<Option<bool>, Self::Error> {
-        if block.height > chain_tip.height {
-            return Ok(None);
+        let chain_tip_cp = match self.tip.get(chain_tip.height) {
+            // we can only determine whether `block` is in chain of `chain_tip` if `chain_tip` can
+            // be identified in chain
+            Some(cp) if cp.hash() == chain_tip.hash => cp,
+            _ => return Ok(None),
+        };
+        match chain_tip_cp.get(block.height) {
+            Some(cp) => Ok(Some(cp.hash() == block.hash)),
+            None => Ok(None),
         }
-        Ok(
-            match (
-                self.index.get(&block.height),
-                self.index.get(&chain_tip.height),
-            ) {
-                (Some(cp), Some(tip_cp)) => Some(*cp == block.hash && *tip_cp == chain_tip.hash),
-                _ => None,
-            },
-        )
     }
 
     fn get_chain_tip(&self) -> Result<BlockId, Self::Error> {
@@ -250,7 +290,7 @@ impl ChainOracle for LocalChain {
 impl LocalChain {
     /// Get the genesis hash.
     pub fn genesis_hash(&self) -> BlockHash {
-        self.index.get(&0).copied().expect("must have genesis hash")
+        self.tip.get(0).expect("genesis must exist").hash()
     }
 
     /// Construct [`LocalChain`] from genesis `hash`.
@@ -259,7 +299,6 @@ impl LocalChain {
         let height = 0;
         let chain = Self {
             tip: CheckPoint::new(BlockId { height, hash }),
-            index: core::iter::once((height, hash)).collect(),
         };
         let changeset = chain.initial_changeset();
         (chain, changeset)
@@ -276,7 +315,6 @@ impl LocalChain {
         let (mut chain, _) = Self::from_genesis_hash(genesis_hash);
         chain.apply_changeset(&changeset)?;
 
-        debug_assert!(chain._check_index_is_consistent_with_tip());
         debug_assert!(chain._check_changeset_is_applied(&changeset));
 
         Ok(chain)
@@ -284,18 +322,11 @@ impl LocalChain {
 
     /// Construct a [`LocalChain`] from a given `checkpoint` tip.
     pub fn from_tip(tip: CheckPoint) -> Result<Self, MissingGenesisError> {
-        let mut chain = Self {
-            tip,
-            index: BTreeMap::new(),
-        };
-        chain.reindex(0);
-
-        if chain.index.get(&0).copied().is_none() {
+        let genesis_cp = tip.iter().last().expect("must have at least one element");
+        if genesis_cp.height() != 0 {
             return Err(MissingGenesisError);
         }
-
-        debug_assert!(chain._check_index_is_consistent_with_tip());
-        Ok(chain)
+        Ok(Self { tip })
     }
 
     /// Constructs a [`LocalChain`] from a [`BTreeMap`] of height to [`BlockHash`].
@@ -303,12 +334,11 @@ impl LocalChain {
     /// The [`BTreeMap`] enforces the height order. However, the caller must ensure the blocks are
     /// all of the same chain.
     pub fn from_blocks(blocks: BTreeMap<u32, BlockHash>) -> Result<Self, MissingGenesisError> {
-        if !blocks.contains_key(&0) {
+        if blocks.get(&0).is_none() {
             return Err(MissingGenesisError);
         }
 
         let mut tip: Option<CheckPoint> = None;
-
         for block in &blocks {
             match tip {
                 Some(curr) => {
@@ -321,13 +351,9 @@ impl LocalChain {
             }
         }
 
-        let chain = Self {
-            index: blocks,
+        Ok(Self {
             tip: tip.expect("already checked to have genesis"),
-        };
-
-        debug_assert!(chain._check_index_is_consistent_with_tip());
-        Ok(chain)
+        })
     }
 
     /// Get the highest checkpoint.
@@ -494,9 +520,7 @@ impl LocalChain {
                 None => LocalChain::from_blocks(extension)?.tip(),
             };
             self.tip = new_tip;
-            self.reindex(start_height);
 
-            debug_assert!(self._check_index_is_consistent_with_tip());
             debug_assert!(self._check_changeset_is_applied(changeset));
         }
 
@@ -509,16 +533,16 @@ impl LocalChain {
     ///
     /// Replacing the block hash of an existing checkpoint will result in an error.
     pub fn insert_block(&mut self, block_id: BlockId) -> Result<ChangeSet, AlterCheckPointError> {
-        if let Some(&original_hash) = self.index.get(&block_id.height) {
+        if let Some(original_cp) = self.tip.get(block_id.height) {
+            let original_hash = original_cp.hash();
             if original_hash != block_id.hash {
                 return Err(AlterCheckPointError {
                     height: block_id.height,
                     original_hash,
                     update_hash: Some(block_id.hash),
                 });
-            } else {
-                return Ok(ChangeSet::default());
             }
+            return Ok(ChangeSet::default());
         }
 
         let mut changeset = ChangeSet::default();
@@ -542,33 +566,41 @@ impl LocalChain {
     /// This will fail with [`MissingGenesisError`] if the caller attempts to disconnect from the
     /// genesis block.
     pub fn disconnect_from(&mut self, block_id: BlockId) -> Result<ChangeSet, MissingGenesisError> {
-        if self.index.get(&block_id.height) != Some(&block_id.hash) {
-            return Ok(ChangeSet::default());
-        }
-
-        let changeset = self
-            .index
-            .range(block_id.height..)
-            .map(|(&height, _)| (height, None))
-            .collect::<ChangeSet>();
-        self.apply_changeset(&changeset).map(|_| changeset)
-    }
-
-    /// Reindex the heights in the chain from (and including) `from` height
-    fn reindex(&mut self, from: u32) {
-        let _ = self.index.split_off(&from);
-        for cp in self.iter_checkpoints() {
-            if cp.height() < from {
+        let mut remove_from = Option::<CheckPoint>::None;
+        let mut changeset = ChangeSet::default();
+        for cp in self.tip().iter() {
+            let cp_id = cp.block_id();
+            if cp_id.height < block_id.height {
                 break;
             }
-            self.index.insert(cp.height(), cp.hash());
+            changeset.insert(cp_id.height, None);
+            if cp_id == block_id {
+                remove_from = Some(cp);
+            }
         }
+        self.tip = match remove_from.map(|cp| cp.prev()) {
+            // The checkpoint below the earliest checkpoint to remove will be the new tip.
+            Some(Some(new_tip)) => new_tip,
+            // If there is no checkpoint below the earliest checkpoint to remove, it means the
+            // "earliest checkpoint to remove" is the genesis block. We disallow removing the
+            // genesis block.
+            Some(None) => return Err(MissingGenesisError),
+            // If there is nothing to remove, we return an empty changeset.
+            None => return Ok(ChangeSet::default()),
+        };
+        Ok(changeset)
     }
 
     /// Derives an initial [`ChangeSet`], meaning that it can be applied to an empty chain to
     /// recover the current chain.
     pub fn initial_changeset(&self) -> ChangeSet {
-        self.index.iter().map(|(k, v)| (*k, Some(*v))).collect()
+        self.tip
+            .iter()
+            .map(|cp| {
+                let block_id = cp.block_id();
+                (block_id.height, Some(block_id.hash))
+            })
+            .collect()
     }
 
     /// Iterate over checkpoints in descending height order.
@@ -578,28 +610,49 @@ impl LocalChain {
         }
     }
 
-    /// Get a reference to the internal index mapping the height to block hash.
-    pub fn blocks(&self) -> &BTreeMap<u32, BlockHash> {
-        &self.index
-    }
-
-    fn _check_index_is_consistent_with_tip(&self) -> bool {
-        let tip_history = self
-            .tip
-            .iter()
-            .map(|cp| (cp.height(), cp.hash()))
-            .collect::<BTreeMap<_, _>>();
-        self.index == tip_history
-    }
-
     fn _check_changeset_is_applied(&self, changeset: &ChangeSet) -> bool {
-        for (height, exp_hash) in changeset {
-            if self.index.get(height) != exp_hash.as_ref() {
-                return false;
+        let mut curr_cp = self.tip.clone();
+        for (height, exp_hash) in changeset.iter().rev() {
+            match curr_cp.get(*height) {
+                Some(query_cp) => {
+                    if query_cp.height() != *height || Some(query_cp.hash()) != *exp_hash {
+                        return false;
+                    }
+                    curr_cp = query_cp;
+                }
+                None => {
+                    if exp_hash.is_some() {
+                        return false;
+                    }
+                }
             }
         }
         true
     }
+
+    /// Get checkpoint at given `height` (if it exists).
+    ///
+    /// This is a shorthand for calling [`CheckPoint::get`] on the [`tip`].
+    ///
+    /// [`tip`]: LocalChain::tip
+    pub fn get(&self, height: u32) -> Option<CheckPoint> {
+        self.tip.get(height)
+    }
+
+    /// Iterate checkpoints over a height range.
+    ///
+    /// Note that we always iterate checkpoints in reverse height order (iteration starts at tip
+    /// height).
+    ///
+    /// This is a shorthand for calling [`CheckPoint::range`] on the [`tip`].
+    ///
+    /// [`tip`]: LocalChain::tip
+    pub fn range<R>(&self, range: R) -> impl Iterator<Item = CheckPoint>
+    where
+        R: RangeBounds<u32>,
+    {
+        self.tip.range(range)
+    }
 }
 
 /// An error which occurs when a [`LocalChain`] is constructed without a genesis checkpoint.
index 30d020ecb3f4e65f4a6182bcbb54b3cf085c5b7b..d951d2d31b13b4a6d0ff70c40fd372cf69118b16 100644 (file)
@@ -725,13 +725,13 @@ impl<A: Anchor> TxGraph<A> {
                     };
                     let mut has_missing_height = false;
                     for anchor_block in tx_anchors.iter().map(Anchor::anchor_block) {
-                        match chain.blocks().get(&anchor_block.height) {
+                        match chain.get(anchor_block.height) {
                             None => {
                                 has_missing_height = true;
                                 continue;
                             }
-                            Some(chain_hash) => {
-                                if chain_hash == &anchor_block.hash {
+                            Some(chain_cp) => {
+                                if chain_cp.hash() == anchor_block.hash {
                                     return true;
                                 }
                             }
@@ -749,7 +749,7 @@ impl<A: Anchor> TxGraph<A> {
             .filter_map(move |(a, _)| {
                 let anchor_block = a.anchor_block();
                 if Some(anchor_block.height) != last_height_emitted
-                    && !chain.blocks().contains_key(&anchor_block.height)
+                    && chain.get(anchor_block.height).is_none()
                 {
                     last_height_emitted = Some(anchor_block.height);
                     Some(anchor_block.height)
@@ -1299,7 +1299,7 @@ impl<A> ChangeSet<A> {
         A: Anchor,
     {
         self.anchor_heights()
-            .filter(move |height| !local_chain.blocks().contains_key(height))
+            .filter(move |&height| local_chain.get(height).is_none())
     }
 }
 
index 3fcaf2d192d16b262db7f48df701c313b5b59992..8a56db175b409d80b1e99d99ef814fb811c06da4 100644 (file)
@@ -7,7 +7,7 @@ use bdk_chain::{
     indexed_tx_graph::{self, IndexedTxGraph},
     keychain::{self, Balance, KeychainTxOutIndex},
     local_chain::LocalChain,
-    tx_graph, BlockId, ChainPosition, ConfirmationHeightAnchor,
+    tx_graph, ChainPosition, ConfirmationHeightAnchor,
 };
 use bitcoin::{secp256k1::Secp256k1, OutPoint, Script, ScriptBuf, Transaction, TxIn, TxOut};
 use miniscript::Descriptor;
@@ -212,10 +212,8 @@ fn test_list_owned_txouts() {
             (
                 *tx,
                 local_chain
-                    .blocks()
-                    .get(&height)
-                    .cloned()
-                    .map(|hash| BlockId { height, hash })
+                    .get(height)
+                    .map(|cp| cp.block_id())
                     .map(|anchor_block| ConfirmationHeightAnchor {
                         anchor_block,
                         confirmation_height: anchor_block.height,
@@ -230,9 +228,8 @@ fn test_list_owned_txouts() {
         |height: u32,
          graph: &IndexedTxGraph<ConfirmationHeightAnchor, KeychainTxOutIndex<String>>| {
             let chain_tip = local_chain
-                .blocks()
-                .get(&height)
-                .map(|&hash| BlockId { height, hash })
+                .get(height)
+                .map(|cp| cp.block_id())
                 .unwrap_or_else(|| panic!("block must exist at {}", height));
             let txouts = graph
                 .graph()
index c1a1cd7f9bf7ae2413cb645747404a0ffc58bac8..482792f5012ac2778a97e9877ab22322c7a20e6e 100644 (file)
@@ -528,6 +528,52 @@ fn checkpoint_from_block_ids() {
     }
 }
 
+#[test]
+fn checkpoint_query() {
+    struct TestCase {
+        chain: LocalChain,
+        /// The heights we want to call [`CheckPoint::query`] with, represented as an inclusive
+        /// range.
+        ///
+        /// If a [`CheckPoint`] exists at that height, we expect [`CheckPoint::query`] to return
+        /// it. If not, [`CheckPoint::query`] should return `None`.
+        query_range: (u32, u32),
+    }
+
+    let test_cases = [
+        TestCase {
+            chain: local_chain![(0, h!("_")), (1, h!("A"))],
+            query_range: (0, 2),
+        },
+        TestCase {
+            chain: local_chain![(0, h!("_")), (2, h!("B")), (3, h!("C"))],
+            query_range: (0, 3),
+        },
+    ];
+
+    for t in test_cases.into_iter() {
+        let tip = t.chain.tip();
+        for h in t.query_range.0..=t.query_range.1 {
+            let query_result = tip.get(h);
+
+            // perform an exhausitive search for the checkpoint at height `h`
+            let exp_hash = t
+                .chain
+                .iter_checkpoints()
+                .find(|cp| cp.height() == h)
+                .map(|cp| cp.hash());
+
+            match query_result {
+                Some(cp) => {
+                    assert_eq!(Some(cp.hash()), exp_hash);
+                    assert_eq!(cp.height(), h);
+                }
+                None => assert!(exp_hash.is_none()),
+            }
+        }
+    }
+}
+
 #[test]
 fn local_chain_apply_header_connected_to() {
     fn header_from_prev_blockhash(prev_blockhash: BlockHash) -> Header {
index 9e39a93c966ca7dd06b45a9cd3ea1c990afc0e3c..de0594eec4dcfde6bf5c1af48f6b3c4c1f54cce4 100644 (file)
@@ -360,8 +360,8 @@ fn update_local_chain() -> anyhow::Result<()> {
         for height in t.request_heights {
             let exp_blockhash = blocks.get(height).expect("block must exist in bitcoind");
             assert_eq!(
-                chain.blocks().get(height),
-                Some(exp_blockhash),
+                chain.get(*height).map(|cp| cp.hash()),
+                Some(*exp_blockhash),
                 "[{}:{}] block {}:{} must exist in final chain",
                 i,
                 t.name,