]> Untitled Git - bdk/commitdiff
Contribution improvements
authorAlekos Filini <alekos.filini@gmail.com>
Mon, 17 Feb 2020 13:22:53 +0000 (14:22 +0100)
committerAlekos Filini <alekos.filini@gmail.com>
Tue, 7 Apr 2020 09:19:38 +0000 (11:19 +0200)
Cargo.toml
examples/parse_descriptor.rs
examples/repl.rs
src/descriptor/error.rs
src/descriptor/extended_key.rs
src/descriptor/mod.rs
src/descriptor/policy.rs
src/error.rs
src/wallet/mod.rs

index f65ba76fc2164aaad0c0b96923dbcc278bedbf71..6cf3607a0e72a93836c0dc3b89e911d58a21c6e1 100644 (file)
@@ -13,7 +13,7 @@ base64 = "^0.11"
 
 # Optional dependencies
 sled = { version = "0.31.0", optional = true }
-electrum-client = { version = "0.1.0-beta.1", optional = true }
+electrum-client = { version = "0.1.0-beta.5", optional = true }
 
 [features]
 minimal = []
index 63c16b4c466b6eeae169db220ff03a3d49cb2ec4..615ba9c06d2e7a3eb4457c5c68960a2cfc1d7624 100644 (file)
@@ -17,13 +17,12 @@ fn main() {
     let extended_desc = ExtendedDescriptor::from_str(desc).unwrap();
     println!("{:?}", extended_desc);
 
+    let policy = extended_desc.extract_policy().unwrap();
+    println!("policy: {}", serde_json::to_string(&policy).unwrap());
+
     let derived_desc = extended_desc.derive(42).unwrap();
     println!("{:?}", derived_desc);
 
-    if let Descriptor::Wsh(x) = &derived_desc {
-        println!("{}", serde_json::to_string(&x.extract_policy()).unwrap());
-    }
-
     let addr = derived_desc.address(Network::Testnet).unwrap();
     println!("{}", addr);
 
index 0430d962f776a5072708cadeb55cbe4053c217b2..b67b6e385f7360afe6d75ea3a437595cd52ceab1 100644 (file)
@@ -304,11 +304,10 @@ fn main() {
             let psbt: PartiallySignedTransaction = deserialize(&psbt).unwrap();
             let (psbt, finalized) = wallet.sign(psbt).unwrap();
 
+            println!("PSBT: {}", base64::encode(&serialize(&psbt)));
             println!("Finalized: {}", finalized);
             if finalized {
                 println!("Extracted: {}", serialize_hex(&psbt.extract_tx()));
-            } else {
-                println!("PSBT: {}", base64::encode(&serialize(&psbt)));
             }
         }
     };
index 0a58ab3dcd8e76a8997f612c156ea73176ebbca3..7a0d076eed5ae3119b02e9b5999cf5f802e0880c 100644 (file)
@@ -6,6 +6,8 @@ pub enum Error {
     MalformedInput,
     KeyParsingError(String),
 
+    Policy(crate::descriptor::policy::PolicyError),
+
     InputIndexDoesntExist,
     MissingPublicKey,
     MissingDetails,
@@ -32,3 +34,4 @@ impl_error!(bitcoin::util::base58::Error, Base58);
 impl_error!(bitcoin::util::key::Error, PK);
 impl_error!(miniscript::Error, Miniscript);
 impl_error!(bitcoin::hashes::hex::Error, Hex);
+impl_error!(crate::descriptor::policy::PolicyError, Policy);
index 71347f161c8b8848d2b45fe9c1e108f2d04f3179..4645eea0de72d450dbb11e3f62837582ce385c80 100644 (file)
@@ -75,7 +75,6 @@ impl DescriptorExtendedKey {
         final_path.into()
     }
 
-
     pub fn derive<C: secp256k1::Verification + secp256k1::Signing>(
         &self,
         ctx: &secp256k1::Secp256k1<C>,
index ac241b53a92c1118d5c92e26764bec0d91e32ffd..71d9ae1642d2848cfa5ec53e9e0729880088dea4 100644 (file)
@@ -10,7 +10,7 @@ use bitcoin::util::bip32::{DerivationPath, ExtendedPrivKey, Fingerprint};
 use bitcoin::util::psbt::PartiallySignedTransaction as PSBT;
 use bitcoin::{PrivateKey, PublicKey, Script};
 
-pub use miniscript::{descriptor::Descriptor, Miniscript};
+pub use miniscript::{Descriptor, Miniscript, MiniscriptKey, Terminal};
 
 use serde::{Deserialize, Serialize};
 
@@ -27,11 +27,14 @@ pub use self::extended_key::{DerivationIndex, DescriptorExtendedKey};
 pub use self::policy::Policy;
 
 trait MiniscriptExtractPolicy {
-    fn extract_policy(&self, lookup_map: &BTreeMap<String, Box<dyn Key>>) -> Option<Policy>;
+    fn extract_policy(
+        &self,
+        lookup_map: &BTreeMap<String, Box<dyn Key>>,
+    ) -> Result<Option<Policy>, Error>;
 }
 
 pub trait ExtractPolicy {
-    fn extract_policy(&self) -> Option<Policy>;
+    fn extract_policy(&self) -> Result<Option<Policy>, Error>;
 }
 
 #[derive(Debug, Clone, Hash, PartialEq, PartialOrd, Eq, Ord, Default)]
@@ -228,6 +231,12 @@ impl std::clone::Clone for ExtendedDescriptor {
     }
 }
 
+impl std::convert::AsRef<StringDescriptor> for ExtendedDescriptor {
+    fn as_ref(&self) -> &StringDescriptor {
+        &self.internal
+    }
+}
+
 impl ExtendedDescriptor {
     fn parse_string(string: &str) -> Result<(String, Box<dyn Key>), Error> {
         if let Ok(pk) = PublicKey::from_str(string) {
@@ -271,13 +280,18 @@ impl ExtendedDescriptor {
         &self,
         miniscript: Miniscript<PublicKey>,
     ) -> Result<DerivedDescriptor, Error> {
-        // TODO: make sure they are "equivalent"
-        match self.internal {
-            Descriptor::Bare(_) => Ok(Descriptor::Bare(miniscript)),
-            Descriptor::Sh(_) => Ok(Descriptor::Sh(miniscript)),
-            Descriptor::Wsh(_) => Ok(Descriptor::Wsh(miniscript)),
-            Descriptor::ShWsh(_) => Ok(Descriptor::ShWsh(miniscript)),
-            _ => Err(Error::CantDeriveWithMiniscript),
+        let derived_desc = match self.internal {
+            Descriptor::Bare(_) => Descriptor::Bare(miniscript),
+            Descriptor::Sh(_) => Descriptor::Sh(miniscript),
+            Descriptor::Wsh(_) => Descriptor::Wsh(miniscript),
+            Descriptor::ShWsh(_) => Descriptor::ShWsh(miniscript),
+            _ => return Err(Error::CantDeriveWithMiniscript),
+        };
+
+        if !self.same_structure(&derived_desc) {
+            Err(Error::CantDeriveWithMiniscript)
+        } else {
+            Ok(derived_desc)
         }
     }
 
@@ -388,10 +402,29 @@ impl ExtendedDescriptor {
     pub fn is_fixed(&self) -> bool {
         self.keys.iter().all(|(_, key)| key.is_fixed())
     }
+
+    pub fn same_structure<K: MiniscriptKey>(&self, other: &Descriptor<K>) -> bool {
+        // Translate all the public keys to () and then check if the two descriptors are equal.
+        // TODO: translate hashes to their default value before checking for ==
+
+        let func_string = |_string: &String| -> Result<_, Error> { Ok(DummyKey::default()) };
+
+        let func_generic_pk = |_data: &K| -> Result<_, Error> { Ok(DummyKey::default()) };
+        let func_generic_pkh =
+            |_data: &<K as MiniscriptKey>::Hash| -> Result<_, Error> { Ok(DummyKey::default()) };
+
+        let translated_a = self.internal.translate_pk(func_string, func_string);
+        let translated_b = other.translate_pk(func_generic_pk, func_generic_pkh);
+
+        match (translated_a, translated_b) {
+            (Ok(a), Ok(b)) => a == b,
+            _ => false,
+        }
+    }
 }
 
 impl ExtractPolicy for ExtendedDescriptor {
-    fn extract_policy(&self) -> Option<Policy> {
+    fn extract_policy(&self) -> Result<Option<Policy>, Error> {
         self.internal.extract_policy(&self.keys)
     }
 }
@@ -479,7 +512,10 @@ mod test {
                 .to_string(),
             "mqwpxxvfv3QbM8PU8uBx2jaNt9btQqvQNx"
         );
-        assert_eq!(desc.get_secret_keys().into_iter().collect::<Vec<_>>().len(), 1);
+        assert_eq!(
+            desc.get_secret_keys().into_iter().collect::<Vec<_>>().len(),
+            1
+        );
     }
 
     #[test]
@@ -503,7 +539,10 @@ mod test {
                 .to_string(),
             "mqwpxxvfv3QbM8PU8uBx2jaNt9btQqvQNx"
         );
-        assert_eq!(desc.get_secret_keys().into_iter().collect::<Vec<_>>().len(), 0);
+        assert_eq!(
+            desc.get_secret_keys().into_iter().collect::<Vec<_>>().len(),
+            0
+        );
     }
 
     #[test]
index b8b59c878ec18f3c709b9f1581fbfdd313e742eb..9955164ba1015d13c5a5c30198f78ed10ae12d4e 100644 (file)
@@ -1,22 +1,30 @@
-use std::collections::{BTreeMap, HashSet};
+use std::cmp::max;
+use std::collections::{BTreeMap, HashSet, VecDeque};
 
-use serde::Serialize;
+use serde::ser::SerializeMap;
+use serde::{Serialize, Serializer};
 
 use bitcoin::hashes::*;
 use bitcoin::secp256k1::Secp256k1;
 use bitcoin::util::bip32::Fingerprint;
-use bitcoin::util::psbt;
 use bitcoin::PublicKey;
 
 use miniscript::{Descriptor, Miniscript, Terminal};
 
-use descriptor::{Key, MiniscriptExtractPolicy};
+#[allow(unused_imports)]
+use log::{debug, error, info, trace};
 
-#[derive(Debug, Serialize)]
+use super::error::Error;
+use crate::descriptor::{Key, MiniscriptExtractPolicy};
+use crate::psbt::PSBTSatisfier;
+
+#[derive(Debug, Clone, Default, Serialize)]
 pub struct PKOrF {
     #[serde(skip_serializing_if = "Option::is_none")]
     pubkey: Option<PublicKey>,
     #[serde(skip_serializing_if = "Option::is_none")]
+    pubkey_hash: Option<hash160::Hash>,
+    #[serde(skip_serializing_if = "Option::is_none")]
     fingerprint: Option<Fingerprint>,
 }
 
@@ -28,28 +36,23 @@ impl PKOrF {
         if let Some(fing) = k.fingerprint(&secp) {
             PKOrF {
                 fingerprint: Some(fing),
-                pubkey: None,
+                ..Default::default()
             }
         } else {
             PKOrF {
-                fingerprint: None,
                 pubkey: Some(pubkey),
+                ..Default::default()
             }
         }
     }
 }
 
-#[derive(Debug, Serialize)]
+#[derive(Debug, Clone, Serialize)]
 #[serde(tag = "type", rename_all = "UPPERCASE")]
 pub enum SatisfiableItem {
     // Leaves
     Signature(PKOrF),
-    SignatureKey {
-        #[serde(skip_serializing_if = "Option::is_none")]
-        fingerprint: Option<Fingerprint>,
-        #[serde(skip_serializing_if = "Option::is_none")]
-        pubkey_hash: Option<hash160::Hash>,
-    },
+    SignatureKey(PKOrF),
     SHA256Preimage {
         hash: sha256::Hash,
     },
@@ -90,95 +93,240 @@ impl SatisfiableItem {
             _ => true,
         }
     }
+}
 
-    fn satisfy(&self, _input: &psbt::Input) -> Satisfaction {
-        Satisfaction::None
+fn combinations(vec: &Vec<usize>, size: usize) -> Vec<Vec<usize>> {
+    assert!(vec.len() >= size);
+
+    let mut answer = Vec::new();
+
+    let mut queue = VecDeque::new();
+    for (index, val) in vec.iter().enumerate() {
+        let mut new_vec = Vec::with_capacity(size);
+        new_vec.push(*val);
+        queue.push_back((index, new_vec));
     }
+
+    while let Some((index, vals)) = queue.pop_front() {
+        if vals.len() >= size {
+            answer.push(vals);
+        } else {
+            for (new_index, val) in vec.iter().skip(index + 1).enumerate() {
+                let mut cloned = vals.clone();
+                cloned.push(*val);
+                queue.push_front((new_index, cloned));
+            }
+        }
+    }
+
+    answer
 }
 
-#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
+fn mix<T: Clone>(vec: Vec<Vec<T>>) -> Vec<Vec<T>> {
+    if vec.is_empty() || vec.iter().any(Vec::is_empty) {
+        return vec![];
+    }
+
+    let mut answer = Vec::new();
+    let size = vec.len();
+
+    let mut queue = VecDeque::new();
+    for i in &vec[0] {
+        let mut new_vec = Vec::with_capacity(size);
+        new_vec.push(i.clone());
+        queue.push_back(new_vec);
+    }
+
+    while let Some(vals) = queue.pop_front() {
+        if vals.len() >= size {
+            answer.push(vals);
+        } else {
+            let level = vals.len();
+            for i in &vec[level] {
+                let mut cloned = vals.clone();
+                cloned.push(i.clone());
+                queue.push_front(cloned);
+            }
+        }
+    }
+
+    answer
+}
+
+pub type ConditionMap = BTreeMap<usize, HashSet<Condition>>;
+pub type FoldedConditionMap = BTreeMap<Vec<usize>, HashSet<Condition>>;
+
+fn serialize_folded_cond_map<S>(
+    input_map: &FoldedConditionMap,
+    serializer: S,
+) -> Result<S::Ok, S::Error>
+where
+    S: Serializer,
+{
+    let mut map = serializer.serialize_map(Some(input_map.len()))?;
+    for (k, v) in input_map {
+        let k_string = format!("{:?}", k);
+        map.serialize_entry(&k_string, v)?;
+    }
+    map.end()
+}
+
+#[derive(Debug, Clone, Serialize)]
 #[serde(tag = "type", rename_all = "UPPERCASE")]
 pub enum Satisfaction {
-    Complete {
-        #[serde(skip_serializing_if = "PathRequirements::is_null")]
-        condition: PathRequirements,
-    },
     Partial {
+        n: usize,
         m: usize,
+        items: Vec<usize>,
+        #[serde(skip_serializing_if = "BTreeMap::is_empty")]
+        conditions: ConditionMap,
+    },
+    PartialComplete {
         n: usize,
-        completed: HashSet<usize>,
+        m: usize,
+        items: Vec<usize>,
+        #[serde(
+            serialize_with = "serialize_folded_cond_map",
+            skip_serializing_if = "BTreeMap::is_empty"
+        )]
+        conditions: FoldedConditionMap,
+    },
+
+    Complete {
+        condition: Condition,
     },
     None,
 }
 
 impl Satisfaction {
-    fn from_items_threshold(items: HashSet<usize>, threshold: usize) -> Satisfaction {
-        Satisfaction::Partial {
-            m: items.len(),
-            n: threshold,
-            completed: items,
+    pub fn is_leaf(&self) -> bool {
+        match self {
+            Satisfaction::None | Satisfaction::Complete { .. } => true,
+            Satisfaction::PartialComplete { .. } | Satisfaction::Partial { .. } => false,
         }
     }
-}
 
-impl<'a> std::ops::Add<&'a Satisfaction> for Satisfaction {
-    type Output = Satisfaction;
+    // add `inner` as one of self's partial items. this only makes sense on partials
+    fn add(&mut self, inner: &Satisfaction, inner_index: usize) -> Result<(), PolicyError> {
+        match self {
+            Satisfaction::None | Satisfaction::Complete { .. } => Err(PolicyError::AddOnLeaf),
+            Satisfaction::PartialComplete { .. } => Err(PolicyError::AddOnPartialComplete),
+            Satisfaction::Partial {
+                n,
+                ref mut conditions,
+                ref mut items,
+                ..
+            } => {
+                if inner_index >= *n || items.contains(&inner_index) {
+                    return Err(PolicyError::IndexOutOfRange(inner_index));
+                }
+
+                match inner {
+                    // not relevant if not completed yet
+                    Satisfaction::None | Satisfaction::Partial { .. } => return Ok(()),
+                    Satisfaction::Complete { condition } => {
+                        items.push(inner_index);
+                        conditions.insert(inner_index, vec![*condition].into_iter().collect());
+                    }
+                    Satisfaction::PartialComplete {
+                        conditions: other_conditions,
+                        ..
+                    } => {
+                        items.push(inner_index);
+                        let conditions_set = other_conditions
+                            .values()
+                            .fold(HashSet::new(), |set, i| set.union(&i).cloned().collect());
+                        conditions.insert(inner_index, conditions_set);
+                    }
+                }
 
-    fn add(self, other: &'a Satisfaction) -> Satisfaction {
-        &self + other
+                Ok(())
+            }
+        }
     }
-}
 
-impl<'a, 'b> std::ops::Add<&'b Satisfaction> for &'a Satisfaction {
-    type Output = Satisfaction;
-
-    fn add(self, other: &'b Satisfaction) -> Satisfaction {
-        match (self, other) {
-            // complete-complete
-            (
-                Satisfaction::Complete { condition: mut a },
-                Satisfaction::Complete { condition: b },
-            ) => {
-                a.merge(&b).unwrap();
-                Satisfaction::Complete { condition: a }
+    fn finalize(&mut self) -> Result<(), PolicyError> {
+        // if partial try to bump it to a partialcomplete
+        if let Satisfaction::Partial {
+            n,
+            m,
+            items,
+            conditions,
+        } = self
+        {
+            if items.len() >= *m {
+                let mut map = BTreeMap::new();
+                let indexes = combinations(items, *m);
+                // `indexes` at this point is a Vec<Vec<usize>>, with the "n choose k" of items (m of n)
+                indexes
+                    .into_iter()
+                    // .inspect(|x| println!("--- orig --- {:?}", x))
+                    // we map each of the combinations of elements into a tuple of ([choosen items], [conditions]). unfortunately, those items have potentially more than one
+                    // condition (think about ORs), so we also use `mix` to expand those, i.e. [[0], [1, 2]] becomes [[0, 1], [0, 2]]. This is necessary to make sure that we
+                    // consider every possibile options and check whether or not they are compatible.
+                    .map(|i_vec| {
+                        mix(i_vec
+                            .iter()
+                            .map(|i| {
+                                conditions
+                                    .get(i)
+                                    .and_then(|set| Some(set.clone().into_iter().collect()))
+                                    .unwrap_or(vec![])
+                            })
+                            .collect())
+                        .into_iter()
+                        .map(|x| (i_vec.clone(), x))
+                        .collect::<Vec<(Vec<usize>, Vec<Condition>)>>()
+                    })
+                    // .inspect(|x: &Vec<(Vec<usize>, Vec<Condition>)>| println!("fetch {:?}", x))
+                    // since the previous step can turn one item of the iterator into multiple ones, we call flatten to expand them out
+                    .flatten()
+                    // .inspect(|x| println!("flat {:?}", x))
+                    // try to fold all the conditions for this specific combination of indexes/options. if they are not compatibile, try_fold will be Err
+                    .map(|(key, val)| {
+                        (
+                            key,
+                            val.into_iter()
+                                .try_fold(Condition::default(), |acc, v| acc.merge(&v)),
+                        )
+                    })
+                    // .inspect(|x| println!("try_fold {:?}", x))
+                    // filter out all the incompatible combinations
+                    .filter(|(_, val)| val.is_ok())
+                    // .inspect(|x| println!("filter {:?}", x))
+                    // push them into the map
+                    .for_each(|(key, val)| {
+                        map.entry(key)
+                            .or_insert_with(HashSet::new)
+                            .insert(val.unwrap());
+                    });
+                // TODO: if the map is empty, the conditions are not compatible, return an error?
+                *self = Satisfaction::PartialComplete {
+                    n: *n,
+                    m: *m,
+                    items: items.clone(),
+                    conditions: map,
+                };
             }
-            // complete-<any>
-            (Satisfaction::Complete { condition }, _) => Satisfaction::Complete {
-                condition: *condition,
-            },
-            (_, Satisfaction::Complete { condition }) => Satisfaction::Complete {
-                condition: *condition,
-            },
-
-            // none-<any>
-            (Satisfaction::None, any) => any.clone(),
-            (any, Satisfaction::None) => any.clone(),
-
-            // partial-partial
-            (
-                Satisfaction::Partial {
-                    m: _,
-                    n: a_n,
-                    completed: a_items,
-                },
-                Satisfaction::Partial {
-                    m: _,
-                    n: _,
-                    completed: b_items,
-                },
-            ) => {
-                let union: HashSet<_> = a_items.union(&b_items).cloned().collect();
-                Satisfaction::Partial {
-                    m: union.len(),
-                    n: *a_n,
-                    completed: union,
-                }
+        }
+
+        Ok(())
+    }
+}
+
+impl From<bool> for Satisfaction {
+    fn from(other: bool) -> Self {
+        if other {
+            Satisfaction::Complete {
+                condition: Default::default(),
             }
+        } else {
+            Satisfaction::None
         }
     }
 }
 
-#[derive(Debug, Serialize)]
+#[derive(Debug, Clone, Serialize)]
 pub struct Policy {
     #[serde(flatten)]
     item: SatisfiableItem,
@@ -186,39 +334,39 @@ pub struct Policy {
     contribution: Satisfaction,
 }
 
-#[derive(Debug, Default, Eq, PartialEq, Clone, Copy, Serialize)]
-pub struct PathRequirements {
+#[derive(Hash, Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Default, Serialize)]
+pub struct Condition {
     #[serde(skip_serializing_if = "Option::is_none")]
     pub csv: Option<u32>,
     #[serde(skip_serializing_if = "Option::is_none")]
     pub timelock: Option<u32>,
 }
 
-impl PathRequirements {
-    pub fn merge(&mut self, other: &Self) -> Result<(), PolicyError> {
-        if other.is_null() {
-            return Ok(());
+impl Condition {
+    fn merge_timelock(a: u32, b: u32) -> Result<u32, PolicyError> {
+        const BLOCKS_TIMELOCK_THRESHOLD: u32 = 500000000;
+
+        if (a < BLOCKS_TIMELOCK_THRESHOLD) != (b < BLOCKS_TIMELOCK_THRESHOLD) {
+            Err(PolicyError::MixedTimelockUnits)
+        } else {
+            Ok(max(a, b))
         }
+    }
 
+    fn merge(mut self, other: &Condition) -> Result<Self, PolicyError> {
         match (self.csv, other.csv) {
-            (Some(old), Some(new)) if old != new => Err(PolicyError::DifferentCSV(old, new)),
-            _ => {
-                self.csv = self.csv.or(other.csv);
-                Ok(())
-            }
-        }?;
+            (Some(a), Some(b)) => self.csv = Some(Self::merge_timelock(a, b)?),
+            (None, any) => self.csv = any,
+            _ => {}
+        }
 
         match (self.timelock, other.timelock) {
-            // TODO: we could actually set the timelock to the highest of the two, but we would
-            // have to first check that they are both in the same "unit" (blocks vs time)
-            (Some(old), Some(new)) if old != new => Err(PolicyError::DifferentTimelock(old, new)),
-            _ => {
-                self.timelock = self.timelock.or(other.timelock);
-                Ok(())
-            }
-        }?;
+            (Some(a), Some(b)) => self.timelock = Some(Self::merge_timelock(a, b)?),
+            (None, any) => self.timelock = any,
+            _ => {}
+        }
 
-        Ok(())
+        Ok(self)
     }
 
     pub fn is_null(&self) -> bool {
@@ -230,9 +378,11 @@ impl PathRequirements {
 pub enum PolicyError {
     NotEnoughItemsSelected(usize),
     TooManyItemsSelected(usize),
-    IndexOutOfRange(usize, usize),
-    DifferentCSV(u32, u32),
-    DifferentTimelock(u32, u32),
+    IndexOutOfRange(usize),
+    AddOnLeaf,
+    AddOnPartialComplete,
+    MixedTimelockUnits,
+    IncompatibleConditions,
 }
 
 impl Policy {
@@ -244,71 +394,95 @@ impl Policy {
         }
     }
 
-    pub fn make_and(a: Option<Policy>, b: Option<Policy>) -> Option<Policy> {
+    pub fn make_and(a: Option<Policy>, b: Option<Policy>) -> Result<Option<Policy>, PolicyError> {
         match (a, b) {
-            (None, None) => None,
-            (Some(x), None) | (None, Some(x)) => Some(x),
+            (None, None) => Ok(None),
+            (Some(x), None) | (None, Some(x)) => Ok(Some(x)),
             (Some(a), Some(b)) => Self::make_thresh(vec![a, b], 2),
         }
     }
 
-    pub fn make_or(a: Option<Policy>, b: Option<Policy>) -> Option<Policy> {
+    pub fn make_or(a: Option<Policy>, b: Option<Policy>) -> Result<Option<Policy>, PolicyError> {
         match (a, b) {
-            (None, None) => None,
-            (Some(x), None) | (None, Some(x)) => Some(x),
+            (None, None) => Ok(None),
+            (Some(x), None) | (None, Some(x)) => Ok(Some(x)),
             (Some(a), Some(b)) => Self::make_thresh(vec![a, b], 1),
         }
     }
 
-    pub fn make_thresh(items: Vec<Policy>, threshold: usize) -> Option<Policy> {
+    pub fn make_thresh(
+        items: Vec<Policy>,
+        threshold: usize,
+    ) -> Result<Option<Policy>, PolicyError> {
         if threshold == 0 {
-            return None;
+            return Ok(None);
         }
 
-        let contribution = items.iter().fold(
-            Satisfaction::Partial {
-                m: 0,
-                n: threshold,
-                completed: HashSet::new(),
-            },
-            |acc, x| acc + &x.contribution,
-        );
+        let mut contribution = Satisfaction::Partial {
+            n: items.len(),
+            m: threshold,
+            items: vec![],
+            conditions: Default::default(),
+        };
+        for (index, item) in items.iter().enumerate() {
+            contribution.add(&item.contribution, index)?;
+        }
+        contribution.finalize()?;
+
         let mut policy: Policy = SatisfiableItem::Thresh { items, threshold }.into();
         policy.contribution = contribution;
 
-        Some(policy)
+        Ok(Some(policy))
     }
 
-    fn make_multisig(keys: Vec<Option<&Box<dyn Key>>>, threshold: usize) -> Option<Policy> {
+    fn make_multisig(
+        keys: Vec<Option<&Box<dyn Key>>>,
+        threshold: usize,
+    ) -> Result<Option<Policy>, PolicyError> {
+        if threshold == 0 {
+            return Ok(None);
+        }
+
         let parsed_keys = keys.iter().map(|k| PKOrF::from_key(k.unwrap())).collect();
+
+        let mut contribution = Satisfaction::Partial {
+            n: keys.len(),
+            m: threshold,
+            items: vec![],
+            conditions: Default::default(),
+        };
+        for (index, key) in keys.iter().enumerate() {
+            let val = if key.is_some() && key.unwrap().has_secret() {
+                Satisfaction::Complete {
+                    condition: Default::default(),
+                }
+            } else {
+                Satisfaction::None
+            };
+            contribution.add(&val, index)?;
+        }
+        contribution.finalize()?;
+
         let mut policy: Policy = SatisfiableItem::Multisig {
             keys: parsed_keys,
             threshold,
         }
         .into();
-        let our_keys = keys
-            .iter()
-            .enumerate()
-            .filter(|(_, x)| x.is_some() && x.unwrap().has_secret())
-            .map(|(k, _)| k)
-            .collect();
-        policy.contribution = Satisfaction::from_items_threshold(our_keys, threshold);
-
-        Some(policy)
+        policy.contribution = contribution;
+
+        Ok(Some(policy))
     }
 
-    pub fn satisfy(&mut self, input: &psbt::Input) {
-        self.satisfaction = self.item.satisfy(input);
+    pub fn satisfy(&mut self, _satisfier: &PSBTSatisfier, _desc_node: &Terminal<PublicKey>) {
+        //self.satisfaction = self.item.satisfy(satisfier, desc_node);
+        //self.contribution += &self.satisfaction;
     }
 
     pub fn requires_path(&self) -> bool {
         self.get_requirements(&vec![]).is_err()
     }
 
-    pub fn get_requirements(
-        &self,
-        path: &Vec<Vec<usize>>,
-    ) -> Result<PathRequirements, PolicyError> {
+    pub fn get_requirements(&self, path: &Vec<Vec<usize>>) -> Result<Condition, PolicyError> {
         self.recursive_get_requirements(path, 0)
     }
 
@@ -316,7 +490,7 @@ impl Policy {
         &self,
         path: &Vec<Vec<usize>>,
         index: usize,
-    ) -> Result<PathRequirements, PolicyError> {
+    ) -> Result<Condition, PolicyError> {
         // if items.len() == threshold, selected can be omitted and we take all of them by default
         let default = match &self.item {
             SatisfiableItem::Thresh { items, threshold } if items.len() == *threshold => {
@@ -339,8 +513,8 @@ impl Policy {
 
                 // if all the requirements are null we don't care about `selected` because there
                 // are no requirements
-                if mapped_req.iter().all(PathRequirements::is_null) {
-                    return Ok(PathRequirements::default());
+                if mapped_req.iter().all(Condition::is_null) {
+                    return Ok(Condition::default());
                 }
 
                 // if we have something, make sure we have enough items. note that the user can set
@@ -351,27 +525,27 @@ impl Policy {
                 }
 
                 // check the selected items, see if there are conflicting requirements
-                let mut requirements = PathRequirements::default();
+                let mut requirements = Condition::default();
                 for item_index in selected {
-                    requirements.merge(
+                    requirements = requirements.merge(
                         mapped_req
                             .get(*item_index)
-                            .ok_or(PolicyError::IndexOutOfRange(*item_index, index))?,
+                            .ok_or(PolicyError::IndexOutOfRange(*item_index))?,
                     )?;
                 }
 
                 Ok(requirements)
             }
             _ if !selected.is_empty() => Err(PolicyError::TooManyItemsSelected(index)),
-            SatisfiableItem::AbsoluteTimelock { value } => Ok(PathRequirements {
+            SatisfiableItem::AbsoluteTimelock { value } => Ok(Condition {
                 csv: None,
                 timelock: Some(*value),
             }),
-            SatisfiableItem::RelativeTimelock { value } => Ok(PathRequirements {
+            SatisfiableItem::RelativeTimelock { value } => Ok(Condition {
                 csv: Some(*value),
                 timelock: None,
             }),
-            _ => Ok(PathRequirements::default()),
+            _ => Ok(Condition::default()),
         }
     }
 }
@@ -403,15 +577,15 @@ fn signature_key_from_string(key: Option<&Box<dyn Key>>) -> Option<Policy> {
     key.map(|k| {
         let pubkey = k.as_public_key(&secp, None).unwrap();
         let mut policy: Policy = if let Some(fing) = k.fingerprint(&secp) {
-            SatisfiableItem::SignatureKey {
+            SatisfiableItem::SignatureKey(PKOrF {
                 fingerprint: Some(fing),
-                pubkey_hash: None,
-            }
+                ..Default::default()
+            })
         } else {
-            SatisfiableItem::SignatureKey {
-                fingerprint: None,
+            SatisfiableItem::SignatureKey(PKOrF {
                 pubkey_hash: Some(hash160::Hash::hash(&pubkey.to_bytes())),
-            }
+                ..Default::default()
+            })
         }
         .into();
         policy.contribution = if k.has_secret() {
@@ -427,8 +601,11 @@ fn signature_key_from_string(key: Option<&Box<dyn Key>>) -> Option<Policy> {
 }
 
 impl MiniscriptExtractPolicy for Miniscript<String> {
-    fn extract_policy(&self, lookup_map: &BTreeMap<String, Box<dyn Key>>) -> Option<Policy> {
-        match &self.node {
+    fn extract_policy(
+        &self,
+        lookup_map: &BTreeMap<String, Box<dyn Key>>,
+    ) -> Result<Option<Policy>, Error> {
+        Ok(match &self.node {
             // Leaves
             Terminal::True | Terminal::False => None,
             Terminal::Pk(pubkey) => signature_from_string(lookup_map.get(pubkey)),
@@ -436,9 +613,9 @@ impl MiniscriptExtractPolicy for Miniscript<String> {
             Terminal::After(value) => {
                 let mut policy: Policy = SatisfiableItem::AbsoluteTimelock { value: *value }.into();
                 policy.contribution = Satisfaction::Complete {
-                    condition: PathRequirements {
-                        csv: None,
+                    condition: Condition {
                         timelock: Some(*value),
+                        csv: None,
                     },
                 };
 
@@ -447,9 +624,9 @@ impl MiniscriptExtractPolicy for Miniscript<String> {
             Terminal::Older(value) => {
                 let mut policy: Policy = SatisfiableItem::RelativeTimelock { value: *value }.into();
                 policy.contribution = Satisfaction::Complete {
-                    condition: PathRequirements {
-                        csv: Some(*value),
+                    condition: Condition {
                         timelock: None,
+                        csv: Some(*value),
                     },
                 };
 
@@ -466,7 +643,7 @@ impl MiniscriptExtractPolicy for Miniscript<String> {
                 Some(SatisfiableItem::HASH160Preimage { hash: *hash }.into())
             }
             Terminal::ThreshM(k, pks) => {
-                Policy::make_multisig(pks.iter().map(|s| lookup_map.get(s)).collect(), *k)
+                Policy::make_multisig(pks.iter().map(|s| lookup_map.get(s)).collect(), *k)?
             }
             // Identities
             Terminal::Alt(inner)
@@ -475,52 +652,58 @@ impl MiniscriptExtractPolicy for Miniscript<String> {
             | Terminal::DupIf(inner)
             | Terminal::Verify(inner)
             | Terminal::NonZero(inner)
-            | Terminal::ZeroNotEqual(inner) => inner.extract_policy(lookup_map),
+            | Terminal::ZeroNotEqual(inner) => inner.extract_policy(lookup_map)?,
             // Complex policies
             Terminal::AndV(a, b) | Terminal::AndB(a, b) => {
-                Policy::make_and(a.extract_policy(lookup_map), b.extract_policy(lookup_map))
+                Policy::make_and(a.extract_policy(lookup_map)?, b.extract_policy(lookup_map)?)?
             }
             Terminal::AndOr(x, y, z) => Policy::make_or(
-                Policy::make_and(x.extract_policy(lookup_map), y.extract_policy(lookup_map)),
-                z.extract_policy(lookup_map),
-            ),
+                Policy::make_and(x.extract_policy(lookup_map)?, y.extract_policy(lookup_map)?)?,
+                z.extract_policy(lookup_map)?,
+            )?,
             Terminal::OrB(a, b)
             | Terminal::OrD(a, b)
             | Terminal::OrC(a, b)
             | Terminal::OrI(a, b) => {
-                Policy::make_or(a.extract_policy(lookup_map), b.extract_policy(lookup_map))
+                Policy::make_or(a.extract_policy(lookup_map)?, b.extract_policy(lookup_map)?)?
             }
             Terminal::Thresh(k, nodes) => {
                 let mut threshold = *k;
                 let mapped: Vec<_> = nodes
                     .iter()
-                    .filter_map(|n| n.extract_policy(lookup_map))
+                    .map(|n| n.extract_policy(lookup_map))
+                    .collect::<Result<Vec<_>, _>>()?
+                    .into_iter()
+                    .filter_map(|x| x)
                     .collect();
 
                 if mapped.len() < nodes.len() {
                     threshold = match threshold.checked_sub(nodes.len() - mapped.len()) {
-                        None => return None,
+                        None => return Ok(None),
                         Some(x) => x,
                     };
                 }
 
-                Policy::make_thresh(mapped, threshold)
+                Policy::make_thresh(mapped, threshold)?
             }
-        }
+        })
     }
 }
 
 impl MiniscriptExtractPolicy for Descriptor<String> {
-    fn extract_policy(&self, lookup_map: &BTreeMap<String, Box<dyn Key>>) -> Option<Policy> {
+    fn extract_policy(
+        &self,
+        lookup_map: &BTreeMap<String, Box<dyn Key>>,
+    ) -> Result<Option<Policy>, Error> {
         match self {
             Descriptor::Pk(pubkey)
             | Descriptor::Pkh(pubkey)
             | Descriptor::Wpkh(pubkey)
-            | Descriptor::ShWpkh(pubkey) => signature_from_string(lookup_map.get(pubkey)),
+            | Descriptor::ShWpkh(pubkey) => Ok(signature_from_string(lookup_map.get(pubkey))),
             Descriptor::Bare(inner)
             | Descriptor::Sh(inner)
             | Descriptor::Wsh(inner)
-            | Descriptor::ShWsh(inner) => inner.extract_policy(lookup_map),
+            | Descriptor::ShWsh(inner) => Ok(inner.extract_policy(lookup_map)?),
         }
     }
 }
index cee77b8ddae5555f8f0bba38a5f20d4f963c9b6d..fa17199e432c13baa66bf411d2732c341cdd2488 100644 (file)
@@ -14,6 +14,7 @@ pub enum Error {
     DifferentTransactions,
 
     ChecksumMismatch,
+    DifferentDescriptorStructure,
 
     SpendingPolicyRequired,
     InvalidPolicyPathError(crate::descriptor::policy::PolicyError),
index 9c63b67bf2ad3be6d7e850be05deee7417acac3d..daba328918742e8870608da53b73e8d03d6208af 100644 (file)
@@ -9,7 +9,6 @@ use std::time::{Instant, SystemTime, UNIX_EPOCH};
 use bitcoin::blockdata::opcodes;
 use bitcoin::blockdata::script::Builder;
 use bitcoin::consensus::encode::serialize;
-use bitcoin::secp256k1::{All, Secp256k1};
 use bitcoin::util::bip32::{ChildNumber, DerivationPath};
 use bitcoin::util::psbt::PartiallySignedTransaction as PSBT;
 use bitcoin::{
@@ -45,8 +44,7 @@ pub struct Wallet<S: Read + Write, D: BatchDatabase> {
     network: Network,
 
     client: Option<RefCell<Client<S>>>,
-    database: RefCell<D>, // TODO: save descriptor checksum and check when loading
-    _secp: Secp256k1<All>,
+    database: RefCell<D>,
 }
 
 // offline actions, always available
@@ -72,13 +70,17 @@ where
                     ScriptType::Internal,
                     get_checksum(desc)?.as_bytes(),
                 )?;
-                Some(ExtendedDescriptor::from_str(desc)?)
+
+                let parsed = ExtendedDescriptor::from_str(desc)?;
+                if !parsed.same_structure(descriptor.as_ref()) {
+                    return Err(Error::DifferentDescriptorStructure);
+                }
+
+                Some(parsed)
             }
             None => None,
         };
 
-        // TODO: make sure that both descriptor have the same structure
-
         Ok(Wallet {
             descriptor,
             change_descriptor,
@@ -86,7 +88,6 @@ where
 
             client: None,
             database: RefCell::new(database),
-            _secp: Secp256k1::gen_new(),
         })
     }
 
@@ -132,13 +133,11 @@ where
         utxos: Option<Vec<OutPoint>>,
         unspendable: Option<Vec<OutPoint>>,
     ) -> Result<(PSBT, TransactionDetails), Error> {
-        let policy = self.descriptor.extract_policy().unwrap();
+        let policy = self.descriptor.extract_policy()?.unwrap();
         if policy.requires_path() && policy_path.is_none() {
             return Err(Error::SpendingPolicyRequired);
         }
-        let requirements = policy_path.map_or(Ok(Default::default()), |path| {
-            policy.get_requirements(&path)
-        })?;
+        let requirements = policy.get_requirements(&policy_path.unwrap_or(vec![]))?;
         debug!("requirements: {:?}", requirements);
 
         let mut tx = Transaction {
@@ -197,9 +196,14 @@ where
             input_witness_weight,
             fee_val,
         )?;
-        inputs
-            .iter_mut()
-            .for_each(|i| i.sequence = requirements.csv.unwrap_or(0xFFFFFFFF));
+        let n_sequence = if let Some(csv) = requirements.csv {
+            csv
+        } else if requirements.timelock.is_some() {
+            0xFFFFFFFE
+        } else {
+            0xFFFFFFFF
+        };
+        inputs.iter_mut().for_each(|i| i.sequence = n_sequence);
         tx.input.append(&mut inputs);
 
         // prepare the change output
@@ -300,9 +304,8 @@ where
         Ok((psbt, transaction_details))
     }
 
-    // TODO: define an enum for signing errors
-    pub fn sign(&self, mut psbt: PSBT) -> Result<(PSBT, bool), Error> {
-        let tx = &psbt.global.unsigned_tx;
+    // TODO: move down to the "internals"
+    fn add_hd_keypaths(&self, psbt: &mut PSBT) -> Result<(), Error> {
         let mut input_utxos = Vec::with_capacity(psbt.inputs.len());
         for n in 0..psbt.inputs.len() {
             input_utxos.push(psbt.get_utxo_for(n).clone());
@@ -339,6 +342,16 @@ where
             }
         }
 
+        Ok(())
+    }
+
+    // TODO: define an enum for signing errors
+    pub fn sign(&self, mut psbt: PSBT) -> Result<(PSBT, bool), Error> {
+        // this helps us doing our job later
+        self.add_hd_keypaths(&mut psbt)?;
+
+        let tx = &psbt.global.unsigned_tx;
+
         let mut signer = PSBTSigner::from_descriptor(&psbt.global.unsigned_tx, &self.descriptor)?;
         if let Some(desc) = &self.change_descriptor {
             let change_signer = PSBTSigner::from_descriptor(&psbt.global.unsigned_tx, desc)?;
@@ -480,9 +493,9 @@ where
 
     pub fn policies(&self, script_type: ScriptType) -> Result<Option<Policy>, Error> {
         match (script_type, self.change_descriptor.as_ref()) {
-            (ScriptType::External, _) => Ok(self.descriptor.extract_policy()),
+            (ScriptType::External, _) => Ok(self.descriptor.extract_policy()?),
             (ScriptType::Internal, None) => Ok(None),
-            (ScriptType::Internal, Some(desc)) => Ok(desc.extract_policy()),
+            (ScriptType::Internal, Some(desc)) => Ok(desc.extract_policy()?),
         }
     }
 
@@ -688,13 +701,17 @@ where
                     ScriptType::Internal,
                     get_checksum(desc)?.as_bytes(),
                 )?;
-                Some(ExtendedDescriptor::from_str(desc)?)
+
+                let parsed = ExtendedDescriptor::from_str(desc)?;
+                if !parsed.same_structure(descriptor.as_ref()) {
+                    return Err(Error::DifferentDescriptorStructure);
+                }
+
+                Some(parsed)
             }
             None => None,
         };
 
-        // TODO: make sure that both descriptor have the same structure
-
         Ok(Wallet {
             descriptor,
             change_descriptor,
@@ -702,7 +719,6 @@ where
 
             client: Some(RefCell::new(client)),
             database: RefCell::new(database),
-            _secp: Secp256k1::gen_new(),
         })
     }
 
@@ -944,7 +960,7 @@ where
                 .as_ref()
                 .unwrap()
                 .borrow_mut()
-                .batch_script_get_history(chunk.iter().collect::<Vec<_>>())?; // TODO: fix electrum client
+                .batch_script_get_history(chunk.iter())?;
 
             for (script, history) in chunk.into_iter().zip(call_result.into_iter()) {
                 trace!("received history for {:?}, size {}", script, history.len());