diff --git a/src/selection.rs b/src/selection.rs index 962c686..4306ba6 100644 --- a/src/selection.rs +++ b/src/selection.rs @@ -174,16 +174,26 @@ impl core::fmt::Display for CreatePsbtError { impl std::error::Error for CreatePsbtError {} impl Selection { - /// Returns none if there is a mismatch of units in `locktimes`. - fn _accumulate_max_locktime( + /// Accumulates the maximum locktime from an iterator of input-required locktimes. + /// + /// Returns the `fallback_locktime` if the locktimes iterator is empty, `Ok(lock_time)` with + /// the maximum locktime if all items share the same unit. Errors if there is a mismatch of + /// lock type units among the required locktimes. + fn accumulate_max_locktime( locktimes: impl IntoIterator, - ) -> Option { + fallback_locktime: absolute::LockTime, + ) -> Result { + // Accumulate locktimes required by inputs. An input-vs-input unit mismatch is an error. + // The fallback is only used when it is compatible with the input requirements. + // If the fallback is a different unit from the required locktime it is + // intentionally ignored so that a height-based fallback does not conflict with a + // time-based CLTV requirement. let mut acc = Option::::None; for locktime in locktimes { match &mut acc { Some(acc) => { if !acc.is_same_unit(locktime) { - return None; + return Err(CreatePsbtError::LockTypeMismatch); } if acc.is_implied_by(locktime) { *acc = locktime; @@ -192,7 +202,20 @@ impl Selection { acc => *acc = Some(locktime), }; } - acc + match acc { + // No required locktimes from inputs: use fallback directly. + None => Ok(fallback_locktime), + // Same unit as fallback: take the maximum of required and fallback. + Some(lock_time) if lock_time.is_same_unit(fallback_locktime) => { + if lock_time.is_implied_by(fallback_locktime) { + Ok(fallback_locktime) + } else { + Ok(lock_time) + } + } + // Fallback is a different unit: use required locktime and ignore fallback. + Some(lock_time) => Ok(lock_time), + } } /// Create PSBT. @@ -209,13 +232,12 @@ impl Selection { ) -> Result { let mut tx = bitcoin::Transaction { version: params.version, - lock_time: Self::_accumulate_max_locktime( + lock_time: Self::accumulate_max_locktime( self.inputs .iter() - .filter_map(|input| input.absolute_timelock()) - .chain([params.fallback_locktime]), - ) - .ok_or(CreatePsbtError::LockTypeMismatch)?, + .filter_map(|input| input.absolute_timelock()), + params.fallback_locktime, + )?, input: self .inputs .iter() @@ -315,7 +337,7 @@ mod tests { const TEST_DESCRIPTOR_PK: &str = "[83737d5e/86h/1h/0h]tpubDDR5GgtoxS8fJyjjvdahN4VzV5DV6jtbcyvVXhEKq2XtpxjxBXmxH3r8QrNbQqHg4bJM1EGkxi7Pjfkgnui9jQWqS7kxHvX6rhUeriLDKxz/0/*"; #[test] - fn test_fallback_locktime() -> anyhow::Result<()> { + fn test_fallback_locktime_height() -> anyhow::Result<()> { let abs_locktime = absolute::LockTime::from_consensus(100_000); let secp = Secp256k1::new(); let pk = "032b0558078bec38694a84933d659303e2575dae7e91685911454115bfd64487e3"; @@ -389,6 +411,64 @@ mod tests { Ok(()) } + /// Tests that a height-based fallback locktime is ignored when the input + /// requires a time-based (UNIX timestamp) CLTV, and that an explicit time-based + /// fallback greater than the requirement is respected. + #[test] + fn test_fallback_locktime_respects_lock_type() -> anyhow::Result<()> { + let time_locktime = absolute::LockTime::from_consensus(1_734_230_218); + let secp = Secp256k1::new(); + let pk = "032b0558078bec38694a84933d659303e2575dae7e91685911454115bfd64487e3"; + let desc_str = format!("wsh(and_v(v:pk({pk}),after({time_locktime})))"); + let desc_pk: DescriptorPublicKey = pk.parse()?; + let (desc, _) = Descriptor::parse_descriptor(&secp, &desc_str)?; + let plan = desc + .at_derivation_index(0)? + .plan(&Assets::new().add(desc_pk).after(time_locktime)) + .unwrap(); + + let prev_tx = Transaction { + version: transaction::Version::TWO, + lock_time: absolute::LockTime::ZERO, + input: vec![TxIn::default()], + output: vec![TxOut { + script_pubkey: desc.at_derivation_index(0)?.script_pubkey(), + value: Amount::ONE_BTC, + }], + }; + let input = Input::from_prev_tx(plan, prev_tx, 0, None)?; + + let selection = Selection { + inputs: vec![input], + outputs: vec![Output::with_descriptor( + desc.at_derivation_index(1)?, + Amount::from_sat(1000), + )], + }; + + // Default fallback is height 0 (block-height unit). It is incompatible with the + // time-based CLTV requirement, so it must be ignored. + let psbt = selection.create_psbt(PsbtParams::default())?; + assert_eq!( + psbt.unsigned_tx.lock_time, time_locktime, + "time-based CLTV requirement should be used; height-based fallback must be ignored", + ); + + // An explicit time-based fallback *greater* than the requirement should be respected. + let larger_time = absolute::LockTime::from_consensus(1_772_167_108); + assert!(larger_time > time_locktime); + let psbt = selection.create_psbt(PsbtParams { + fallback_locktime: larger_time, + ..Default::default() + })?; + assert_eq!( + psbt.unsigned_tx.lock_time, larger_time, + "a larger time-based fallback should override the CLTV requirement", + ); + + Ok(()) + } + pub fn setup_test_input(confirmation_height: u32) -> anyhow::Result { let secp = Secp256k1::new(); let desc = Descriptor::parse_descriptor(&secp, TEST_DESCRIPTOR)