diff --git a/components/clarity-static-cost/src/static_cost/cost_analysis.rs b/components/clarity-static-cost/src/static_cost/cost_analysis.rs index 358e71c24..9779ba73e 100644 --- a/components/clarity-static-cost/src/static_cost/cost_analysis.rs +++ b/components/clarity-static-cost/src/static_cost/cost_analysis.rs @@ -9,8 +9,8 @@ use clarity::vm::costs::ExecutionCost; use clarity::vm::functions::NativeFunctions; use clarity::vm::representations::{ClarityName, SymbolicExpression, SymbolicExpressionType}; use clarity::vm::types::{ - parse_name_type_pairs, PrincipalData, QualifiedContractIdentifier, SequenceSubtype, - TupleTypeSignature, TypeSignature, TypeSignatureExt, + parse_name_type_pairs, ListTypeData, PrincipalData, QualifiedContractIdentifier, + SequenceSubtype, StringSubtype, TupleTypeSignature, TypeSignature, TypeSignatureExt, }; use clarity::vm::variables::lookup_reserved_variable; use clarity::vm::{ClarityVersion, Value}; @@ -98,6 +98,9 @@ pub struct UserArgumentsContext { pub map_types: HashMap, /// Map from data-var name to its value type, pre-populated from define-data-var declarations pub data_var_types: HashMap, + /// Map from argument name to a known-constant value (call-site narrowing). + /// Treated as a hint, never required. + pub known_values: HashMap, } impl UserArgumentsContext { @@ -137,6 +140,14 @@ impl UserArgumentsContext { pub fn get_data_var_type(&self, name: &ClarityName) -> Option<&TypeSignature> { self.data_var_types.get(name) } + + pub fn add_known_value(&mut self, name: ClarityName, value: Value) { + self.known_values.insert(name, value); + } + + pub fn get_known_value(&self, name: &ClarityName) -> Option<&Value> { + self.known_values.get(name) + } } /// A type to track summed execution costs for different paths @@ -470,6 +481,15 @@ pub fn static_cost_tree_from_ast( } } } + // Build function_defs map for call-site narrowing + let function_defs: HashMap = exprs + .iter() + .filter_map(|expr| { + let function_name = extract_function_name(expr)?; + let list = expr.match_list()?; + Some((function_name, list)) + }) + .collect(); // second pass computes the cost for expr in exprs { if let Some(function_name) = extract_function_name(expr) { @@ -477,6 +497,7 @@ pub fn static_cost_tree_from_ast( expr, &user_args, &costs_map, + &function_defs, clarity_version, epoch, env, @@ -485,8 +506,9 @@ pub fn static_cost_tree_from_ast( )?; // Compute static cost for this function so subsequent calls can look it up. // Include the overhead costs so that callers get the full cost - // of invoking this function. Note: LookupFunction is NOT added here because - // the dynamic VM does not charge it for user-defined function calls. + // of invoking this function. Note: LookupFunction is NOT added here; + // it is added at the call site instead (in build_listlike_cost_analysis_tree + // for direct calls, and via fn_lookup_cost for map/filter/fold). let mut sc: StaticCost = calculate_total_cost_with_branching(&cost_analysis_tree).into(); let overhead = compute_function_overhead_costs(None, &function_name, exprs, epoch); @@ -599,6 +621,7 @@ pub fn build_cost_analysis_tree( expr: &SymbolicExpression, user_args: &UserArgumentsContext, cost_map: &HashMap>, + function_defs: &HashMap, clarity_version: &ClarityVersion, epoch: StacksEpochId, env: &mut ExecutionState, @@ -614,8 +637,11 @@ pub fn build_cost_analysis_tree( list, user_args, cost_map, + function_defs, clarity_version, epoch, + None, + None, env, invoke_ctx, )?; @@ -625,6 +651,7 @@ pub fn build_cost_analysis_tree( list, user_args, cost_map, + function_defs, clarity_version, epoch, env, @@ -638,6 +665,7 @@ pub fn build_cost_analysis_tree( list, user_args, cost_map, + function_defs, clarity_version, epoch, env, @@ -695,7 +723,9 @@ pub fn build_cost_analysis_tree( // For now, we use let_depth which tracks let-binding depth type_sig .as_ref() - .map(|sig| calculate_variable_lookup_cost_from_type(sig, let_depth, epoch)) + .map(|sig| { + calculate_variable_lookup_cost_from_type(sig, let_depth, epoch, false) + }) .unwrap_or(StaticCost::ZERO) } }; @@ -705,7 +735,7 @@ pub fn build_cost_analysis_tree( let final_cost = if let CostExprNode::UserArgument(ref arg_name, _) = expr_node { // Get the type from user_args and calculate cost from it if let Some(arg_type) = user_args.get_argument_type(arg_name) { - calculate_variable_lookup_cost_from_type(arg_type, let_depth, epoch) + calculate_variable_lookup_cost_from_type(arg_type, let_depth, epoch, false) } else { cost } @@ -732,14 +762,25 @@ pub fn build_cost_analysis_tree( } } -/// Calculate variable lookup cost from a TypeSignature +/// Calculate variable lookup cost from a TypeSignature. +/// +/// `exact` controls whether the type fully determines the runtime value size: +/// - `false` (declared types): the actual value size is unknown, so min uses +/// `min_size()` (e.g. empty list) and max uses `size()` (e.g. full-length list). +/// - `true` (narrowed call sites): the type is fully determined by the caller, +/// so `size()` is used for both min and max. fn calculate_variable_lookup_cost_from_type( type_sig: &TypeSignature, let_depth: u64, epoch: StacksEpochId, + exact: bool, ) -> StaticCost { let type_size = u64::from(type_sig.size().unwrap_or(0)); - let type_min_size = u64::from(type_sig.min_size().unwrap_or(0)); + let type_min_size = if exact { + type_size + } else { + u64::from(type_sig.min_size().unwrap_or(0)) + }; let mut variable_size_cost = ClarityCostFunction::LookupVariableSize .eval_for_epoch(type_size, epoch) @@ -822,6 +863,10 @@ fn infer_type_from_expression_with_args( if let Some(tuple_type) = infer_tuple_type_from_tuplecons(exprs, user_args, epoch) { return Ok(tuple_type); } + // Try ListCons inference: (list elem1 elem2 ...) + if let Some(list_type) = infer_type_from_listcons(exprs, user_args, epoch) { + return Ok(list_type); + } infer_type_from_expression(expr, epoch) } _ => infer_type_from_expression(expr, epoch), @@ -857,6 +902,28 @@ fn infer_tuple_type_from_tuplecons( .map(TypeSignature::TupleType) } +/// Try to infer a `ListType` from a ListCons expression like `(list u1 u2)`. +/// Returns the list type with max_len equal to the number of elements. +fn infer_type_from_listcons( + exprs: &[SymbolicExpression], + user_args: &UserArgumentsContext, + epoch: StacksEpochId, +) -> Option { + if exprs.first()?.match_atom()?.as_str() != "list" { + return None; + } + let elements = &exprs[1..]; + if elements.is_empty() { + return None; + } + // Infer the element type from the first element + let elem_type = infer_type_from_expression_with_args(&elements[0], user_args, epoch).ok()?; + let list_type = ListTypeData::new_list(elem_type, elements.len() as u32).ok()?; + Some(TypeSignature::SequenceType(SequenceSubtype::ListType( + list_type, + ))) +} + /// Infer type from a SymbolicExpression by examining its structure. /// This is a fallback when type_map is not available. pub(crate) fn infer_type_from_expression( @@ -910,11 +977,11 @@ fn multiply_cost(cost: &mut ExecutionCost, factor: u64) { fn get_map_filter_fold_list_max_len( list_exprs: &[SymbolicExpression], user_args: &UserArgumentsContext, + epoch: StacksEpochId, ) -> u64 { for expr in list_exprs { - if let Some(TypeSignature::SequenceType(SequenceSubtype::ListType(list_data))) = expr - .match_atom() - .and_then(|name| user_args.get_argument_type(name)) + if let Ok(TypeSignature::SequenceType(SequenceSubtype::ListType(list_data))) = + infer_type_from_expression_with_args(expr, user_args, epoch) { return list_data.get_max_len() as u64; } @@ -922,13 +989,242 @@ fn get_map_filter_fold_list_max_len( 1 } +/// Try to resolve an expression to a constant boolean known at static-analysis time. +/// Used for folding `if` branches when the condition is constant. +/// Returns `None` for any expression we cannot statically resolve to a literal bool. +fn try_resolve_constant_bool( + expr: &SymbolicExpression, + user_args: &UserArgumentsContext, +) -> Option { + match &expr.expr { + SymbolicExpressionType::LiteralValue(Value::Bool(b)) + | SymbolicExpressionType::AtomValue(Value::Bool(b)) => Some(*b), + SymbolicExpressionType::Atom(name) => match name.as_str() { + "true" => Some(true), + "false" => Some(false), + _ => match user_args.get_known_value(name)? { + Value::Bool(b) => Some(*b), + _ => None, + }, + }, + SymbolicExpressionType::List(list) => { + // Recognise `(not )` + let head = list.first()?.match_atom()?; + if head.as_str() == "not" && list.len() == 2 { + Some(!try_resolve_constant_bool(&list[1], user_args)?) + } else { + None + } + } + _ => None, + } +} + +/// Check if `actual` is a narrower type than `declared` (e.g., shorter list max length, +/// shorter buffer/string). +fn is_narrower_type(actual: &TypeSignature, declared: &TypeSignature) -> bool { + match (actual, declared) { + ( + TypeSignature::SequenceType(SequenceSubtype::ListType(actual_list)), + TypeSignature::SequenceType(SequenceSubtype::ListType(declared_list)), + ) => actual_list.get_max_len() < declared_list.get_max_len(), + ( + TypeSignature::SequenceType(SequenceSubtype::BufferType(actual_len)), + TypeSignature::SequenceType(SequenceSubtype::BufferType(declared_len)), + ) => u32::from(actual_len) < u32::from(declared_len), + ( + TypeSignature::SequenceType(SequenceSubtype::StringType(StringSubtype::ASCII( + actual_len, + ))), + TypeSignature::SequenceType(SequenceSubtype::StringType(StringSubtype::ASCII( + declared_len, + ))), + ) => u32::from(actual_len) < u32::from(declared_len), + ( + TypeSignature::SequenceType(SequenceSubtype::StringType(StringSubtype::UTF8( + actual_len, + ))), + TypeSignature::SequenceType(SequenceSubtype::StringType(StringSubtype::UTF8( + declared_len, + ))), + ) => u32::from(actual_len) < u32::from(declared_len), + _ => false, + } +} + +/// Try to compute a narrowed cost for a user-function call when the actual +/// arguments have tighter types than the declared parameters. For example, +/// when calling `(add-many-64 (list u1))` where `add-many-64` declares +/// `(ns (list 64 uint))`, the list is known to be length 1 instead of up to 64. +/// +/// Returns `Some(narrowed_cost)` when narrowing is possible, `None` otherwise. +/// Walk the cost tree and recalculate UserArgument lookup costs using exact mode +/// (size() for both min and max). This is used after re-analyzing a function body +/// with narrowed argument types, where we know the actual types precisely. +/// Tracks `let_depth` by incrementing when entering `Let` or `If` nodes, mirroring +/// the depth tracking in `build_cost_analysis_tree`. +fn fix_user_arg_costs_exact(node: &mut CostAnalysisNode, let_depth: u64, epoch: StacksEpochId) { + if let CostExprNode::UserArgument(_, ref arg_type) = node.expr { + node.cost = calculate_variable_lookup_cost_from_type(arg_type, let_depth, epoch, true); + } + let child_depth = match &node.expr { + CostExprNode::NativeFunction(NativeFunctions::Let | NativeFunctions::If) => let_depth + 1, + _ => let_depth, + }; + for child in &mut node.children { + fix_user_arg_costs_exact(child, child_depth, epoch); + } +} + +fn try_narrow_user_function_cost( + fn_name: &str, + call_args: &[SymbolicExpression], + caller_user_args: &UserArgumentsContext, + function_defs: &HashMap, + cost_map: &HashMap>, + clarity_version: &ClarityVersion, + epoch: StacksEpochId, + env: &mut ExecutionState, + invoke_ctx: &InvocationContext, +) -> Option { + let fn_def_list = function_defs.get(fn_name)?; + let signature = fn_def_list.get(1)?.match_list()?; + + // Extract declared parameter types from the function signature + let mut free_tracker = clarity::vm::costs::LimitedCostTracker::new_free(); + let params: Vec<(ClarityName, TypeSignature)> = signature + .iter() + .skip(1) + .filter_map(|arg_expr| { + let arg_list = arg_expr.match_list()?; + if arg_list.len() != 2 { + return None; + } + let name = arg_list[0].match_atom()?.clone(); + let arg_type = + TypeSignature::parse_type_repr(epoch, &arg_list[1], &mut free_tracker).ok()?; + Some((name, arg_type)) + }) + .collect(); + + // Infer actual argument types from the call site and check if any are narrower. + // Also extract known constant values from literal arguments or by propagating + // from the caller's known_values when an arg is itself an atom referring to + // a known-constant parameter. + let mut any_narrower_or_known = false; + let mut narrowed_types: Vec = Vec::with_capacity(params.len()); + let mut known_values: Vec> = Vec::with_capacity(params.len()); + for (i, (_param_name, declared_type)) in params.iter().enumerate() { + let call_arg = call_args.get(i); + + // Type narrowing + let narrowed_type = call_arg + .and_then(|arg| infer_type_from_expression_with_args(arg, caller_user_args, epoch).ok()) + .filter(|actual_type| is_narrower_type(actual_type, declared_type)); + if let Some(t) = narrowed_type { + narrowed_types.push(t); + any_narrower_or_known = true; + } else { + narrowed_types.push(declared_type.clone()); + } + + // Known value extraction. Literals (`u1`, `0x01`, `"hi"`) come through as + // AtomValue/LiteralValue. Constant atoms `true`/`false` are reserved + // variables parsed as plain Atom. Caller-side parameters with a known + // value are also propagated. + let known = call_arg.and_then(|arg| { + arg.match_atom_value() + .or_else(|| arg.match_literal_value()) + .cloned() + .or_else(|| match arg.match_atom()?.as_str() { + "true" => Some(Value::Bool(true)), + "false" => Some(Value::Bool(false)), + name => caller_user_args.get_known_value(&name.into()).cloned(), + }) + }); + if known.is_some() { + any_narrower_or_known = true; + } + known_values.push(known); + } + + if !any_narrower_or_known { + return None; + } + + // Re-analyze the function body with narrowed argument types and known values + let (_, mut narrowed_tree) = build_function_definition_cost_analysis_tree( + fn_def_list, + caller_user_args, + cost_map, + function_defs, + clarity_version, + epoch, + Some(&narrowed_types), + Some(&known_values), + env, + invoke_ctx, + ) + .ok()?; + + // Fix UserArgument lookup costs to be exact in body subtrees only. + // Top-level UserArgument children are parameter declarations which don't + // incur variable lookup costs — only references within the body do. + for child in &mut narrowed_tree.children { + if !matches!(child.expr, CostExprNode::UserArgument(..)) { + fix_user_arg_costs_exact(child, 0, epoch); + } + } + + let mut sc: StaticCost = super::calculate_total_cost_with_branching(&narrowed_tree).into(); + + // Add overhead (UserFunctionApplication + InnerTypeCheckCost per declared param). + let (arg_count, _) = extract_function_signature_from_list(fn_def_list).unwrap_or((0, &[])); + + let application_cost = ClarityCostFunction::UserFunctionApplication + .eval_for_epoch(arg_count as u64, epoch) + .unwrap_or(ExecutionCost::ZERO); + super::saturating_add_cost(&mut sc.min, &application_cost); + super::saturating_add_cost(&mut sc.max, &application_cost); + + // Use narrowed types for InnerTypeCheckCost. Since we know the exact argument types + // at this call site, we use size() for both min and max. In Epoch33+ the VM charges + // based on the actual argument's Value::size() which equals TypeSignature::type_of(arg).size(), + // matching our narrowed type's size(). + for narrowed_type in &narrowed_types { + let type_check_cost = ClarityCostFunction::InnerTypeCheckCost + .eval_for_epoch(u64::from(narrowed_type.size().unwrap_or(0)), epoch) + .unwrap_or(ExecutionCost::ZERO); + super::saturating_add_cost(&mut sc.min, &type_check_cost); + super::saturating_add_cost(&mut sc.max, &type_check_cost); + } + + Some(sc) +} + +/// Extract function parameter count and signature args from a function definition list. +/// The list is the inner contents of a `define-public`/`define-private`/`define-read-only`. +fn extract_function_signature_from_list( + list: &[SymbolicExpression], +) -> Option<(usize, &[SymbolicExpression])> { + let signature = list.get(1)?.match_list()?; + if signature.len() <= 1 { + return Some((0, &[])); + } + let args = &signature[1..]; + Some((args.len(), args)) +} + /// Build an expression tree for function definitions like (define-public (foo (a u64)) (ok a)) fn build_function_definition_cost_analysis_tree( list: &[SymbolicExpression], outer_user_args: &UserArgumentsContext, cost_map: &HashMap>, + function_defs: &HashMap, clarity_version: &ClarityVersion, epoch: StacksEpochId, + arg_type_overrides: Option<&[TypeSignature]>, + arg_known_values: Option<&[Option]>, env: &mut ExecutionState, invoke_ctx: &InvocationContext, ) -> Result<(String, CostAnalysisNode), StaticCostError> { @@ -946,12 +1242,13 @@ fn build_function_definition_cost_analysis_tree( arguments: HashMap::new(), map_types: outer_user_args.map_types.clone(), data_var_types: outer_user_args.data_var_types.clone(), + known_values: HashMap::new(), }; // Process function arguments: (a u64) // Use a free cost tracker since we're just parsing types let mut free_tracker = clarity::vm::costs::LimitedCostTracker::new_free(); - for arg_expr in signature.iter().skip(1) { + for (i, arg_expr) in signature.iter().skip(1).enumerate() { if let Some(arg_list) = arg_expr.match_list() { if arg_list.len() == 2 { let arg_name = arg_list[0] @@ -962,14 +1259,25 @@ fn build_function_definition_cost_analysis_tree( let arg_type_expr = &arg_list[1]; - // Parse the type from the AST to TypeSignature - let arg_type = + // Parse the declared type from the AST + let declared_type = TypeSignature::parse_type_repr(epoch, arg_type_expr, &mut free_tracker) .map_err(|e| StaticCostError::TypeParse(format!("{e:?}")))?; + // Use the override type if provided (for call-site narrowing) + let arg_type = arg_type_overrides + .and_then(|overrides| overrides.get(i)) + .cloned() + .unwrap_or(declared_type); + // Add to function's user arguments context function_user_args.add_argument(arg_name.clone(), arg_type.clone()); + // Seed known constant value for this argument, if provided. + if let Some(Some(value)) = arg_known_values.and_then(|kvs| kvs.get(i)) { + function_user_args.add_known_value(arg_name.clone(), value.clone()); + } + // Create UserArgument node children.push(CostAnalysisNode::leaf( CostExprNode::UserArgument(arg_name.clone(), arg_type), @@ -984,6 +1292,7 @@ fn build_function_definition_cost_analysis_tree( body, &function_user_args, cost_map, + function_defs, clarity_version, epoch, env, @@ -1083,6 +1392,7 @@ fn build_listlike_cost_analysis_tree( exprs: &[SymbolicExpression], user_args: &UserArgumentsContext, cost_map: &HashMap>, + function_defs: &HashMap, clarity_version: &ClarityVersion, epoch: StacksEpochId, env: &mut ExecutionState, @@ -1104,6 +1414,7 @@ fn build_listlike_cost_analysis_tree( &exprs[0], user_args, cost_map, + function_defs, clarity_version, epoch, env, @@ -1115,6 +1426,7 @@ fn build_listlike_cost_analysis_tree( expr, user_args, cost_map, + function_defs, clarity_version, epoch, env, @@ -1169,6 +1481,7 @@ fn build_listlike_cost_analysis_tree( &exprs[1], &extended_user_args, cost_map, + function_defs, clarity_version, epoch, env, @@ -1184,6 +1497,7 @@ fn build_listlike_cost_analysis_tree( expr, &extended_user_args, cost_map, + function_defs, clarity_version, epoch, env, @@ -1195,11 +1509,29 @@ fn build_listlike_cost_analysis_tree( } else if native_function == NativeFunctions::If { // `If` creates a nested context let nested_depth = let_depth + 1; - for expr in exprs[1..].iter() { + // Fold the if when its condition resolves to a constant bool: + // build only the condition + selected branch, and below replace + // the NativeFunction(If) node with NestedExpression so the cost + // sum is non-branching. + let folded_branch = exprs + .get(1) + .and_then(|cond| try_resolve_constant_bool(cond, user_args)); + let selected_indices: Vec = if let Some(taken) = folded_branch { + let branch_idx = if taken { 2 } else { 3 }; + if exprs.get(branch_idx).is_some() { + vec![1, branch_idx] + } else { + (1..exprs.len()).collect() + } + } else { + (1..exprs.len()).collect() + }; + for &idx in &selected_indices { let (_, child_tree) = build_cost_analysis_tree( - expr, + &exprs[idx], user_args, cost_map, + function_defs, clarity_version, epoch, env, @@ -1208,6 +1540,24 @@ fn build_listlike_cost_analysis_tree( )?; children.push(child_tree); } + if folded_branch.is_some() { + // Compute the If's own cost (lookup + special function), but + // emit a NestedExpression node so calculate_total_cost_with_branching + // treats it as non-branching. + let cost = calculate_function_cost_from_native_function( + native_function, + (exprs.len() - 1) as u64, + &exprs[1..], + epoch, + Some(user_args), + Some(invoke_ctx.contract_context), + )?; + return Ok(CostAnalysisNode::new( + CostExprNode::NestedExpression, + cost, + children, + )); + } } else { // For other functions, build all children with current depth for expr in exprs[1..].iter() { @@ -1215,6 +1565,7 @@ fn build_listlike_cost_analysis_tree( expr, user_args, cost_map, + function_defs, clarity_version, epoch, env, @@ -1257,7 +1608,8 @@ fn build_listlike_cost_analysis_tree( .and_then(|name| cost_map.get(name.as_str())) .and_then(|c| c.as_ref()) { - let list_max_len = get_map_filter_fold_list_max_len(&exprs[2..], user_args); + let list_max_len = + get_map_filter_fold_list_max_len(&exprs[2..], user_args, epoch); let mut multiplied_min = called_fn_cost.min.clone(); multiply_cost(&mut multiplied_min, list_max_len); let mut multiplied_max = called_fn_cost.max.clone(); @@ -1285,6 +1637,7 @@ fn build_listlike_cost_analysis_tree( expr, user_args, cost_map, + function_defs, clarity_version, epoch, env, @@ -1305,7 +1658,26 @@ fn build_listlike_cost_analysis_tree( }); if cost_map.contains_key(name.as_str()) { let expr_node = CostExprNode::UserFunction(name.clone()); - let cost = calculate_function_cost(name.as_str(), cost_map)?; + let default_cost = calculate_function_cost(name.as_str(), cost_map)?; + let mut cost = try_narrow_user_function_cost( + name.as_str(), + &exprs[1..], + user_args, + function_defs, + cost_map, + clarity_version, + epoch, + env, + invoke_ctx, + ) + .unwrap_or(default_cost); + // The VM's eval() charges LookupFunction(0) for every list + // expression, including user-defined function calls. + let fn_lookup_cost = ClarityCostFunction::LookupFunction + .eval_for_epoch(0, epoch) + .unwrap_or(ExecutionCost::ZERO); + super::saturating_add_cost(&mut cost.min, &fn_lookup_cost); + super::saturating_add_cost(&mut cost.max, &fn_lookup_cost); (expr_node, cost) } else if is_callable_arg { // Callable (trait) arguments used in call position — the actual @@ -1324,6 +1696,7 @@ fn build_listlike_cost_analysis_tree( expr, user_args, cost_map, + function_defs, clarity_version, epoch, env, @@ -1341,6 +1714,7 @@ fn build_listlike_cost_analysis_tree( expr, user_args, cost_map, + function_defs, clarity_version, epoch, env, @@ -1360,6 +1734,7 @@ fn build_listlike_cost_analysis_tree( expr, user_args, cost_map, + function_defs, clarity_version, epoch, env, @@ -1379,6 +1754,7 @@ fn build_listlike_cost_analysis_tree( expr, user_args, cost_map, + function_defs, clarity_version, epoch, env, @@ -1567,6 +1943,7 @@ mod tests { clarity_version: &ClarityVersion, ) -> Result { let cost_map: HashMap> = HashMap::new(); + let function_defs: HashMap = HashMap::new(); let epoch = StacksEpochId::latest(); // XXX this should be matched with the clarity version let ast = make_ast(source, epoch, clarity_version)?; @@ -1583,6 +1960,7 @@ mod tests { expr, &user_args, &cost_map, + &function_defs, clarity_version, epoch, &mut env, @@ -1653,11 +2031,13 @@ mod tests { QualifiedContractIdentifier::transient(), ClarityVersion::Clarity3, ); + let function_defs = HashMap::new(); let (mut env, invoke_ctx) = owned_env.get_exec_environment(None, None, &contract_context); let (_, cost_tree) = build_cost_analysis_tree( expr, &user_args, &cost_map, + &function_defs, &ClarityVersion::Clarity3, epoch, &mut env, @@ -1717,11 +2097,13 @@ mod tests { QualifiedContractIdentifier::transient(), ClarityVersion::Clarity3, ); + let function_defs = HashMap::new(); let (mut env, invoke_ctx) = owned_env.get_exec_environment(None, None, &contract_context); let (_, cost_tree) = build_cost_analysis_tree( expr, &user_args, &cost_map, + &function_defs, &ClarityVersion::Clarity3, epoch, &mut env, @@ -1767,11 +2149,13 @@ mod tests { QualifiedContractIdentifier::transient(), ClarityVersion::Clarity3, ); + let function_defs = HashMap::new(); let (mut env, invoke_ctx) = owned_env.get_exec_environment(None, None, &contract_context); let (_, cost_tree) = build_cost_analysis_tree( expr, &user_args, &cost_map, + &function_defs, &ClarityVersion::Clarity3, epoch, &mut env, @@ -1807,11 +2191,13 @@ mod tests { QualifiedContractIdentifier::transient(), ClarityVersion::Clarity3, ); + let function_defs = HashMap::new(); let (mut env, invoke_ctx) = owned_env.get_exec_environment(None, None, &contract_context); let (_, cost_tree) = build_cost_analysis_tree( expr, &user_args, &cost_map, + &function_defs, &ClarityVersion::Clarity3, epoch, &mut env, @@ -1962,6 +2348,64 @@ mod tests { ); } + #[test] + fn test_infer_type_from_listcons_single_uint() { + let ast = build_test_ast("(list u1)"); + let exprs = ast.expressions[0].match_list().unwrap(); + let user_args = UserArgumentsContext::new(); + let result = infer_type_from_listcons(exprs, &user_args, StacksEpochId::latest()).unwrap(); + let expected = TypeSignature::SequenceType(SequenceSubtype::ListType( + ListTypeData::new_list(TypeSignature::UIntType, 1).unwrap(), + )); + assert_eq!(result, expected); + } + + #[test] + fn test_infer_type_from_listcons_multiple_uints() { + let ast = build_test_ast("(list u1 u2 u3)"); + let exprs = ast.expressions[0].match_list().unwrap(); + let user_args = UserArgumentsContext::new(); + let result = infer_type_from_listcons(exprs, &user_args, StacksEpochId::latest()).unwrap(); + let expected = TypeSignature::SequenceType(SequenceSubtype::ListType( + ListTypeData::new_list(TypeSignature::UIntType, 3).unwrap(), + )); + assert_eq!(result, expected); + } + + #[test] + fn test_infer_type_from_listcons_empty_returns_none() { + // (list) with no elements should return None + let ast = build_test_ast("(list)"); + let exprs = ast.expressions[0].match_list().unwrap(); + let user_args = UserArgumentsContext::new(); + let result = infer_type_from_listcons(exprs, &user_args, StacksEpochId::latest()); + assert!(result.is_none()); + } + + #[test] + fn test_infer_type_from_listcons_not_list_returns_none() { + // A non-list expression like (+ u1 u2) should return None + let ast = build_test_ast("(+ u1 u2)"); + let exprs = ast.expressions[0].match_list().unwrap(); + let user_args = UserArgumentsContext::new(); + let result = infer_type_from_listcons(exprs, &user_args, StacksEpochId::latest()); + assert!(result.is_none()); + } + + #[test] + fn test_infer_type_from_listcons_with_user_arg() { + // When a list element references a known user argument, infer from that type + let mut user_args = UserArgumentsContext::new(); + user_args.add_argument("x".into(), TypeSignature::IntType); + let ast = build_test_ast("(list x)"); + let exprs = ast.expressions[0].match_list().unwrap(); + let result = infer_type_from_listcons(exprs, &user_args, StacksEpochId::latest()).unwrap(); + let expected = TypeSignature::SequenceType(SequenceSubtype::ListType( + ListTypeData::new_list(TypeSignature::IntType, 1).unwrap(), + )); + assert_eq!(result, expected); + } + #[test] fn test_keccak256_cost_varies_with_buffer_size() { let v = &ClarityVersion::Clarity3; diff --git a/components/clarity-static-cost/src/static_cost/special_costs.rs b/components/clarity-static-cost/src/static_cost/special_costs.rs index 6fa85f8a0..e88fd877e 100644 --- a/components/clarity-static-cost/src/static_cost/special_costs.rs +++ b/components/clarity-static-cost/src/static_cost/special_costs.rs @@ -10,6 +10,30 @@ use stacks_common::types::StacksEpochId; use super::cost_analysis::{StaticCost, UserArgumentsContext}; use super::cost_functions::{from_native_function, ClarityCostFunctionExt}; +/// Compute (min_serialized_size, max_serialized_size) for a type. +/// +/// `TypeSignature::min_size()` omits the 1-byte type prefix for some types +/// (e.g. UIntType returns 16 instead of 17) but includes it for others +/// (e.g. BoolType returns 1). `max_serialized_size()` consistently includes +/// the prefix. To get a correct lower bound, we ensure min is never less +/// than the prefix byte (1) and never exceeds max. +fn serialized_size_range(type_sig: &TypeSignature) -> (u64, u64) { + let max = type_sig + .max_serialized_size() + .ok() + .map(u64::from) + .unwrap_or(0); + let min_raw = u64::from(type_sig.min_size().unwrap_or(0)); + // If min_size < max_serialized_size by exactly the type prefix byte, + // the type is fixed-size and min should equal max. + let min = if min_raw.saturating_add(1) == max { + max + } else { + min_raw.min(max) + }; + (min, max) +} + // Constants for tuple serialization overhead const TUPLE_LENGTH_ENCODING_BYTES: u64 = 4; const TUPLE_FIELD_OVERHEAD_BYTES: u64 = 2; @@ -667,15 +691,7 @@ pub fn cost_fetch_var( user_args: Option<&UserArgumentsContext>, ) -> StaticCost { let (min_size, max_size) = resolve_data_var_type(args, user_args) - .map(|type_sig| { - let min = u64::from(type_sig.min_size().unwrap_or(0)); - let max = type_sig - .max_serialized_size() - .ok() - .map(u64::from) - .unwrap_or(0); - (min, max) - }) + .map(serialized_size_range) .unwrap_or((0, 0)); let min_cost = ClarityCostFunction::FetchVar @@ -699,13 +715,19 @@ pub fn cost_set_var( user_args: Option<&UserArgumentsContext>, ) -> StaticCost { // SetVar args: [var-name, value] - // If the value expression is a literal, use its exact serialized size + // If the value expression is a literal — or an atom referring to a user + // argument with a known constant value — use its exact serialized size // since the runtime cost is based on the actual stored value. if let Some(exact_size) = args.get(1).and_then(|e| { - e.match_atom_value() + let value_opt = e + .match_atom_value() .or_else(|| e.match_literal_value()) - .and_then(|v| v.serialized_size().ok()) - .map(u64::from) + .cloned() + .or_else(|| { + let name = e.match_atom()?; + user_args?.get_known_value(name).cloned() + })?; + value_opt.serialized_size().ok().map(u64::from) }) { let cost = ClarityCostFunction::SetVar .eval_for_epoch(exact_size, epoch) @@ -716,18 +738,17 @@ pub fn cost_set_var( }; } + // If the value expression is an atom referring to a user argument with a + // narrower type than the data var's declared type, use that narrower type. + // This is set at the call site by `try_narrow_user_function_cost`. + let value_type = args.get(1).and_then(|e| { + let name = e.match_atom()?; + user_args?.get_argument_type(name) + }); + let resolved_type = value_type.or_else(|| resolve_data_var_type(args, user_args)); + // Fall back to type-based range when the value isn't a literal - let (min_size, max_size) = resolve_data_var_type(args, user_args) - .map(|type_sig| { - let min = u64::from(type_sig.min_size().unwrap_or(0)); - let max = type_sig - .max_serialized_size() - .ok() - .map(u64::from) - .unwrap_or(0); - (min, max) - }) - .unwrap_or((0, 0)); + let (min_size, max_size) = resolved_type.map(serialized_size_range).unwrap_or((0, 0)); let min_cost = ClarityCostFunction::SetVar .eval_for_epoch(min_size, epoch) diff --git a/components/clarity-static-cost/tests/mod.rs b/components/clarity-static-cost/tests/mod.rs index a057fba08..d5531d38f 100644 --- a/components/clarity-static-cost/tests/mod.rs +++ b/components/clarity-static-cost/tests/mod.rs @@ -74,10 +74,12 @@ fn test_build_cost_analysis_tree_function_definition() { &contract_id, clarity_version, |env, invoke_ctx| { + let function_defs = std::collections::HashMap::new(); build_cost_analysis_tree( expr, &user_args, &cost_map, + &function_defs, &clarity_version, epoch, env, @@ -998,6 +1000,145 @@ fn test_against_dynamic_cost_analysis() { &list_32_uint_args, false, ), + // call a function that maps over a list, passing a shorter literal list + ( + indoc! {r#" + (define-data-var count uint u0) + + (define-private (add (n uint)) + (begin + (var-set count (+ (var-get count) n)) + ) + ) + + (define-public (add-many-64 (ns (list 64 uint))) + (begin + (map add ns) + (ok true) + ) + ) + + (define-public (add-u1) + (add-many-64 (list u1)) + ) + "#}, + "add-u1", + &[], + true, + ), + // Buffer narrowing: pass a 1-byte literal to a function declared `(buff 128)`. + // Static cost should match the dynamic cost of writing only 1 byte. + ( + indoc! {r#" + (define-data-var data-buff (buff 128) 0x) + + (define-private (write-buff (data (buff 128))) + (var-set data-buff data) + ) + + (define-public (do-write-buff) + (ok (write-buff 0x01)) + ) + "#}, + "do-write-buff", + &[], + true, + ), + // String narrowing: pass a short literal to a function declared `(string-ascii 100)`. + ( + indoc! {r#" + (define-data-var msg (string-ascii 100) "") + + (define-private (write-msg (s (string-ascii 100))) + (var-set msg s) + ) + + (define-public (do-write-msg) + (ok (write-msg "hi")) + ) + "#}, + "do-write-msg", + &[], + true, + ), + // Map over a literal list: cost should be 3x per-iteration, not 64x. + ( + indoc! {r#" + (define-data-var count uint u0) + + (define-private (add (n uint)) + (begin + (var-set count (+ (var-get count) n)) + ) + ) + + (define-public (map-literal) + (begin + (map add (list u1 u2 u3)) + (ok true) + ) + ) + "#}, + "map-literal", + &[], + true, + ), + // `if true` folding: literal true selects the write branch. + ( + indoc! {r#" + (define-data-var data-buff (buff 128) 0x) + + (define-private (write-if-true (write bool) (data (buff 128))) + (if write (var-set data-buff data) false) + ) + + (define-public (do-write) + (ok (write-if-true true 0x01)) + ) + "#}, + "do-write", + &[], + true, + ), + // `if false` folding: literal false selects the no-op branch. + ( + indoc! {r#" + (define-data-var data-buff (buff 128) 0x) + + (define-private (write-if-true (write bool) (data (buff 128))) + (if write (var-set data-buff data) false) + ) + + (define-public (dont-write) + (ok (write-if-true false 0x00)) + ) + "#}, + "dont-write", + &[], + true, + ), + // Indirect constant: a parent function whose own arg is a constant still + // folds when inlined into the child via known-value propagation. + ( + indoc! {r#" + (define-data-var data-buff (buff 128) 0x) + + (define-private (write-if-true (write bool) (data (buff 128))) + (if write (var-set data-buff data) false) + ) + + (define-private (always-true) + (write-if-true true 0x01) + ) + + (define-public (do-write-indirect) + (ok (always-true)) + ) + "#}, + "do-write-indirect", + &[], + true, + ), ]; let mut failures = Vec::new();