]> Untitled Git - bdk/commitdiff
feat(chain): add `LocalChain::disconnect_from` method
author志宇 <hello@evanlinjin.me>
Mon, 15 Jan 2024 09:48:36 +0000 (17:48 +0800)
committer志宇 <hello@evanlinjin.me>
Mon, 15 Jan 2024 09:48:36 +0000 (17:48 +0800)
crates/chain/src/local_chain.rs
crates/chain/tests/test_local_chain.rs

index bdd25d8e0c063c95258b1fe6f03791437803db13..32fd728522a7713cbaf8bcff407fb9b4d5d0e410 100644 (file)
@@ -420,6 +420,28 @@ impl LocalChain {
         Ok(changeset)
     }
 
+    /// Removes blocks from (and inclusive of) the given `block_id`.
+    ///
+    /// This will remove blocks with a height equal or greater than `block_id`, but only if
+    /// `block_id` exists in the chain.
+    ///
+    /// # Errors
+    ///
+    /// This will fail with [`MissingGenesisError`] if the caller attempts to disconnect from the
+    /// genesis block.
+    pub fn disconnect_from(&mut self, block_id: BlockId) -> Result<ChangeSet, MissingGenesisError> {
+        if self.index.get(&block_id.height) != Some(&block_id.hash) {
+            return Ok(ChangeSet::default());
+        }
+
+        let changeset = self
+            .index
+            .range(block_id.height..)
+            .map(|(&height, _)| (height, None))
+            .collect::<ChangeSet>();
+        self.apply_changeset(&changeset).map(|_| changeset)
+    }
+
     /// Reindex the heights in the chain from (and including) `from` height
     fn reindex(&mut self, from: u32) {
         let _ = self.index.split_off(&from);
index d09325bd98e4c063c75e9dbc2f1afaae4c243a6e..25cbbb08e3f2b72c1ed862c6328fb9d1c6dd3a0f 100644 (file)
@@ -1,5 +1,5 @@
 use bdk_chain::local_chain::{
-    AlterCheckPointError, CannotConnectError, ChangeSet, LocalChain, Update,
+    AlterCheckPointError, CannotConnectError, ChangeSet, LocalChain, MissingGenesisError, Update,
 };
 use bitcoin::BlockHash;
 
@@ -350,3 +350,76 @@ fn local_chain_insert_block() {
         assert_eq!(chain, t.expected_final, "[{}] unexpected final chain", i,);
     }
 }
+
+#[test]
+fn local_chain_disconnect_from() {
+    struct TestCase {
+        name: &'static str,
+        original: LocalChain,
+        disconnect_from: (u32, BlockHash),
+        exp_result: Result<ChangeSet, MissingGenesisError>,
+        exp_final: LocalChain,
+    }
+
+    let test_cases = [
+        TestCase {
+            name: "try_replace_genesis_should_fail",
+            original: local_chain![(0, h!("_"))],
+            disconnect_from: (0, h!("_")),
+            exp_result: Err(MissingGenesisError),
+            exp_final: local_chain![(0, h!("_"))],
+        },
+        TestCase {
+            name: "try_replace_genesis_should_fail_2",
+            original: local_chain![(0, h!("_")), (2, h!("B")), (3, h!("C"))],
+            disconnect_from: (0, h!("_")),
+            exp_result: Err(MissingGenesisError),
+            exp_final: local_chain![(0, h!("_")), (2, h!("B")), (3, h!("C"))],
+        },
+        TestCase {
+            name: "from_does_not_exist",
+            original: local_chain![(0, h!("_")), (3, h!("C"))],
+            disconnect_from: (2, h!("B")),
+            exp_result: Ok(ChangeSet::default()),
+            exp_final: local_chain![(0, h!("_")), (3, h!("C"))],
+        },
+        TestCase {
+            name: "from_has_different_blockhash",
+            original: local_chain![(0, h!("_")), (2, h!("B"))],
+            disconnect_from: (2, h!("not_B")),
+            exp_result: Ok(ChangeSet::default()),
+            exp_final: local_chain![(0, h!("_")), (2, h!("B"))],
+        },
+        TestCase {
+            name: "disconnect_one",
+            original: local_chain![(0, h!("_")), (2, h!("B"))],
+            disconnect_from: (2, h!("B")),
+            exp_result: Ok(ChangeSet::from_iter([(2, None)])),
+            exp_final: local_chain![(0, h!("_"))],
+        },
+        TestCase {
+            name: "disconnect_three",
+            original: local_chain![(0, h!("_")), (2, h!("B")), (3, h!("C")), (4, h!("D"))],
+            disconnect_from: (2, h!("B")),
+            exp_result: Ok(ChangeSet::from_iter([(2, None), (3, None), (4, None)])),
+            exp_final: local_chain![(0, h!("_"))],
+        },
+    ];
+
+    for (i, t) in test_cases.into_iter().enumerate() {
+        println!("Case {}: {}", i, t.name);
+
+        let mut chain = t.original;
+        let result = chain.disconnect_from(t.disconnect_from.into());
+        assert_eq!(
+            result, t.exp_result,
+            "[{}:{}] unexpected changeset result",
+            i, t.name
+        );
+        assert_eq!(
+            chain, t.exp_final,
+            "[{}:{}] unexpected final chain",
+            i, t.name
+        );
+    }
+}