Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/AddAtomicMutex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ class AddAtomicMutex : public IRMutator {
std::string name = unique_name('t');
index_let = index;
index = Variable::make(index.type(), name);
body = ReplaceStoreIndexWithVar(op->producer_name, index).mutate(body);
body = ReplaceStoreIndexWithVar(op->producer_name, index)(body);
}
// This generates a pointer to the mutex array
Expr mutex_array = Variable::make(
Expand Down Expand Up @@ -454,8 +454,8 @@ Stmt add_atomic_mutex(Stmt s, const std::vector<Function> &outputs) {
CheckAtomicValidity check;
s.accept(&check);
if (check.any_atomic) {
s = RemoveUnnecessaryMutexUse().mutate(s);
s = AddAtomicMutex(outputs).mutate(s);
s = RemoveUnnecessaryMutexUse()(s);
s = AddAtomicMutex(outputs)(s);
}
return s;
}
Expand Down
13 changes: 8 additions & 5 deletions src/AddImageChecks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class TrimStmtToPartsThatAccessBuffers : public IRMutator {
bool touches_buffer = false;
const map<string, FindBuffers::Result> &buffers;

protected:
using IRMutator::visit;

Expr visit(const Call *op) override {
Expand Down Expand Up @@ -185,10 +186,10 @@ Stmt add_image_checks_inner(Stmt s,

// Add the input buffer(s) and annotate which output buffers are
// used on host.
s.accept(&finder);
finder(s);

Scope<Interval> empty_scope;
Stmt sub_stmt = TrimStmtToPartsThatAccessBuffers(bufs).mutate(s);
Stmt sub_stmt = TrimStmtToPartsThatAccessBuffers(bufs)(s);
map<string, Box> boxes = boxes_touched(sub_stmt, empty_scope, fb);

// Now iterate through all the buffers, creating a list of lets
Expand Down Expand Up @@ -225,7 +226,7 @@ Stmt add_image_checks_inner(Stmt s,
string extent_name = concat_strings(name, ".extent.", i);
string stride_name = concat_strings(name, ".stride.", i);
replace_with_required[min_name] = Variable::make(Int(32), min_name + ".required");
replace_with_required[extent_name] = simplify(Variable::make(Int(32), extent_name + ".required"));
replace_with_required[extent_name] = Variable::make(Int(32), extent_name + ".required");
replace_with_required[stride_name] = Variable::make(Int(32), stride_name + ".required");
}
}
Expand Down Expand Up @@ -737,6 +738,7 @@ Stmt add_image_checks(const Stmt &s,
// Checks for images go at the marker deposited by computation
// bounds inference.
class Injector : public IRMutator {
protected:
using IRMutator::visit;

Expr visit(const Variable *op) override {
Expand Down Expand Up @@ -794,9 +796,10 @@ Stmt add_image_checks(const Stmt &s,
bool will_inject_host_copies)
: outputs(outputs), t(t), order(order), env(env), fb(fb), will_inject_host_copies(will_inject_host_copies) {
}
} injector(outputs, t, order, env, fb, will_inject_host_copies);
};
Injector injector(outputs, t, order, env, fb, will_inject_host_copies);

return injector.mutate(s);
return injector(s);
}

} // namespace Internal
Expand Down
2 changes: 1 addition & 1 deletion src/AlignLoads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class AlignLoads : public IRMutator {
} // namespace

Stmt align_loads(const Stmt &s, int alignment, int min_bytes_to_align) {
return AlignLoads(alignment, min_bytes_to_align).mutate(s);
return AlignLoads(alignment, min_bytes_to_align)(s);
}

} // namespace Internal
Expand Down
4 changes: 2 additions & 2 deletions src/AllocationBoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ class StripDeclareBoxTouched : public IRMutator {
Stmt allocation_bounds_inference(Stmt s,
const map<string, Function> &env,
const FuncValueBounds &fb) {
s = AllocationInference(env, fb).mutate(s);
s = StripDeclareBoxTouched().mutate(s);
s = AllocationInference(env, fb)(s);
s = StripDeclareBoxTouched()(s);
return s;
}

Expand Down
2 changes: 1 addition & 1 deletion src/Associativity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ AssociativeOp prove_associativity(const string &f, vector<Expr> args, vector<Exp

// Replace any self-reference to Func 'f' with a Var
ConvertSelfRef csr(f, args, idx, op_x_names);
exprs[idx] = csr.mutate(exprs[idx]);
exprs[idx] = csr(exprs[idx]);
if (!csr.is_solvable) {
return AssociativeOp();
}
Expand Down
31 changes: 21 additions & 10 deletions src/AsyncProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class NoOpCollapsingMutator : public IRMutator {
};

class GenerateProducerBody : public NoOpCollapsingMutator {
protected:
const string &func;
vector<Expr> sema;
std::set<string> producers_dropped;
Expand Down Expand Up @@ -285,6 +286,7 @@ class GenerateProducerBody : public NoOpCollapsingMutator {
};

class GenerateConsumerBody : public NoOpCollapsingMutator {
protected:
const string &func;
vector<Expr> sema;

Expand Down Expand Up @@ -342,6 +344,7 @@ class GenerateConsumerBody : public NoOpCollapsingMutator {
};

class CloneAcquire : public IRMutator {
protected:
using IRMutator::visit;

const string &old_name;
Expand Down Expand Up @@ -390,6 +393,7 @@ class CountConsumeNodes : public IRVisitor {
};

class ForkAsyncProducers : public IRMutator {
protected:
using IRMutator::visit;

const map<string, Function> &env;
Expand All @@ -414,8 +418,8 @@ class ForkAsyncProducers : public IRMutator {
sema_vars.push_back(Variable::make(type_of<halide_semaphore_t *>(), sema_names.back()));
}

Stmt producer = GenerateProducerBody(name, sema_vars, cloned_acquires).mutate(body);
Stmt consumer = GenerateConsumerBody(name, sema_vars).mutate(body);
Stmt producer = GenerateProducerBody(name, sema_vars, cloned_acquires)(body);
Stmt consumer = GenerateConsumerBody(name, sema_vars)(body);

// Recurse on both sides
producer = mutate(producer);
Expand All @@ -434,7 +438,7 @@ class ForkAsyncProducers : public IRMutator {
// of the producer and consumer.
const vector<string> &clones = cloned_acquires[sema_name];
for (const auto &i : clones) {
body = CloneAcquire(sema_name, i).mutate(body);
body = CloneAcquire(sema_name, i)(body);
body = LetStmt::make(i, sema_space, body);
}

Expand Down Expand Up @@ -493,6 +497,7 @@ class ForkAsyncProducers : public IRMutator {
// simple failure case, error_async_require_fail. One has not been
// written for the complex nested case yet.)
class InitializeSemaphores : public IRMutator {
protected:
using IRMutator::visit;

const Type sema_type = type_of<halide_semaphore_t *>();
Expand Down Expand Up @@ -558,6 +563,7 @@ class InitializeSemaphores : public IRMutator {
// A class to support stmt_uses_vars queries that repeatedly hit the same
// sub-stmts. Used to support TightenProducerConsumerNodes below.
class CachingStmtUsesVars : public IRMutator {
protected:
const Scope<> &query;
bool found_use = false;
std::map<Stmt, bool> cache;
Expand Down Expand Up @@ -613,6 +619,7 @@ class CachingStmtUsesVars : public IRMutator {

// Tighten the scope of consume nodes as much as possible to avoid needless synchronization.
class TightenProducerConsumerNodes : public IRMutator {
protected:
using IRMutator::visit;

Stmt make_producer_consumer(const string &name, bool is_producer, Stmt body, const Scope<> &scope, CachingStmtUsesVars &uses_vars) {
Expand Down Expand Up @@ -703,6 +710,7 @@ class TightenProducerConsumerNodes : public IRMutator {

// Update indices to add ring buffer.
class UpdateIndices : public IRMutator {
protected:
using IRMutator::visit;

Stmt visit(const Provide *op) override {
Expand Down Expand Up @@ -734,6 +742,7 @@ class UpdateIndices : public IRMutator {

// Inject ring buffering.
class InjectRingBuffering : public IRMutator {
protected:
using IRMutator::visit;

struct Loop {
Expand Down Expand Up @@ -768,7 +777,7 @@ class InjectRingBuffering : public IRMutator {
}
current_index = current_index % f.schedule().ring_buffer();
// Adds an extra index for to the all of the references of f.
body = UpdateIndices(op->name, current_index).mutate(body);
body = UpdateIndices(op->name, current_index)(body);

if (f.schedule().async()) {
Expr sema_var = Variable::make(type_of<halide_semaphore_t *>(), f.name() + ".folding_semaphore.ring_buffer");
Expand Down Expand Up @@ -816,6 +825,7 @@ class InjectRingBuffering : public IRMutator {
// Broaden the scope of acquire nodes to pack trailing work into the
// same task and to potentially reduce the nesting depth of tasks.
class ExpandAcquireNodes : public IRMutator {
protected:
using IRMutator::visit;

Stmt visit(const Block *op) override {
Expand Down Expand Up @@ -918,6 +928,7 @@ class ExpandAcquireNodes : public IRMutator {
};

class TightenForkNodes : public IRMutator {
protected:
using IRMutator::visit;

Stmt make_fork(const Stmt &first, const Stmt &rest) {
Expand Down Expand Up @@ -1005,12 +1016,12 @@ class TightenForkNodes : public IRMutator {
} // namespace

Stmt fork_async_producers(Stmt s, const map<string, Function> &env) {
s = TightenProducerConsumerNodes(env).mutate(s);
s = InjectRingBuffering(env).mutate(s);
s = ForkAsyncProducers(env).mutate(s);
s = ExpandAcquireNodes().mutate(s);
s = TightenForkNodes().mutate(s);
s = InitializeSemaphores().mutate(s);
s = TightenProducerConsumerNodes(env)(s);
s = InjectRingBuffering(env)(s);
s = ForkAsyncProducers(env)(s);
s = ExpandAcquireNodes()(s);
s = TightenForkNodes()(s);
s = InitializeSemaphores()(s);
return s;
}

Expand Down
4 changes: 2 additions & 2 deletions src/AutoScheduleUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ Expr substitute_var_estimates(Expr e) {
if (!e.defined()) {
return e;
}
return simplify(SubstituteVarEstimates().mutate(e));
return simplify(SubstituteVarEstimates()(e));
}

Stmt substitute_var_estimates(Stmt s) {
if (!s.defined()) {
return s;
}
return simplify(SubstituteVarEstimates().mutate(s));
return simplify(SubstituteVarEstimates()(s));
}

int string_to_int(const string &s) {
Expand Down
3 changes: 2 additions & 1 deletion src/BoundConstantExtentLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace Internal {

namespace {
class BoundLoops : public IRMutator {
protected:
using IRMutator::visit;

std::vector<std::pair<std::string, Expr>> lets;
Expand Down Expand Up @@ -128,7 +129,7 @@ class BoundLoops : public IRMutator {
} // namespace

Stmt bound_constant_extent_loops(const Stmt &s) {
return BoundLoops().mutate(s);
return BoundLoops()(s);
}

} // namespace Internal
Expand Down
2 changes: 1 addition & 1 deletion src/BoundSmallAllocations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class BoundSmallAllocations : public IRMutator {
} // namespace

Stmt bound_small_allocations(const Stmt &s) {
return BoundSmallAllocations().mutate(s);
return BoundSmallAllocations()(s);
}

} // namespace Internal
Expand Down
17 changes: 9 additions & 8 deletions src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ class Bounds : public IRVisitor {

#endif // DO_TRACK_BOUNDS_INTERVALS

private:
protected:
// Compute the intrinsic bounds of a function.
void bounds_of_func(const string &name, int value_index, Type t) {
// if we can't get a good bound from the function, fall back to the bounds of the type.
Expand Down Expand Up @@ -1799,7 +1799,7 @@ Interval bounds_of_expr_in_scope_with_indent(const Expr &expr, const Scope<Inter
#if DO_TRACK_BOUNDS_INTERVALS
b.log_indent = indent + 1;
#endif
expr.accept(&b);
b(expr);
#if DO_TRACK_BOUNDS_INTERVALS
debug(0) << spaces << " mn=" << simplify(b.interval.min) << "\n"
<< spaces << " mx=" << simplify(b.interval.max) << "\n"
Expand Down Expand Up @@ -2023,6 +2023,7 @@ class FindInnermostVar : public IRVisitor {

// Place innermost vars in an IfThenElse's condition as far to the left as possible.
class SolveIfThenElse : public IRMutator {
protected:
// Scope of variable names and their depths. Higher depth indicates
// variable defined more innermost.
Scope<int> vars_depth;
Expand Down Expand Up @@ -2255,7 +2256,7 @@ class BoxesTouched : public IRGraphVisitor {

#endif // DO_TRACK_BOUNDS_INTERVALS

private:
protected:
struct VarInstance {
string var;
int instance;
Expand Down Expand Up @@ -3107,7 +3108,7 @@ map<string, Box> boxes_touched(const Expr &e, Stmt s, bool consider_calls, bool
// as possible, so that BoxesTouched can prune the variable scope tighter
// when encountering the IfThenElse.
if (s.defined()) {
s = SolveIfThenElse().mutate(s);
s = SolveIfThenElse()(s);
}

// Do calls and provides separately, for better simplification.
Expand All @@ -3116,18 +3117,18 @@ map<string, Box> boxes_touched(const Expr &e, Stmt s, bool consider_calls, bool

if (consider_calls) {
if (e.defined()) {
e.accept(&calls);
calls(e);
}
if (s.defined()) {
s.accept(&calls);
calls(s);
}
}
if (consider_provides) {
if (e.defined()) {
e.accept(&provides);
provides(e);
}
if (s.defined()) {
s.accept(&provides);
provides(s);
}
}

Expand Down
5 changes: 2 additions & 3 deletions src/BoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ class BoundsInference : public IRMutator {
} select_to_if_then_else;

for (auto &e : exprs) {
e.value = select_to_if_then_else.mutate(e.value);
e.value = select_to_if_then_else(e.value);
}
}

Expand Down Expand Up @@ -1382,8 +1382,7 @@ Stmt bounds_inference(Stmt s,
s = For::make("<outermost>", 0, 0, ForType::Serial, Partition::Never, DeviceAPI::None, s);

s = BoundsInference(funcs, fused_func_groups, fused_pairs_in_groups,
outputs, func_bounds, target)
.mutate(s);
outputs, func_bounds, target)(s);
return s.as<For>()->body;
}

Expand Down
5 changes: 5 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,11 @@ target_compile_definitions(Halide PRIVATE WITH_SPIRV)
target_compile_definitions(Halide PRIVATE WITH_VULKAN)
target_compile_definitions(Halide PRIVATE WITH_WEBGPU)

if (WITH_COMPILER_PROFILING)
target_compile_definitions(Halide PRIVATE WITH_COMPILER_PROFILING)
endif()


##
# Flatbuffers and Serialization dependencies.
##
Expand Down
Loading
Loading