Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions compiler/rustc_next_trait_solver/src/solve/search_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ where
response_no_constraints(cx, input, Certainty::overflow(true))
}

const FIXPOINT_OVERFLOW_AMBIGUITY_KIND: Certainty = Certainty::overflow(false);
fn fixpoint_overflow_result(cx: I, input: CanonicalInput<I>) -> QueryResult<I> {
response_no_constraints(cx, input, Certainty::overflow(false))
response_no_constraints(cx, input, Self::FIXPOINT_OVERFLOW_AMBIGUITY_KIND)
}

fn is_ambiguous_result(result: QueryResult<I>) -> Option<Certainty> {
Expand All @@ -111,14 +112,6 @@ where
})
}

fn propagate_ambiguity(
cx: I,
for_input: CanonicalInput<I>,
certainty: Certainty,
) -> QueryResult<I> {
response_no_constraints(cx, for_input, certainty)
}

fn compute_goal(
search_graph: &mut SearchGraph<D>,
cx: I,
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_type_ir/src/interner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ impl<T, R, E> CollectAndApply<T, R> for Result<T, E> {
impl<I: Interner> search_graph::Cx for I {
type Input = CanonicalInput<I>;
type Result = QueryResult<I>;
type AmbiguityInfo = Certainty;
type AmbiguityKind = Certainty;

type DepNodeIndex = I::DepNodeIndex;
type Tracked<T: Debug + Clone> = I::Tracked<T>;
Expand Down
145 changes: 47 additions & 98 deletions compiler/rustc_type_ir/src/search_graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub use global_cache::GlobalCache;
pub trait Cx: Copy {
type Input: Debug + Eq + Hash + Copy;
type Result: Debug + Eq + Hash + Copy;
type AmbiguityInfo: Debug + Eq + Hash + Copy;
type AmbiguityKind: Debug + Eq + Hash + Copy;

type DepNodeIndex;
type Tracked<T: Debug + Clone>: Debug;
Expand Down Expand Up @@ -92,19 +92,16 @@ pub trait Delegate: Sized {
cx: Self::Cx,
input: <Self::Cx as Cx>::Input,
) -> <Self::Cx as Cx>::Result;

const FIXPOINT_OVERFLOW_AMBIGUITY_KIND: <Self::Cx as Cx>::AmbiguityKind;
fn fixpoint_overflow_result(
cx: Self::Cx,
input: <Self::Cx as Cx>::Input,
) -> <Self::Cx as Cx>::Result;

fn is_ambiguous_result(
result: <Self::Cx as Cx>::Result,
) -> Option<<Self::Cx as Cx>::AmbiguityInfo>;
fn propagate_ambiguity(
cx: Self::Cx,
for_input: <Self::Cx as Cx>::Input,
ambiguity_info: <Self::Cx as Cx>::AmbiguityInfo,
) -> <Self::Cx as Cx>::Result;
) -> Option<<Self::Cx as Cx>::AmbiguityKind>;

fn compute_goal(
search_graph: &mut SearchGraph<Self>,
Expand Down Expand Up @@ -337,10 +334,6 @@ impl CycleHeads {
self.heads.last_key_value().map(|(k, _)| *k)
}

fn opt_lowest_cycle_head_index(&self) -> Option<StackDepth> {
self.heads.first_key_value().map(|(k, _)| *k)
}

fn remove_highest_cycle_head(&mut self) -> CycleHead {
let last = self.heads.pop_last();
last.unwrap().1
Expand Down Expand Up @@ -476,10 +469,6 @@ impl PathsToNested {
/// in this case as it could otherwise result in behavioral differences.
/// Cycles can impact behavior. The cycle ABA may have different final
/// results from a the cycle BAB depending on the cycle root.
///
/// We only start tracking nested goals once we've either encountered
/// overflow or a solver cycle. This is a performance optimization to
/// avoid tracking nested goals on the happy path.
#[derive_where(Debug, Default, Clone; X: Cx)]
struct NestedGoals<X: Cx> {
nested_goals: HashMap<X::Input, PathsToNested>,
Expand Down Expand Up @@ -529,9 +518,7 @@ impl<X: Cx> NestedGoals<X> {
/// goals still on the stack.
#[derive_where(Debug; X: Cx)]
struct ProvisionalCacheEntry<X: Cx> {
/// Whether evaluating the goal encountered overflow. This is used to
/// disable the cache entry except if the last goal on the stack is
/// already involved in this cycle.
/// Whether evaluating the goal encountered overflow.
encountered_overflow: bool,
/// All cycle heads this cache entry depends on.
heads: CycleHeads,
Expand Down Expand Up @@ -601,8 +588,7 @@ pub struct SearchGraph<D: Delegate<Cx = X>, X: Cx = <D as Delegate>::Cx> {
/// cache entry.
enum UpdateParentGoalCtxt<'a, X: Cx> {
Ordinary(&'a NestedGoals<X>),
CycleOnStack(X::Input),
ProvisionalCacheHit,
ProvisionalCacheHitOrDirectCycle,
}

impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
Expand All @@ -621,6 +607,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
/// have the same impact on the remaining evaluation.
fn update_parent_goal(
stack: &mut Stack<X>,
child_input: X::Input,
step_kind_from_parent: PathKind,
required_depth_for_nested: usize,
heads: impl Iterator<Item = (StackDepth, CycleHead)>,
Expand Down Expand Up @@ -652,27 +639,14 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
Ordering::Greater => unreachable!(),
}
}
let parent_depends_on_cycle = match context {

parent.nested_goals.insert(child_input, step_kind_from_parent.into());
match context {
UpdateParentGoalCtxt::Ordinary(nested_goals) => {
parent.nested_goals.extend_from_child(step_kind_from_parent, nested_goals);
!nested_goals.is_empty()
}
UpdateParentGoalCtxt::CycleOnStack(head) => {
// We lookup provisional cache entries before detecting cycles.
// We therefore can't use a global cache entry if it contains a cycle
// whose head is in the provisional cache.
parent.nested_goals.insert(head, step_kind_from_parent.into());
true
}
UpdateParentGoalCtxt::ProvisionalCacheHit => true,
UpdateParentGoalCtxt::ProvisionalCacheHitOrDirectCycle => {}
};
// Once we've got goals which encountered overflow or a cycle,
// we track all goals whose behavior may depend depend on these
// goals as this change may cause them to now depend on additional
// goals, resulting in new cycles. See the dev-guide for examples.
if parent_depends_on_cycle {
parent.nested_goals.insert(parent.input, PathsToNested::EMPTY);
}
}
}

Expand Down Expand Up @@ -784,6 +758,13 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
return result;
}

// Detect cycles on the stack. We do this before the global cache lookup as we'd
// need to check whether the current goal is on the stack regardless when checking
// whether a global cache entry is applicable.
if let Some(result) = self.check_cycle_on_stack(cx, input, step_kind_from_parent) {
return result;
}

// Lookup the global cache unless we're building proof trees or are currently
// fuzzing.
let validate_cache = if !D::inspect_is_noop(inspect) {
Expand All @@ -804,15 +785,6 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
None
};

// Detect cycles on the stack. We do this after the global cache lookup to
// avoid iterating over the stack in case a goal has already been computed.
// This may not have an actual performance impact and we could reorder them
// as it may reduce the number of `nested_goals` we need to track.
if let Some(result) = self.check_cycle_on_stack(cx, input, step_kind_from_parent) {
debug_assert!(validate_cache.is_none(), "global cache and cycle on stack: {input:?}");
return result;
}

// Unfortunate, it looks like we actually have to compute this goal.
self.stack.push(StackEntry {
input,
Expand Down Expand Up @@ -840,6 +812,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
// lazily update its parent goal.
Self::update_parent_goal(
&mut self.stack,
input,
step_kind_from_parent,
evaluation_result.required_depth,
evaluation_result.heads.iter(),
Expand Down Expand Up @@ -929,8 +902,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
#[derive_where(Debug; X: Cx)]
enum RebaseReason<X: Cx> {
NoCycleUsages,
Ambiguity(X::AmbiguityInfo),
Overflow,
Ambiguity(X::AmbiguityKind),
/// We've actually reached a fixpoint.
///
/// This either happens in the first evaluation step for the cycle head.
Expand Down Expand Up @@ -961,10 +933,9 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D, X> {
/// cache entries to also be ambiguous. This causes some undesirable ambiguity for nested
/// goals whose result doesn't actually depend on this cycle head, but that's acceptable
/// to me.
#[instrument(level = "trace", skip(self, cx))]
#[instrument(level = "trace", skip(self))]
fn rebase_provisional_cache_entries(
&mut self,
cx: X,
stack_entry: &StackEntry<X>,
rebase_reason: RebaseReason<X>,
) {
Expand Down Expand Up @@ -1039,18 +1010,22 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D, X> {
}

// The provisional cache entry does depend on the provisional result
// of the popped cycle head. We need to mutate the result of our
// provisional cache entry in case we did not reach a fixpoint.
// of the popped cycle head. In case we didn't actually reach a fixpoint,
// we must not keep potentially incorrect provisional cache entries around.
match rebase_reason {
// If the cycle head does not actually depend on itself, then
// the provisional result used by the provisional cache entry
// is not actually equal to the final provisional result. We
// need to discard the provisional cache entry in this case.
RebaseReason::NoCycleUsages => return false,
RebaseReason::Ambiguity(info) => {
*result = D::propagate_ambiguity(cx, input, info);
// If we avoid rerunning a goal due to ambiguity, we only keep provisional
// results which depend on that cycle head if these are already ambiguous
// themselves.
RebaseReason::Ambiguity(kind) => {
if !D::is_ambiguous_result(*result).is_some_and(|k| k == kind) {
return false;
}
}
RebaseReason::Overflow => *result = D::fixpoint_overflow_result(cx, input),
RebaseReason::ReachedFixpoint(None) => {}
RebaseReason::ReachedFixpoint(Some(path_kind)) => {
if !popped_head.usages.is_single(path_kind) {
Expand Down Expand Up @@ -1093,36 +1068,20 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D, X> {
for &ProvisionalCacheEntry { encountered_overflow, ref heads, path_from_head, result } in
entries
{
let head_index = heads.highest_cycle_head_index();
if encountered_overflow {
// This check is overly strict and very subtle. We need to make sure that if
// a global cache entry depends on some goal without adding it to its
// `nested_goals`, that goal must never have an applicable provisional
// cache entry to avoid incorrectly applying the cache entry.
//
// As we'd have to otherwise track literally all nested goals, we only
// apply provisional cache entries which encountered overflow once the
// current goal is already part of the same cycle. This check could be
// improved but seems to be good enough for now.
let last = self.stack.last().unwrap();
if last.heads.opt_lowest_cycle_head_index().is_none_or(|lowest| lowest > head_index)
{
continue;
}
}
Copy link
Copy Markdown
Contributor Author

@lcnr lcnr Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this fast path meant exponential blowup in ml-kem ended up causing a hang. let's instead just track all nested goals

View changes since the review


// A provisional cache entry is only valid if the current path from its
// highest cycle head to the goal is the same.
let head_index = heads.highest_cycle_head_index();
if path_from_head
== Self::cycle_path_kind(&self.stack, step_kind_from_parent, head_index)
{
Self::update_parent_goal(
&mut self.stack,
input,
step_kind_from_parent,
0,
heads.iter(),
encountered_overflow,
UpdateParentGoalCtxt::ProvisionalCacheHit,
UpdateParentGoalCtxt::ProvisionalCacheHitOrDirectCycle,
);
debug!(?head_index, ?path_from_head, "provisional cache hit");
return Some(result);
Expand Down Expand Up @@ -1167,18 +1126,12 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D, X> {
// A provisional cache entry is applicable if the path to
// its highest cycle head is equal to the expected path.
for &ProvisionalCacheEntry {
encountered_overflow,
encountered_overflow: _,
ref heads,
path_from_head: head_to_provisional,
result: _,
} in entries.iter()
{
// We don't have to worry about provisional cache entries which encountered
// overflow, see the relevant comment in `lookup_provisional_cache`.
if encountered_overflow {
continue;
}

// A provisional cache entry only applies if the path from its highest head
// matches the path when encountering the goal.
//
Expand Down Expand Up @@ -1241,6 +1194,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D, X> {
let heads = iter::empty();
Self::update_parent_goal(
&mut self.stack,
input,
step_kind_from_parent,
required_depth,
heads,
Expand Down Expand Up @@ -1273,11 +1227,12 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D, X> {
let head = CycleHead { paths_to_head: step_kind_from_parent.into(), usages };
Self::update_parent_goal(
&mut self.stack,
input,
step_kind_from_parent,
0,
iter::once((head_index, head)),
false,
UpdateParentGoalCtxt::CycleOnStack(input),
UpdateParentGoalCtxt::ProvisionalCacheHitOrDirectCycle,
);

// Return the provisional result or, if we're in the first iteration,
Expand Down Expand Up @@ -1352,17 +1307,12 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D, X> {
// final result is equal to the initial response for that case.
if let Ok(fixpoint) = self.reached_fixpoint(&stack_entry, usages, result) {
self.rebase_provisional_cache_entries(
cx,
&stack_entry,
RebaseReason::ReachedFixpoint(fixpoint),
);
return EvaluationResult::finalize(stack_entry, encountered_overflow, result);
} else if usages.is_empty() {
self.rebase_provisional_cache_entries(
cx,
&stack_entry,
RebaseReason::NoCycleUsages,
);
self.rebase_provisional_cache_entries(&stack_entry, RebaseReason::NoCycleUsages);
return EvaluationResult::finalize(stack_entry, encountered_overflow, result);
}

Expand All @@ -1371,19 +1321,15 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D, X> {
// response in the next iteration in this case. These changes would
// likely either be caused by incompleteness or can change the maybe
// cause from ambiguity to overflow. Returning ambiguity always
// preserves soundness and completeness even if the goal is be known
// to succeed or fail.
// preserves soundness and completeness even if the goal could
// otherwise succeed or fail.
//
// This prevents exponential blowup affecting multiple major crates.
// As we only get to this branch if we haven't yet reached a fixpoint,
// we also taint all provisional cache entries which depend on the
// current goal.
if let Some(info) = D::is_ambiguous_result(result) {
self.rebase_provisional_cache_entries(
cx,
&stack_entry,
RebaseReason::Ambiguity(info),
);
if let Some(kind) = D::is_ambiguous_result(result) {
self.rebase_provisional_cache_entries(&stack_entry, RebaseReason::Ambiguity(kind));
return EvaluationResult::finalize(stack_entry, encountered_overflow, result);
};

Expand All @@ -1393,7 +1339,10 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D, X> {
if i >= D::FIXPOINT_STEP_LIMIT {
debug!("canonical cycle overflow");
let result = D::fixpoint_overflow_result(cx, input);
self.rebase_provisional_cache_entries(cx, &stack_entry, RebaseReason::Overflow);
self.rebase_provisional_cache_entries(
&stack_entry,
RebaseReason::Ambiguity(D::FIXPOINT_OVERFLOW_AMBIGUITY_KIND),
);
return EvaluationResult::finalize(stack_entry, encountered_overflow, result);
}

Expand Down Expand Up @@ -1445,7 +1394,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D, X> {
evaluation_result: EvaluationResult<X>,
dep_node: X::DepNodeIndex,
) {
debug!(?evaluation_result, "insert global cache");
debug!(?input, ?evaluation_result, "insert global cache");
cx.with_global_cache(|cache| cache.insert(cx, input, evaluation_result, dep_node))
}
}
Loading
Loading