]> Untitled Git - bdk/commitdiff
feat(core): Optimize `CheckPointIter` using pskip
author志宇 <hello@evanlinjin.me>
Fri, 15 May 2026 18:22:26 +0000 (18:22 +0000)
committer志宇 <hello@evanlinjin.me>
Fri, 15 May 2026 18:25:57 +0000 (18:25 +0000)
crates/core/src/checkpoint.rs
crates/core/tests/test_checkpoint_skiplist.rs

index e560fd6bb17a6b2c572e117f1e50750e87377905..a343e0dc915da9d39223629f42e2354a7ca18fe3 100644 (file)
@@ -569,8 +569,46 @@ impl<D> Iterator for CheckPointIter<D> {
         self.next.clone_from(&current.prev);
         Some(CheckPoint(current))
     }
+
+    fn nth(&mut self, n: usize) -> Option<Self::Item> {
+        // Take `self.next` since if the `n`th is not found, `.next` should return `None`.
+        let current = self.next.take()?;
+
+        let target_index = current.index.checked_sub(n.try_into().ok()?)?;
+        let inner = ancestor_by_index(&current, target_index);
+
+        // Advance `self.next`.
+        self.next.clone_from(&inner.prev);
+
+        Some(CheckPoint(inner))
+    }
+
+    fn last(self) -> Option<Self::Item>
+    where
+        Self: Sized,
+    {
+        Some(CheckPoint(ancestor_by_index(&self.next?, 0)))
+    }
+
+    fn count(self) -> usize
+    where
+        Self: Sized,
+    {
+        self.next
+            .map_or(0, |cp_inner| (cp_inner.index as usize).saturating_add(1))
+    }
+
+    fn size_hint(&self) -> (usize, Option<usize>) {
+        let n = self
+            .next
+            .as_ref()
+            .map_or(0, |cp_inner| (cp_inner.index as usize).saturating_add(1));
+        (n, Some(n))
+    }
 }
 
+impl<D> ExactSizeIterator for CheckPointIter<D> {}
+
 impl<D> IntoIterator for CheckPoint<D> {
     type Item = CheckPoint<D>;
     type IntoIter = CheckPointIter<D>;
index 56e61a3e55dcad856ee8cd08f2fe61950809ca61..e1a082c33c85cfc6eaaf6ecc86d71ea50e290313 100644 (file)
@@ -137,6 +137,35 @@ fn test_range_edge_cases() {
     assert_eq!(from_genesis[2].height(), 0);
 }
 
+#[test]
+fn test_iter_overrides() {
+    const N: u32 = 200;
+    let mut cp = CheckPoint::new(0, BlockHash::all_zeros());
+    for height in 1..=N {
+        let hash = BlockHash::from_byte_array([(height % 256) as u8; 32]);
+        cp = cp.push(height, hash).unwrap();
+    }
+    let len = (N + 1) as usize;
+
+    // Fresh iterator: count, last, size_hint, len.
+    // (nth correctness is covered by the `iter_nth_matches_collected` proptest).
+    assert_eq!(cp.iter().count(), len);
+    assert_eq!(cp.iter().last().unwrap().height(), 0);
+    assert_eq!(cp.iter().size_hint(), (len, Some(len)));
+    assert_eq!(cp.iter().len(), len);
+
+    // After nth(k), the iterator must resume from element k+1 and len shrinks accordingly.
+    let mut it = cp.iter();
+    assert_eq!(it.nth(3).unwrap().height(), N - 3);
+    assert_eq!(it.next().unwrap().height(), N - 4);
+    assert_eq!(it.len(), len - 5);
+
+    // Out-of-range nth drains the iterator.
+    let mut it = cp.iter();
+    assert!(it.nth(usize::MAX).is_none());
+    assert_eq!(it.len(), 0);
+}
+
 /// Build a sparse chain at the given heights (genesis at 0 is implicit; `heights` must be a
 /// strictly increasing sequence of positive heights).
 fn build_chain(heights: &[u32]) -> CheckPoint<BlockHash> {
@@ -204,6 +233,19 @@ proptest! {
         ..ProptestConfig::default()
     })]
 
+    /// `iter().nth(n)` matches indexing into the fully-collected chain for arbitrary n.
+    #[test]
+    fn iter_nth_matches_collected(
+        heights in arbitrary_sparse_heights(),
+        n in 0usize..=300,
+    ) {
+        let cp = build_chain(&heights);
+        let collected: Vec<u32> = cp.iter().map(|c| c.height()).collect();
+        let got = cp.iter().nth(n).map(|c| c.height());
+        let expected = collected.get(n).copied();
+        prop_assert_eq!(got, expected);
+    }
+
     /// `get(h)` matches a linear scan for any chain and any query height (existing, missing,
     /// genesis, beyond tip, etc.).
     #[test]