]> Untitled Git - bdk/commitdiff
[wallet] Add explicit ordering for the signers
authorAlekos Filini <alekos.filini@gmail.com>
Mon, 17 Aug 2020 10:10:51 +0000 (12:10 +0200)
committerAlekos Filini <alekos.filini@gmail.com>
Sun, 30 Aug 2020 18:38:20 +0000 (20:38 +0200)
src/wallet/mod.rs
src/wallet/signer.rs

index 46a93b86d9c6f09f79467b793e413b9495a432c1..57d822add158f16d8eaf080c80b3951228e6d2b9 100644 (file)
@@ -24,7 +24,7 @@ pub mod tx_builder;
 pub mod utils;
 
 use address_validator::AddressValidator;
-use signer::{Signer, SignerId, SignersContainer};
+use signer::{Signer, SignerId, SignerOrdering, SignersContainer};
 use tx_builder::TxBuilder;
 use utils::{After, FeeRate, IsDust, Older};
 
@@ -142,6 +142,7 @@ where
         &mut self,
         script_type: ScriptType,
         id: SignerId<DescriptorPublicKey>,
+        ordering: SignerOrdering,
         signer: Arc<Box<dyn Signer>>,
     ) {
         let signers = match script_type {
@@ -149,7 +150,7 @@ where
             ScriptType::Internal => Arc::make_mut(&mut self.change_signers),
         };
 
-        signers.add_external(id, signer);
+        signers.add_external(id, ordering, signer);
     }
 
     pub fn add_address_validator(&mut self, validator: Arc<Box<dyn AddressValidator>>) {
@@ -575,15 +576,18 @@ where
         Ok((psbt, details))
     }
 
-    // TODO: define an enum for signing errors
     pub fn sign(&self, mut psbt: PSBT, assume_height: Option<u32>) -> Result<(PSBT, bool), Error> {
         // this helps us doing our job later
         self.add_input_hd_keypaths(&mut psbt)?;
 
-        for index in 0..psbt.inputs.len() {
-            self.signers.sign(&mut psbt, index)?;
-            if self.change_descriptor.is_some() {
-                self.change_signers.sign(&mut psbt, index)?;
+        for signer in self
+            .signers
+            .signers()
+            .iter()
+            .chain(self.change_signers.signers().iter())
+        {
+            for index in 0..psbt.inputs.len() {
+                signer.sign(&mut psbt, index)?;
             }
         }
 
index 687bc8c67077ce4a8a17e67b6b5e2a13542819aa..d94589c1d39e78f950716b76874d048cbaad06f5 100644 (file)
@@ -1,6 +1,8 @@
 use std::any::Any;
-use std::collections::HashMap;
+use std::cmp::Ordering;
+use std::collections::BTreeMap;
 use std::fmt;
+use std::ops::Bound::Included;
 use std::sync::Arc;
 
 use bitcoin::blockdata::opcodes;
@@ -150,9 +152,35 @@ impl Signer for PrivateKey {
     }
 }
 
+#[derive(Debug, Clone, PartialOrd, PartialEq, Ord, Eq)]
+pub struct SignerOrdering(pub usize);
+
+impl std::default::Default for SignerOrdering {
+    fn default() -> Self {
+        SignerOrdering(100)
+    }
+}
+
+#[derive(Debug, Clone)]
+struct SignersContainerKey<Pk: MiniscriptKey> {
+    id: SignerId<Pk>,
+    ordering: SignerOrdering,
+}
+
+impl<Pk: MiniscriptKey> From<(SignerId<Pk>, SignerOrdering)> for SignersContainerKey<Pk> {
+    fn from(tuple: (SignerId<Pk>, SignerOrdering)) -> Self {
+        SignersContainerKey {
+            id: tuple.0,
+            ordering: tuple.1,
+        }
+    }
+}
+
 /// Container for multiple signers
 #[derive(Debug, Default, Clone)]
-pub struct SignersContainer<Pk: MiniscriptKey>(HashMap<SignerId<Pk>, Arc<Box<dyn Signer>>>);
+pub struct SignersContainer<Pk: MiniscriptKey>(
+    BTreeMap<SignersContainerKey<Pk>, Arc<Box<dyn Signer>>>,
+);
 
 impl SignersContainer<DescriptorPublicKey> {
     pub fn as_key_map(&self) -> KeyMap {
@@ -190,10 +218,12 @@ impl From<KeyMap> for SignersContainer<DescriptorPublicKey> {
                             .public_key(&Secp256k1::signing_only())
                             .to_pubkeyhash(),
                     ),
+                    SignerOrdering::default(),
                     Arc::new(Box::new(private_key)),
                 ),
                 DescriptorSecretKey::XPrv(xprv) => container.add_external(
                     SignerId::from(xprv.root_fingerprint()),
+                    SignerOrdering::default(),
                     Arc::new(Box::new(xprv)),
                 ),
             };
@@ -206,7 +236,7 @@ impl From<KeyMap> for SignersContainer<DescriptorPublicKey> {
 impl<Pk: MiniscriptKey> SignersContainer<Pk> {
     /// Default constructor
     pub fn new() -> Self {
-        SignersContainer(HashMap::new())
+        SignersContainer(Default::default())
     }
 
     /// Adds an external signer to the container for the specified id. Optionally returns the
@@ -214,24 +244,43 @@ impl<Pk: MiniscriptKey> SignersContainer<Pk> {
     pub fn add_external(
         &mut self,
         id: SignerId<Pk>,
+        ordering: SignerOrdering,
         signer: Arc<Box<dyn Signer>>,
     ) -> Option<Arc<Box<dyn Signer>>> {
-        self.0.insert(id, signer)
+        self.0.insert((id, ordering).into(), signer)
     }
 
     /// Removes a signer from the container and returns it
-    pub fn remove(&mut self, id: SignerId<Pk>) -> Option<Arc<Box<dyn Signer>>> {
-        self.0.remove(&id)
+    pub fn remove(
+        &mut self,
+        id: SignerId<Pk>,
+        ordering: SignerOrdering,
+    ) -> Option<Arc<Box<dyn Signer>>> {
+        self.0.remove(&(id, ordering).into())
     }
 
     /// Returns the list of identifiers of all the signers in the container
     pub fn ids(&self) -> Vec<&SignerId<Pk>> {
-        self.0.keys().collect()
+        self.0
+            .keys()
+            .map(|SignersContainerKey { id, .. }| id)
+            .collect()
+    }
+
+    /// Returns the list of signers in the container, sorted by lowest to highest `ordering`
+    pub fn signers(&self) -> Vec<&Arc<Box<dyn Signer>>> {
+        self.0.values().collect()
     }
 
-    /// Finds the signer with a given id in the container
+    /// Finds the signer with lowest ordering for a given id in the container.
     pub fn find(&self, id: SignerId<Pk>) -> Option<&Arc<Box<dyn Signer>>> {
-        self.0.get(&id)
+        self.0
+            .range((
+                Included(&(id.clone(), SignerOrdering(0)).into()),
+                Included(&(id, SignerOrdering(usize::MAX)).into()),
+            ))
+            .map(|(_, v)| v)
+            .nth(0)
     }
 }
 
@@ -327,3 +376,23 @@ impl ComputeSighash for Segwitv0 {
         ))
     }
 }
+
+impl<Pk: MiniscriptKey> PartialOrd for SignersContainerKey<Pk> {
+    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+        Some(self.cmp(other))
+    }
+}
+
+impl<Pk: MiniscriptKey> Ord for SignersContainerKey<Pk> {
+    fn cmp(&self, other: &Self) -> Ordering {
+        self.ordering.cmp(&other.ordering)
+    }
+}
+
+impl<Pk: MiniscriptKey> PartialEq for SignersContainerKey<Pk> {
+    fn eq(&self, other: &Self) -> bool {
+        self.ordering == other.ordering
+    }
+}
+
+impl<Pk: MiniscriptKey> Eq for SignersContainerKey<Pk> {}