]> Untitled Git - bdk/commitdiff
`get_checksum_bytes` now checks input data for checksum
author志宇 <hello@evanlinjin.me>
Thu, 29 Sep 2022 05:06:03 +0000 (13:06 +0800)
committer志宇 <hello@evanlinjin.me>
Thu, 29 Sep 2022 05:06:03 +0000 (13:06 +0800)
If `exclude_hash` is set, we split the input data, and if a checksum
already existed within the original data, we check the calculated
checksum against the original checksum.

Additionally, the implementation of `IntoWalletDescriptor` for `&str`
has been refactored for clarity.

src/descriptor/checksum.rs
src/descriptor/mod.rs
src/wallet/mod.rs

index 5ed1151bd8b9ae0d143f8e812b79460310e996a6..8dfdac49b44608b91c85e6cf5bc51fcf5174e458 100644 (file)
@@ -41,12 +41,21 @@ fn poly_mod(mut c: u64, val: u64) -> u64 {
     c
 }
 
-/// Computes the checksum bytes of a descriptor
-pub fn get_checksum_bytes(desc: &str) -> Result<[u8; 8], DescriptorError> {
+/// Computes the checksum bytes of a descriptor.
+/// `exclude_hash = true` ignores all data after the first '#' (inclusive).
+pub fn get_checksum_bytes(mut desc: &str, exclude_hash: bool) -> Result<[u8; 8], DescriptorError> {
     let mut c = 1;
     let mut cls = 0;
     let mut clscount = 0;
 
+    let mut original_checksum = None;
+    if exclude_hash {
+        if let Some(split) = desc.split_once('#') {
+            desc = split.0;
+            original_checksum = Some(split.1);
+        }
+    }
+
     for ch in desc.as_bytes() {
         let pos = INPUT_CHARSET
             .iter()
@@ -72,13 +81,20 @@ pub fn get_checksum_bytes(desc: &str) -> Result<[u8; 8], DescriptorError> {
         checksum[j] = CHECKSUM_CHARSET[((c >> (5 * (7 - j))) & 31) as usize];
     }
 
+    // if input data already had a checksum, check calculated checksum against original checksum
+    if let Some(original_checksum) = original_checksum {
+        if original_checksum.as_bytes() != &checksum {
+            return Err(DescriptorError::InvalidDescriptorChecksum);
+        }
+    }
+
     Ok(checksum)
 }
 
 /// Compute the checksum of a descriptor
 pub fn get_checksum(desc: &str) -> Result<String, DescriptorError> {
     // unsafe is okay here as the checksum only uses bytes in `CHECKSUM_CHARSET`
-    get_checksum_bytes(desc).map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) })
+    get_checksum_bytes(desc, true).map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) })
 }
 
 #[cfg(test)]
index 802ccd19ca2ababa877b4e1552058581eac2763b..7c51d27fc912b4eca5919b07d8b38315da135513 100644 (file)
@@ -40,6 +40,7 @@ pub mod policy;
 pub mod template;
 
 pub use self::checksum::get_checksum;
+use self::checksum::get_checksum_bytes;
 pub use self::derived::{AsDerived, DerivedDescriptorKey};
 pub use self::error::Error as DescriptorError;
 pub use self::policy::Policy;
@@ -84,19 +85,15 @@ impl IntoWalletDescriptor for &str {
         secp: &SecpCtx,
         network: Network,
     ) -> Result<(ExtendedDescriptor, KeyMap), DescriptorError> {
-        let descriptor = if self.contains('#') {
-            let parts: Vec<&str> = self.splitn(2, '#').collect();
-            if !get_checksum(parts[0])
-                .ok()
-                .map(|computed| computed == parts[1])
-                .unwrap_or(false)
-            {
-                return Err(DescriptorError::InvalidDescriptorChecksum);
+        let descriptor = match self.split_once('#') {
+            Some((desc, original_checksum)) => {
+                let checksum = get_checksum_bytes(desc, false)?;
+                if original_checksum.as_bytes() != &checksum {
+                    return Err(DescriptorError::InvalidDescriptorChecksum);
+                }
+                desc
             }
-
-            parts[0]
-        } else {
-            self
+            None => self,
         };
 
         ExtendedDescriptor::parse_descriptor(secp, descriptor)?
index 2e3d9fdffbdab0081915082ba72ab5fb06825b8f..776e1740a1afab9ae0c160a974ca5f67afbff9f6 100644 (file)
@@ -1943,15 +1943,10 @@ pub(crate) mod test {
         let (wallet, _, _) = get_funded_wallet(get_test_wpkh());
         let checksum = wallet.descriptor_checksum(KeychainKind::External);
         assert_eq!(checksum.len(), 8);
-
-        let raw_descriptor = wallet
-            .descriptor
-            .to_string()
-            .split_once('#')
-            .unwrap()
-            .0
-            .to_string();
-        assert_eq!(get_checksum(&raw_descriptor).unwrap(), checksum);
+        assert_eq!(
+            get_checksum(&wallet.descriptor.to_string()).unwrap(),
+            checksum
+        );
     }
 
     #[test]