]> Untitled Git - bdk/commitdiff
[wallet] Add `force_non_witness_utxo()` to TxBuilder
authorAlekos Filini <alekos.filini@gmail.com>
Sat, 8 Aug 2020 10:06:40 +0000 (12:06 +0200)
committerAlekos Filini <alekos.filini@gmail.com>
Mon, 10 Aug 2020 15:18:15 +0000 (17:18 +0200)
src/blockchain/electrum.rs
src/blockchain/esplora.rs
src/blockchain/mod.rs
src/cli.rs
src/wallet/mod.rs
src/wallet/tx_builder.rs

index 5ac861519a0743508bef0f5148f79251b36e0861..8913feb8fe92673a6175943a8f8796a3c5590964 100644 (file)
@@ -68,7 +68,7 @@ impl OnlineBlockchain for ElectrumBlockchain {
             .map(|_| ())?)
     }
 
-    fn get_height(&self) -> Result<usize, Error> {
+    fn get_height(&self) -> Result<u32, Error> {
         // TODO: unsubscribe when added to the client, or is there a better call to use here?
 
         Ok(self
@@ -76,7 +76,7 @@ impl OnlineBlockchain for ElectrumBlockchain {
             .as_ref()
             .ok_or(Error::OfflineClient)?
             .block_headers_subscribe()
-            .map(|data| data.height)?)
+            .map(|data| data.height as u32)?)
     }
 
     fn estimate_fee(&self, target: usize) -> Result<FeeRate, Error> {
index 53dcb5b3862f249fc7f564b91da266622b44f6c4..10eb5cb4fef200ab37628a0eb4257ddf3cd814f0 100644 (file)
@@ -93,7 +93,7 @@ impl OnlineBlockchain for EsploraBlockchain {
             ._broadcast(tx))?)
     }
 
-    fn get_height(&self) -> Result<usize, Error> {
+    fn get_height(&self) -> Result<u32, Error> {
         Ok(await_or_block!(self
             .0
             .as_ref()
@@ -153,7 +153,7 @@ impl UrlClient {
         Ok(())
     }
 
-    async fn _get_height(&self) -> Result<usize, EsploraError> {
+    async fn _get_height(&self) -> Result<u32, EsploraError> {
         let req = self
             .client
             .get(&format!("{}/api/blocks/tip/height", self.url))
index 08e466948edcaf645f840c334626337126ac8010..da267335da9e43127b6e36987ee974a38555b249 100644 (file)
@@ -64,7 +64,7 @@ pub trait OnlineBlockchain: Blockchain {
     fn get_tx(&self, txid: &Txid) -> Result<Option<Transaction>, Error>;
     fn broadcast(&self, tx: &Transaction) -> Result<(), Error>;
 
-    fn get_height(&self) -> Result<usize, Error>;
+    fn get_height(&self) -> Result<u32, Error>;
     fn estimate_fee(&self, target: usize) -> Result<FeeRate, Error>;
 }
 
index 48f5e42b565ea3c3eb45fc528be2f3406b9098fd..164ae4a2b6ffe3fbb1fb4706ba8e8fd344e9a0a7 100644 (file)
@@ -326,8 +326,11 @@ where
             .map(|s| parse_addressee(s))
             .collect::<Result<Vec<_>, _>>()
             .map_err(|s| Error::Generic(s))?;
-        let mut tx_builder =
-            TxBuilder::from_addressees(addressees).send_all(sub_matches.is_present("send_all"));
+        let mut tx_builder = TxBuilder::from_addressees(addressees);
+
+        if sub_matches.is_present("send_all") {
+            tx_builder = tx_builder.send_all();
+        }
 
         if let Some(fee_rate) = sub_matches.value_of("fee_rate") {
             let fee_rate = f32::from_str(fee_rate).map_err(|s| Error::Generic(s.to_string()))?;
index b6cd853217e882c4bd4724803a162504a76c0250..cd2bfbc838842d917f715f8f3f5aefa7548e7849 100644 (file)
@@ -1,4 +1,5 @@
 use std::cell::RefCell;
+use std::collections::HashMap;
 use std::collections::{BTreeMap, HashSet};
 use std::ops::DerefMut;
 use std::str::FromStr;
@@ -236,6 +237,12 @@ where
             fee_amount,
         )?;
         let (mut txin, prev_script_pubkeys): (Vec<_>, Vec<_>) = txin.into_iter().unzip();
+        // map that allows us to lookup the prev_script_pubkey for a given previous_output
+        let prev_script_pubkeys = txin
+            .iter()
+            .zip(prev_script_pubkeys.into_iter())
+            .map(|(txin, script)| (txin.previous_output, script))
+            .collect::<HashMap<_, _>>();
 
         txin.iter_mut().for_each(|i| i.sequence = n_sequence);
         tx.input = txin;
@@ -285,12 +292,13 @@ where
         let mut psbt = PSBT::from_unsigned_tx(tx)?;
 
         // add metadata for the inputs
-        for ((psbt_input, prev_script), input) in psbt
+        for (psbt_input, input) in psbt
             .inputs
             .iter_mut()
-            .zip(prev_script_pubkeys.into_iter())
             .zip(psbt.global.unsigned_tx.input.iter())
         {
+            let prev_script = prev_script_pubkeys.get(&input.previous_output).unwrap();
+
             // Add sighash, default is obviously "ALL"
             psbt_input.sighash_type = builder.sighash.or(Some(SigHashType::All));
 
@@ -317,7 +325,8 @@ where
                 if derived_descriptor.is_witness() {
                     psbt_input.witness_utxo =
                         Some(prev_tx.output[prev_output.vout as usize].clone());
-                } else {
+                }
+                if !derived_descriptor.is_witness() || builder.force_non_witness_utxo {
                     psbt_input.non_witness_utxo = Some(prev_tx);
                 }
             }
@@ -535,7 +544,6 @@ where
                 n, input.previous_output, create_height, current_height
             );
 
-            // TODO: use height once we sync headers
             let satisfier =
                 PSBTSatisfier::new(&psbt.inputs[n], false, create_height, current_height);
 
@@ -778,17 +786,16 @@ where
         ))
     }
 
+    pub fn client(&self) -> &B {
+        &self.client
+    }
+
     #[maybe_async]
     pub fn broadcast(&self, tx: Transaction) -> Result<Txid, Error> {
         maybe_await!(self.client.broadcast(&tx))?;
 
         Ok(tx.txid())
     }
-
-    #[maybe_async]
-    pub fn estimate_fee(&self, target: usize) -> Result<FeeRate, Error> {
-        Ok(maybe_await!(self.client.estimate_fee(target))?)
-    }
 }
 
 #[cfg(test)]
index 21fc932e68807310a8fec2a106e380f0c32bc46a..ea34fbcffc254ffe69c56baf819532c78f8e51d1 100644 (file)
@@ -7,7 +7,6 @@ use super::coin_selection::{CoinSelectionAlgorithm, DefaultCoinSelectionAlgorith
 use super::utils::FeeRate;
 use crate::types::UTXO;
 
-// TODO: add a flag to ignore change outputs (make them unspendable)
 #[derive(Debug, Default)]
 pub struct TxBuilder<Cs: CoinSelectionAlgorithm> {
     pub(crate) addressees: Vec<(Address, u64)>,
@@ -22,6 +21,7 @@ pub struct TxBuilder<Cs: CoinSelectionAlgorithm> {
     pub(crate) rbf: Option<u32>,
     pub(crate) version: Version,
     pub(crate) change_policy: ChangeSpendPolicy,
+    pub(crate) force_non_witness_utxo: bool,
     pub(crate) coin_selection: Cs,
 }
 
@@ -46,8 +46,8 @@ impl<Cs: CoinSelectionAlgorithm> TxBuilder<Cs> {
         self
     }
 
-    pub fn send_all(mut self, send_all: bool) -> Self {
-        self.send_all = send_all;
+    pub fn send_all(mut self) -> Self {
+        self.send_all = true;
         self
     }
 
@@ -122,6 +122,16 @@ impl<Cs: CoinSelectionAlgorithm> TxBuilder<Cs> {
         self
     }
 
+    pub fn change_policy(mut self, change_policy: ChangeSpendPolicy) -> Self {
+        self.change_policy = change_policy;
+        self
+    }
+
+    pub fn force_non_witness_utxo(mut self) -> Self {
+        self.force_non_witness_utxo = true;
+        self
+    }
+
     pub fn coin_selection<P: CoinSelectionAlgorithm>(self, coin_selection: P) -> TxBuilder<P> {
         TxBuilder {
             addressees: self.addressees,
@@ -136,6 +146,7 @@ impl<Cs: CoinSelectionAlgorithm> TxBuilder<Cs> {
             rbf: self.rbf,
             version: self.version,
             change_policy: self.change_policy,
+            force_non_witness_utxo: self.force_non_witness_utxo,
             coin_selection,
         }
     }