]> Untitled Git - bdk/commitdiff
feat(chain): add `CheckPoint::from_block_ids` convenience method
author志宇 <hello@evanlinjin.me>
Wed, 10 Jan 2024 09:42:03 +0000 (17:42 +0800)
committer志宇 <hello@evanlinjin.me>
Mon, 15 Jan 2024 16:23:42 +0000 (00:23 +0800)
crates/chain/src/local_chain.rs
crates/chain/tests/test_local_chain.rs

index 32fd728522a7713cbaf8bcff407fb9b4d5d0e410..f6d8af9f6af0069de186ecba6780d0e4ed2c8fcf 100644 (file)
@@ -39,6 +39,28 @@ impl CheckPoint {
         Self(Arc::new(CPInner { block, prev: None }))
     }
 
+    /// Construct a checkpoint from a list of [`BlockId`]s in ascending height order.
+    ///
+    /// # Errors
+    ///
+    /// This method will error if any of the follow occurs:
+    ///
+    /// - The `blocks` iterator is empty, in which case, the error will be `None`.
+    /// - The `blocks` iterator is not in ascending height order.
+    /// - The `blocks` iterator contains multiple [`BlockId`]s of the same height.
+    ///
+    /// The error type is the last successful checkpoint constructed (if any).
+    pub fn from_block_ids(
+        block_ids: impl IntoIterator<Item = BlockId>,
+    ) -> Result<Self, Option<Self>> {
+        let mut blocks = block_ids.into_iter();
+        let mut acc = CheckPoint::new(blocks.next().ok_or(None)?);
+        for id in blocks {
+            acc = acc.push(id).map_err(Some)?;
+        }
+        Ok(acc)
+    }
+
     /// Construct a checkpoint from the given `header` and block `height`.
     ///
     /// If `header` is of the genesis block, the checkpoint won't have a [`prev`] node. Otherwise,
index 25cbbb08e3f2b72c1ed862c6328fb9d1c6dd3a0f..7e6f73bf2cf7f29bc6d47839e9114744b5cf496d 100644 (file)
@@ -1,5 +1,9 @@
-use bdk_chain::local_chain::{
-    AlterCheckPointError, CannotConnectError, ChangeSet, LocalChain, MissingGenesisError, Update,
+use bdk_chain::{
+    local_chain::{
+        AlterCheckPointError, CannotConnectError, ChangeSet, CheckPoint, LocalChain,
+        MissingGenesisError, Update,
+    },
+    BlockId,
 };
 use bitcoin::BlockHash;
 
@@ -423,3 +427,82 @@ fn local_chain_disconnect_from() {
         );
     }
 }
+
+#[test]
+fn checkpoint_from_block_ids() {
+    struct TestCase<'a> {
+        name: &'a str,
+        blocks: &'a [(u32, BlockHash)],
+        exp_result: Result<(), Option<(u32, BlockHash)>>,
+    }
+
+    let test_cases = [
+        TestCase {
+            name: "in_order",
+            blocks: &[(0, h!("A")), (1, h!("B")), (3, h!("D"))],
+            exp_result: Ok(()),
+        },
+        TestCase {
+            name: "with_duplicates",
+            blocks: &[(1, h!("B")), (2, h!("C")), (2, h!("C'"))],
+            exp_result: Err(Some((2, h!("C")))),
+        },
+        TestCase {
+            name: "not_in_order",
+            blocks: &[(1, h!("B")), (3, h!("D")), (2, h!("C"))],
+            exp_result: Err(Some((3, h!("D")))),
+        },
+        TestCase {
+            name: "empty",
+            blocks: &[],
+            exp_result: Err(None),
+        },
+        TestCase {
+            name: "single",
+            blocks: &[(21, h!("million"))],
+            exp_result: Ok(()),
+        },
+    ];
+
+    for (i, t) in test_cases.into_iter().enumerate() {
+        println!("running test case {}: '{}'", i, t.name);
+        let result = CheckPoint::from_block_ids(
+            t.blocks
+                .iter()
+                .map(|&(height, hash)| BlockId { height, hash }),
+        );
+        match t.exp_result {
+            Ok(_) => {
+                assert!(result.is_ok(), "[{}:{}] should be Ok", i, t.name);
+                let result_vec = {
+                    let mut v = result
+                        .unwrap()
+                        .into_iter()
+                        .map(|cp| (cp.height(), cp.hash()))
+                        .collect::<Vec<_>>();
+                    v.reverse();
+                    v
+                };
+                assert_eq!(
+                    &result_vec, t.blocks,
+                    "[{}:{}] not equal to original block ids",
+                    i, t.name
+                );
+            }
+            Err(exp_last) => {
+                assert!(result.is_err(), "[{}:{}] should be Err", i, t.name);
+                let err = result.unwrap_err();
+                assert_eq!(
+                    err.as_ref()
+                        .map(|last_cp| (last_cp.height(), last_cp.hash())),
+                    exp_last,
+                    "[{}:{}] error's last cp height should be {:?}, got {:?}",
+                    i,
+                    t.name,
+                    exp_last,
+                    err
+                );
+            }
+        }
+    }
+}