]> Untitled Git - bdk/commitdiff
Fix policy condition calculation
authorAlekos Filini <alekos.filini@gmail.com>
Fri, 31 Mar 2023 18:08:23 +0000 (20:08 +0200)
committerAlekos Filini <alekos.filini@gmail.com>
Sat, 1 Apr 2023 10:57:49 +0000 (12:57 +0200)
When constructing the `Condition` struct we recursively call
`get_condition` on all the items in a threshold and short-circuit if
there's an error somewhere (for example, because the policy-path hasn't
been provided for a specific threshold).

This can cause issues when the user doesn't care about a subtree, because
we still try to call `get_condition` on all the items and fail if
something is missing, even if the specific subtree isn't selected and
won't be used later on.

This commit changes the logic so that we first filter only the `selected`
items, and then unwrap the error using the question mark. If errors
happened somewhere else they will be ignored, as it should.

crates/bdk/src/descriptor/policy.rs
crates/bdk/tests/wallet.rs

index af3e4a3b8c7fb1877e7ee28dcbb5e4109e7a8110..5de20ae7a3304404348d131cfa0496642a22ee75 100644 (file)
@@ -662,11 +662,11 @@ impl Policy {
                 (0..*threshold).collect()
             }
             SatisfiableItem::Multisig { keys, .. } => (0..keys.len()).collect(),
-            _ => vec![],
+            _ => HashSet::new(),
         };
-        let selected = match path.get(&self.id) {
-            Some(arr) => arr,
-            _ => &default,
+        let selected: HashSet<_> = match path.get(&self.id) {
+            Some(arr) => arr.iter().copied().collect(),
+            _ => default,
         };
 
         match &self.item {
@@ -674,14 +674,24 @@ impl Policy {
                 let mapped_req = items
                     .iter()
                     .map(|i| i.get_condition(path))
-                    .collect::<Result<Vec<_>, _>>()?;
+                    .collect::<Vec<_>>();
 
                 // if all the requirements are null we don't care about `selected` because there
                 // are no requirements
-                if mapped_req.iter().all(Condition::is_null) {
+                if mapped_req
+                    .iter()
+                    .all(|cond| matches!(cond, Ok(c) if c.is_null()))
+                {
                     return Ok(Condition::default());
                 }
 
+                // make sure all the indexes in the `selected` list are within range
+                for index in &selected {
+                    if *index >= items.len() {
+                        return Err(PolicyError::IndexOutOfRange(*index));
+                    }
+                }
+
                 // if we have something, make sure we have enough items. note that the user can set
                 // an empty value for this step in case of n-of-n, because `selected` is set to all
                 // the elements above
@@ -690,23 +700,18 @@ impl Policy {
                 }
 
                 // check the selected items, see if there are conflicting requirements
-                let mut requirements = Condition::default();
-                for item_index in selected {
-                    requirements = requirements.merge(
-                        mapped_req
-                            .get(*item_index)
-                            .ok_or(PolicyError::IndexOutOfRange(*item_index))?,
-                    )?;
-                }
-
-                Ok(requirements)
+                mapped_req
+                    .into_iter()
+                    .enumerate()
+                    .filter(|(index, _)| selected.contains(index))
+                    .try_fold(Condition::default(), |acc, (_, cond)| acc.merge(&cond?))
             }
             SatisfiableItem::Multisig { keys, threshold } => {
                 if selected.len() < *threshold {
                     return Err(PolicyError::NotEnoughItemsSelected(self.id.clone()));
                 }
-                if let Some(item) = selected.iter().find(|i| **i >= keys.len()) {
-                    return Err(PolicyError::IndexOutOfRange(*item));
+                if let Some(item) = selected.into_iter().find(|&i| i >= keys.len()) {
+                    return Err(PolicyError::IndexOutOfRange(item));
                 }
 
                 Ok(Condition::default())
index 9b25223e47d553360ae590804813e7944f42d511..0ada20d398d2be5c7c9c5308b06e5c24a8fb25bb 100644 (file)
@@ -925,6 +925,25 @@ fn test_create_tx_policy_path_use_csv() {
     assert_eq!(psbt.unsigned_tx.input[0].sequence, Sequence(144));
 }
 
+#[test]
+fn test_create_tx_policy_path_ignored_subtree_with_csv() {
+    let (mut wallet, _) = get_funded_wallet("wsh(or_d(pk(cRjo6jqfVNP33HhSS76UhXETZsGTZYx8FMFvR9kpbtCSV1PmdZdu),or_i(and_v(v:pkh(cVpPVruEDdmutPzisEsYvtST1usBR3ntr8pXSyt6D2YYqXRyPcFW),older(30)),and_v(v:pkh(cMnkdebixpXMPfkcNEjjGin7s94hiehAH4mLbYkZoh9KSiNNmqC8),older(90)))))");
+
+    let external_policy = wallet.policies(KeychainKind::External).unwrap().unwrap();
+    let root_id = external_policy.id;
+    // child #0 is pk(cRjo6jqfVNP33HhSS76UhXETZsGTZYx8FMFvR9kpbtCSV1PmdZdu)
+    let path = vec![(root_id, vec![0])].into_iter().collect();
+
+    let addr = Address::from_str("2N1Ffz3WaNzbeLFBb51xyFMHYSEUXcbiSoX").unwrap();
+    let mut builder = wallet.build_tx();
+    builder
+        .add_recipient(addr.script_pubkey(), 30_000)
+        .policy_path(path, KeychainKind::External);
+    let (psbt, _) = builder.finish().unwrap();
+
+    assert_eq!(psbt.unsigned_tx.input[0].sequence, Sequence(0xFFFFFFFE));
+}
+
 #[test]
 fn test_create_tx_global_xpubs_with_origin() {
     use bitcoin::hashes::hex::FromHex;