Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 91 additions & 11 deletions src/selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,16 +174,26 @@ impl core::fmt::Display for CreatePsbtError {
impl std::error::Error for CreatePsbtError {}

impl Selection {
/// Returns none if there is a mismatch of units in `locktimes`.
fn _accumulate_max_locktime(
/// Accumulates the maximum locktime from an iterator of input-required locktimes.
///
/// Returns the `fallback_locktime` if the locktimes iterator is empty, `Ok(lock_time)` with
/// the maximum locktime if all items share the same unit. Errors if there is a mismatch of
/// lock type units among the required locktimes.
fn accumulate_max_locktime(
locktimes: impl IntoIterator<Item = absolute::LockTime>,
) -> Option<absolute::LockTime> {
fallback_locktime: absolute::LockTime,
) -> Result<absolute::LockTime, CreatePsbtError> {
// Accumulate locktimes required by inputs. An input-vs-input unit mismatch is an error.
// The fallback is only used when it is compatible with the input requirements.
// If the fallback is a different unit from the required locktime it is
// intentionally ignored so that a height-based fallback does not conflict with a
// time-based CLTV requirement.
let mut acc = Option::<absolute::LockTime>::None;
for locktime in locktimes {
match &mut acc {
Some(acc) => {
if !acc.is_same_unit(locktime) {
return None;
return Err(CreatePsbtError::LockTypeMismatch);
}
if acc.is_implied_by(locktime) {
*acc = locktime;
Expand All @@ -192,7 +202,20 @@ impl Selection {
acc => *acc = Some(locktime),
};
}
acc
match acc {
// No required locktimes from inputs: use fallback directly.
None => Ok(fallback_locktime),
// Same unit as fallback: take the maximum of required and fallback.
Some(lock_time) if lock_time.is_same_unit(fallback_locktime) => {
if lock_time.is_implied_by(fallback_locktime) {
Ok(fallback_locktime)
} else {
Ok(lock_time)
}
}
// Fallback is a different unit: use required locktime and ignore fallback.
Some(lock_time) => Ok(lock_time),
}
}

/// Create PSBT.
Expand All @@ -209,13 +232,12 @@ impl Selection {
) -> Result<bitcoin::Psbt, CreatePsbtError> {
let mut tx = bitcoin::Transaction {
version: params.version,
lock_time: Self::_accumulate_max_locktime(
lock_time: Self::accumulate_max_locktime(
self.inputs
.iter()
.filter_map(|input| input.absolute_timelock())
.chain([params.fallback_locktime]),
)
.ok_or(CreatePsbtError::LockTypeMismatch)?,
.filter_map(|input| input.absolute_timelock()),
params.fallback_locktime,
)?,
input: self
.inputs
.iter()
Expand Down Expand Up @@ -315,7 +337,7 @@ mod tests {
const TEST_DESCRIPTOR_PK: &str = "[83737d5e/86h/1h/0h]tpubDDR5GgtoxS8fJyjjvdahN4VzV5DV6jtbcyvVXhEKq2XtpxjxBXmxH3r8QrNbQqHg4bJM1EGkxi7Pjfkgnui9jQWqS7kxHvX6rhUeriLDKxz/0/*";

#[test]
fn test_fallback_locktime() -> anyhow::Result<()> {
fn test_fallback_locktime_height() -> anyhow::Result<()> {
let abs_locktime = absolute::LockTime::from_consensus(100_000);
let secp = Secp256k1::new();
let pk = "032b0558078bec38694a84933d659303e2575dae7e91685911454115bfd64487e3";
Expand Down Expand Up @@ -389,6 +411,64 @@ mod tests {
Ok(())
}

/// Tests that a height-based fallback locktime is ignored when the input
/// requires a time-based (UNIX timestamp) CLTV, and that an explicit time-based
/// fallback greater than the requirement is respected.
#[test]
fn test_fallback_locktime_respects_lock_type() -> anyhow::Result<()> {
let time_locktime = absolute::LockTime::from_consensus(1_734_230_218);
let secp = Secp256k1::new();
let pk = "032b0558078bec38694a84933d659303e2575dae7e91685911454115bfd64487e3";
let desc_str = format!("wsh(and_v(v:pk({pk}),after({time_locktime})))");
let desc_pk: DescriptorPublicKey = pk.parse()?;
let (desc, _) = Descriptor::parse_descriptor(&secp, &desc_str)?;
let plan = desc
.at_derivation_index(0)?
.plan(&Assets::new().add(desc_pk).after(time_locktime))
.unwrap();

let prev_tx = Transaction {
version: transaction::Version::TWO,
lock_time: absolute::LockTime::ZERO,
input: vec![TxIn::default()],
output: vec![TxOut {
script_pubkey: desc.at_derivation_index(0)?.script_pubkey(),
value: Amount::ONE_BTC,
}],
};
let input = Input::from_prev_tx(plan, prev_tx, 0, None)?;

let selection = Selection {
inputs: vec![input],
outputs: vec![Output::with_descriptor(
desc.at_derivation_index(1)?,
Amount::from_sat(1000),
)],
};

// Default fallback is height 0 (block-height unit). It is incompatible with the
// time-based CLTV requirement, so it must be ignored.
let psbt = selection.create_psbt(PsbtParams::default())?;
assert_eq!(
psbt.unsigned_tx.lock_time, time_locktime,
"time-based CLTV requirement should be used; height-based fallback must be ignored",
);

// An explicit time-based fallback *greater* than the requirement should be respected.
let larger_time = absolute::LockTime::from_consensus(1_772_167_108);
assert!(larger_time > time_locktime);
let psbt = selection.create_psbt(PsbtParams {
fallback_locktime: larger_time,
..Default::default()
})?;
assert_eq!(
psbt.unsigned_tx.lock_time, larger_time,
"a larger time-based fallback should override the CLTV requirement",
);

Ok(())
}

pub fn setup_test_input(confirmation_height: u32) -> anyhow::Result<Input> {
let secp = Secp256k1::new();
let desc = Descriptor::parse_descriptor(&secp, TEST_DESCRIPTOR)
Expand Down
Loading