]> Untitled Git - bdk/commitdiff
[wallet] Add RBF and custom versions in TxBuilder
authorAlekos Filini <alekos.filini@gmail.com>
Fri, 7 Aug 2020 14:30:19 +0000 (16:30 +0200)
committerAlekos Filini <alekos.filini@gmail.com>
Mon, 10 Aug 2020 15:18:11 +0000 (17:18 +0200)
src/wallet/mod.rs
src/wallet/tx_builder.rs

index fe1255a5277275639758d4e30e34afa68cac6128..fe0f2b9b869fb988036787b897617372cbbb1a5c 100644 (file)
@@ -135,9 +135,36 @@ where
             policy.get_requirements(builder.policy_path.as_ref().unwrap_or(&BTreeMap::new()))?;
         debug!("requirements: {:?}", requirements);
 
+        let version = match builder.version {
+            tx_builder::Version(0) => return Err(Error::Generic("Invalid version `0`".into())),
+            tx_builder::Version(1) if requirements.csv.is_some() => {
+                return Err(Error::Generic(
+                    "TxBuilder requested version `1`, but at least `2` is needed to use OP_CSV"
+                        .into(),
+                ))
+            }
+            tx_builder::Version(x) => x,
+        };
+
+        let lock_time = match builder.locktime {
+            None => requirements.timelock.unwrap_or(0),
+            Some(x) if requirements.timelock.is_none() => x,
+            Some(x) if requirements.timelock.unwrap() <= x => x,
+            Some(x) => return Err(Error::Generic(format!("TxBuilder requested timelock of `{}`, but at least `{}` is required to spend from this script", x, requirements.timelock.unwrap())))
+        };
+
+        let n_sequence = match (builder.rbf, requirements.csv) {
+            (None, Some(csv)) => csv,
+            (Some(rbf), Some(csv)) if rbf < csv => return Err(Error::Generic(format!("Cannot enable RBF with nSequence `{}`, since at least `{}` is required to spend with OP_CSV", rbf, csv))),
+            (None, _) if requirements.timelock.is_some() => 0xFFFFFFFE,
+            (Some(rbf), _) if rbf >= 0xFFFFFFFE => return Err(Error::Generic("Cannot enable RBF with anumber >= 0xFFFFFFFE".into())),
+            (Some(rbf), _) => rbf,
+            (None, _) => 0xFFFFFFFF,
+        };
+
         let mut tx = Transaction {
-            version: 2,
-            lock_time: requirements.timelock.unwrap_or(0),
+            version,
+            lock_time,
             input: vec![],
             output: vec![],
         };
@@ -206,11 +233,6 @@ where
         )?;
         let (mut txin, prev_script_pubkeys): (Vec<_>, Vec<_>) = txin.into_iter().unzip();
 
-        let n_sequence = match requirements.csv {
-            Some(csv) => csv,
-            _ if requirements.timelock.is_some() => 0xFFFFFFFE,
-            _ => 0xFFFFFFFF,
-        };
         txin.iter_mut().for_each(|i| i.sequence = n_sequence);
         tx.input = txin;
 
index 74c69468ccc9da9b563e464063901b936617eea8..1b43fbdb1a363c2f4059d41ea3080e0e01d61bb6 100644 (file)
@@ -18,6 +18,8 @@ pub struct TxBuilder<Cs: CoinSelectionAlgorithm> {
     pub(crate) sighash: Option<SigHashType>,
     pub(crate) ordering: TxOrdering,
     pub(crate) locktime: Option<u32>,
+    pub(crate) rbf: Option<u32>,
+    pub(crate) version: Version,
     pub(crate) coin_selection: Cs,
 }
 
@@ -92,6 +94,20 @@ impl<Cs: CoinSelectionAlgorithm> TxBuilder<Cs> {
         self
     }
 
+    pub fn enable_rbf(self) -> Self {
+        self.enable_rbf_with_sequence(0xFFFFFFFD)
+    }
+
+    pub fn enable_rbf_with_sequence(mut self, nsequence: u32) -> Self {
+        self.rbf = Some(nsequence);
+        self
+    }
+
+    pub fn version(mut self, version: u32) -> Self {
+        self.version = Version(version);
+        self
+    }
+
     pub fn coin_selection<P: CoinSelectionAlgorithm>(self, coin_selection: P) -> TxBuilder<P> {
         TxBuilder {
             addressees: self.addressees,
@@ -103,6 +119,8 @@ impl<Cs: CoinSelectionAlgorithm> TxBuilder<Cs> {
             sighash: self.sighash,
             ordering: self.ordering,
             locktime: self.locktime,
+            rbf: self.rbf,
+            version: self.version,
             coin_selection,
         }
     }
@@ -148,6 +166,16 @@ impl TxOrdering {
     }
 }
 
+// Helper type that wraps u32 and has a default value of 1
+#[derive(Debug)]
+pub(crate) struct Version(pub(crate) u32);
+
+impl Default for Version {
+    fn default() -> Self {
+        Version(1)
+    }
+}
+
 #[cfg(test)]
 mod test {
     const ORDERING_TEST_TX: &'static str = "0200000003c26f3eb7932f7acddc5ddd26602b77e7516079b03090a16e2c2f54\