]> Untitled Git - bdk/commitdiff
Implementing review suggestions from afilini
authorRichard Ulrich <richi@paraeasy.ch>
Thu, 22 Oct 2020 07:11:58 +0000 (09:11 +0200)
committerRichard Ulrich <richi@paraeasy.ch>
Thu, 22 Oct 2020 07:11:58 +0000 (09:11 +0200)
src/error.rs
src/wallet/mod.rs
src/wallet/tx_builder.rs

index 6628cc0d54af923494fed31c92ad7ba14108801a..659641332e8a46f7d0a577044d58685293cafae3 100644 (file)
@@ -47,6 +47,9 @@ pub enum Error {
     FeeRateTooLow {
         required: crate::types::FeeRate,
     },
+    FeeTooLow {
+        required: u64,
+    },
 
     Key(crate::keys::KeyError),
 
index 8f9a3a066ee0bae0b77ec22c2563970d57970fc2..a5754e7093c944024b2a9847b1bf12c40f0b5c2d 100644 (file)
@@ -299,13 +299,20 @@ where
             output: vec![],
         };
 
-        let fee_rate = get_fee_rate(&builder.fee_policy);
+        let (fee_rate, mut fee_amount) = match builder
+            .fee_policy
+            .as_ref()
+            .unwrap_or(&FeePolicy::FeeRate(FeeRate::default()))
+        {
+            FeePolicy::FeeAmount(amount) => (FeeRate::from_sat_per_vb(0.0), *amount as f32),
+            FeePolicy::FeeRate(rate) => (*rate, 0.0),
+        };
+
         if builder.send_all && builder.recipients.len() != 1 {
             return Err(Error::SendAllMultipleOutputs);
         }
 
         // we keep it as a float while we accumulate it, and only round it at the end
-        let mut fee_amount: f32 = 0.0;
         let mut outgoing: u64 = 0;
         let mut received: u64 = 0;
 
@@ -394,13 +401,6 @@ where
 
         let mut fee_amount = fee_amount.ceil() as u64;
 
-        if builder.has_absolute_fee() {
-            fee_amount = match builder.fee_policy.as_ref().unwrap() {
-                FeePolicy::FeeAmount(amount) => *amount,
-                _ => fee_amount,
-            }
-        };
-
         let change_val = (selected_amount - outgoing).saturating_sub(fee_amount);
         if !builder.send_all && !change_val.is_dust() {
             let mut change_output = change_output.unwrap();
@@ -493,13 +493,28 @@ where
         // the new tx must "pay for its bandwidth"
         let vbytes = tx.get_weight() as f32 / 4.0;
         let required_feerate = FeeRate::from_sat_per_vb(details.fees as f32 / vbytes + 1.0);
-        let new_feerate = get_fee_rate(&builder.fee_policy);
-
-        if new_feerate < required_feerate && !builder.has_absolute_fee() {
-            return Err(Error::FeeRateTooLow {
-                required: required_feerate,
-            });
-        }
+        let new_feerate = match builder
+            .fee_policy
+            .as_ref()
+            .unwrap_or(&FeePolicy::FeeRate(FeeRate::default()))
+        {
+            FeePolicy::FeeAmount(amount) => {
+                if *amount < details.fees {
+                    return Err(Error::FeeTooLow {
+                        required: details.fees,
+                    });
+                }
+                FeeRate::from_sat_per_vb(0.0)
+            }
+            FeePolicy::FeeRate(rate) => {
+                if *rate < required_feerate {
+                    return Err(Error::FeeRateTooLow {
+                        required: required_feerate,
+                    });
+                }
+                *rate
+            }
+        };
 
         if builder.send_all && tx.output.len() > 1 {
             return Err(Error::SendAllMultipleOutputs);
@@ -630,14 +645,13 @@ where
         must_use_utxos.append(&mut original_utxos);
 
         let amount_needed = tx.output.iter().fold(0, |acc, out| acc + out.value);
-        let initial_fee = tx.get_weight() as f32 / 4.0 * new_feerate.as_sat_vb();
-        let initial_fee = if builder.has_absolute_fee() {
-            match builder.fee_policy.as_ref().unwrap() {
-                FeePolicy::FeeAmount(amount) => *amount as f32,
-                _ => initial_fee,
-            }
-        } else {
-            initial_fee
+        let initial_fee = match builder
+            .fee_policy
+            .as_ref()
+            .unwrap_or(&FeePolicy::FeeRate(FeeRate::default()))
+        {
+            FeePolicy::FeeAmount(amount) => *amount as f32,
+            FeePolicy::FeeRate(_) => tx.get_weight() as f32 / 4.0 * new_feerate.as_sat_vb(),
         };
 
         let coin_selection::CoinSelectionResult {
@@ -669,12 +683,6 @@ where
         details.sent = selected_amount;
 
         let mut fee_amount = fee_amount.ceil() as u64;
-        if builder.has_absolute_fee() {
-            fee_amount = match builder.fee_policy.as_ref().unwrap() {
-                FeePolicy::FeeAmount(amount) => *amount,
-                _ => fee_amount,
-            }
-        };
         let removed_output_fee_cost = (serialize(&removed_updatable_output).len() as f32
             * new_feerate.as_sat_vb())
         .ceil() as u64;
@@ -682,23 +690,14 @@ where
         let change_val = selected_amount - amount_needed - fee_amount;
         let change_val_after_add = change_val.saturating_sub(removed_output_fee_cost);
         if !builder.send_all && !change_val_after_add.is_dust() {
-            if builder.has_absolute_fee() {
-                removed_updatable_output.value = change_val_after_add + removed_output_fee_cost;
-                details.received += change_val_after_add + removed_output_fee_cost;
-            } else {
-                removed_updatable_output.value = change_val_after_add;
-                fee_amount += removed_output_fee_cost;
-                details.received += change_val_after_add;
-            }
+            removed_updatable_output.value = change_val_after_add;
+            fee_amount += removed_output_fee_cost;
+            details.received += change_val_after_add;
 
             tx.output.push(removed_updatable_output);
         } else if builder.send_all && !change_val_after_add.is_dust() {
-            if builder.has_absolute_fee() {
-                removed_updatable_output.value = change_val_after_add + removed_output_fee_cost;
-            } else {
-                removed_updatable_output.value = change_val_after_add;
-                fee_amount += removed_output_fee_cost;
-            }
+            removed_updatable_output.value = change_val_after_add;
+            fee_amount += removed_output_fee_cost;
 
             // send_all to our address
             if self.is_mine(&removed_updatable_output.script_pubkey)? {
@@ -1242,17 +1241,6 @@ where
     }
 }
 
-/// get the fee rate if specified or a default
-fn get_fee_rate(fee_policy: &Option<FeePolicy>) -> FeeRate {
-    if fee_policy.is_none() {
-        return FeeRate::default();
-    }
-    match fee_policy.as_ref().unwrap() {
-        FeePolicy::FeeRate(fr) => *fr,
-        _ => FeeRate::default(),
-    }
-}
-
 #[cfg(test)]
 mod test {
     use std::str::FromStr;
@@ -1716,6 +1704,52 @@ mod test {
             .unwrap();
 
         assert_eq!(details.fees, 100);
+        assert_eq!(psbt.global.unsigned_tx.output.len(), 1);
+        assert_eq!(
+            psbt.global.unsigned_tx.output[0].value,
+            50_000 - details.fees
+        );
+    }
+
+    #[test]
+    fn test_create_tx_absolute_zero_fee() {
+        let (wallet, _, _) = get_funded_wallet(get_test_wpkh());
+        let addr = wallet.get_new_address().unwrap();
+        let (psbt, details) = wallet
+            .create_tx(
+                TxBuilder::with_recipients(vec![(addr.script_pubkey(), 0)])
+                    .fee_absolute(0)
+                    .send_all(),
+            )
+            .unwrap();
+
+        assert_eq!(details.fees, 0);
+        assert_eq!(psbt.global.unsigned_tx.output.len(), 1);
+        assert_eq!(
+            psbt.global.unsigned_tx.output[0].value,
+            50_000 - details.fees
+        );
+    }
+
+    #[test]
+    #[should_panic(expected = "InsufficientFunds")]
+    fn test_create_tx_absolute_high_fee() {
+        let (wallet, _, _) = get_funded_wallet(get_test_wpkh());
+        let addr = wallet.get_new_address().unwrap();
+        let (psbt, details) = wallet
+            .create_tx(
+                TxBuilder::with_recipients(vec![(addr.script_pubkey(), 0)])
+                    .fee_absolute(60_000)
+                    .send_all(),
+            )
+            .unwrap();
+
+        assert_eq!(details.fees, 0);
+        assert_eq!(psbt.global.unsigned_tx.output.len(), 1);
+        assert_eq!(
+            psbt.global.unsigned_tx.output[0].value,
+            50_000 - details.fees
+        );
     }
 
     #[test]
@@ -2044,6 +2078,48 @@ mod test {
             .unwrap();
     }
 
+    #[test]
+    #[should_panic(expected = "FeeTooLow")]
+    fn test_bump_fee_low_fee() {
+        let (wallet, _, _) = get_funded_wallet(get_test_wpkh());
+        let addr = wallet.get_new_address().unwrap();
+        let (psbt, mut details) = wallet
+            .create_tx(
+                TxBuilder::with_recipients(vec![(addr.script_pubkey(), 25_000)]).enable_rbf(),
+            )
+            .unwrap();
+        let tx = psbt.extract_tx();
+        let txid = tx.txid();
+        // skip saving the utxos, we know they can't be used anyways
+        details.transaction = Some(tx);
+        wallet.database.borrow_mut().set_tx(&details).unwrap();
+
+        wallet
+            .bump_fee(&txid, TxBuilder::new().fee_absolute(10))
+            .unwrap();
+    }
+
+    #[test]
+    #[should_panic(expected = "FeeTooLow")]
+    fn test_bump_fee_zero_abs() {
+        let (wallet, _, _) = get_funded_wallet(get_test_wpkh());
+        let addr = wallet.get_new_address().unwrap();
+        let (psbt, mut details) = wallet
+            .create_tx(
+                TxBuilder::with_recipients(vec![(addr.script_pubkey(), 25_000)]).enable_rbf(),
+            )
+            .unwrap();
+        let tx = psbt.extract_tx();
+        let txid = tx.txid();
+        // skip saving the utxos, we know they can't be used anyways
+        details.transaction = Some(tx);
+        wallet.database.borrow_mut().set_tx(&details).unwrap();
+
+        wallet
+            .bump_fee(&txid, TxBuilder::new().fee_absolute(0))
+            .unwrap();
+    }
+
     #[test]
     fn test_bump_fee_reduce_change() {
         let (wallet, _, _) = get_funded_wallet(get_test_wpkh());
index 65560a0590b8a10900467dbaa8c85efd8177ba48..aca6fc6f69c6b5c4b38512d2a368a5d53225bf1f 100644 (file)
@@ -82,6 +82,12 @@ pub enum FeePolicy {
     FeeAmount(u64),
 }
 
+impl std::default::Default for FeePolicy {
+    fn default() -> Self {
+        FeePolicy::FeeRate(FeeRate::default_min_relay_fee())
+    }
+}
+
 // Unfortunately derive doesn't work with `PhantomData`: https://github.com/rust-lang/rust/issues/26925
 impl<D: Database, Cs: CoinSelectionAlgorithm<D>> Default for TxBuilder<D, Cs>
 where
@@ -315,14 +321,6 @@ impl<D: Database, Cs: CoinSelectionAlgorithm<D>> TxBuilder<D, Cs> {
             phantom: PhantomData,
         }
     }
-
-    /// Returns true if an absolute fee was specified
-    pub fn has_absolute_fee(&self) -> bool {
-        if self.fee_policy.is_none() {
-            return false;
-        };
-        matches!(self.fee_policy.as_ref().unwrap(), FeePolicy::FeeAmount(_))
-    }
 }
 
 /// Ordering of the transaction's inputs and outputs