-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Hoist the shared dense load as a let in stage_strided_loads #8964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| #include "StageStridedLoads.h" | ||
| #include "CSE.h" | ||
| #include "ExprUsesVar.h" | ||
| #include "IREquality.h" | ||
| #include "IRMutator.h" | ||
| #include "IROperator.h" | ||
|
|
@@ -95,12 +96,15 @@ class FindStridedLoads : public IRVisitor { | |
| base = base_add->a; | ||
| offset = *off; | ||
| } | ||
| } else if (auto off = as_const_int(base)) { | ||
| base = 0; | ||
| offset = *off; | ||
| } | ||
|
|
||
| // TODO: We do not yet handle nested vectorization here for | ||
| // ramps which have not already collapsed. We could potentially | ||
| // handle more interesting types of shuffle than simple flat slices. | ||
| if (stride >= 2 && stride < r->lanes && r->stride.type().is_scalar()) { | ||
| if (stride >= 2 && stride <= r->lanes && r->stride.type().is_scalar()) { | ||
| const IRNode *s = scope; | ||
| const Allocate *a = nullptr; | ||
| if (const Allocate *const *a_ptr = allocation_scope.find(op->name)) { | ||
|
|
@@ -154,17 +158,35 @@ class FindStridedLoads : public IRVisitor { | |
| // Replace a bunch of load expressions in a stmt | ||
| class ReplaceStridedLoads : public IRMutator { | ||
| public: | ||
| std::map<std::pair<const Allocate *, const Load *>, Expr> replacements; | ||
| std::map<const Load *, Expr> replacements; | ||
| std::map<const Allocate *, int> padding; | ||
| Scope<const Allocate *> allocation_scope; | ||
| std::map<const IRNode *, std::vector<std::pair<std::string, Expr>>> let_injections; | ||
|
|
||
| Stmt mutate(const Stmt &s) override { | ||
| Stmt stmt = IRMutator::mutate(s); | ||
| auto it = let_injections.find(s.get()); | ||
| if (it != let_injections.end()) { | ||
| for (const auto &[name, value] : it->second) { | ||
| stmt = LetStmt::make(name, value, stmt); | ||
| } | ||
| } | ||
| return stmt; | ||
| } | ||
|
|
||
| Expr mutate(const Expr &e) override { | ||
| Expr expr = IRMutator::mutate(e); | ||
| auto it = let_injections.find(e.get()); | ||
| if (it != let_injections.end()) { | ||
| for (const auto &[name, value] : it->second) { | ||
| expr = Let::make(name, value, expr); | ||
| } | ||
| } | ||
| return expr; | ||
| } | ||
|
|
||
| protected: | ||
| Expr visit(const Load *op) override { | ||
| const Allocate *alloc = nullptr; | ||
| if (const Allocate *const *a_ptr = allocation_scope.find(op->name)) { | ||
| alloc = *a_ptr; | ||
| } | ||
| auto it = replacements.find({alloc, op}); | ||
| auto it = replacements.find(op); | ||
| if (it != replacements.end()) { | ||
| return mutate(it->second); | ||
| } else { | ||
|
|
@@ -173,7 +195,6 @@ class ReplaceStridedLoads : public IRMutator { | |
| } | ||
|
|
||
| Stmt visit(const Allocate *op) override { | ||
| ScopedBinding bind(allocation_scope, op->name, op); | ||
| auto it = padding.find(op); | ||
| Stmt s = IRMutator::visit(op); | ||
| if (it == padding.end()) { | ||
|
|
@@ -191,12 +212,88 @@ class ReplaceStridedLoads : public IRMutator { | |
| using IRMutator::visit; | ||
| }; | ||
|
|
||
| const IRNode *innermost_containing_node(const IRNode *root, const std::set<const Load *> &exprs) { | ||
| const IRNode *result = nullptr; | ||
| // The innermost containing stmt is whichever stmt node contains the | ||
| // largest number of our exprs, with ties breaking inwards. | ||
| int seen = 0, best = 0; | ||
| mutate_with(root, // | ||
| [&](auto *self, const Stmt &s) { | ||
| int old = seen; | ||
| self->mutate_base(s); | ||
| if (old == 0 && seen > best) { | ||
| result = s.get(); | ||
| best = seen; | ||
| } | ||
| return s; // | ||
| }, | ||
| [&](auto *self, const Expr &e) { | ||
| int old = seen; | ||
| const Load *l = e.as<Load>(); | ||
| if (l && exprs.count(l)) { | ||
| seen++; | ||
| }; | ||
| self->mutate_base(e); | ||
| if (old == 0 && seen > best) { | ||
| result = e.get(); | ||
| best = seen; | ||
| } | ||
| return e; // | ||
| }); | ||
| internal_assert(seen) << "None of the exprs were found\n"; | ||
| return result; | ||
| } | ||
|
|
||
| bool can_hoist_shared_load(const IRNode *n, const std::string &buf, const Expr &idx) { | ||
| // Check none of the variables the idx depends on are defined somewhere | ||
| // within this stmt, and there are no stores to the given buffer in the | ||
| // stmt. | ||
|
||
| bool result = true; | ||
| visit_with(n, // | ||
| [&](auto *self, const Let *let) { // | ||
| result &= !expr_uses_var(idx, let->name); | ||
| self->visit_base(let); | ||
| }, | ||
| [&](auto *self, const LetStmt *let) { // | ||
| result &= !expr_uses_var(idx, let->name); | ||
| self->visit_base(let); | ||
| }, | ||
| [&](auto *self, const For *loop) { // | ||
| result &= !expr_uses_var(idx, loop->name); | ||
| self->visit_base(loop); | ||
| }, | ||
| [&](auto *self, const Allocate *alloc) { // | ||
| result &= alloc->name != buf; | ||
| self->visit_base(alloc); | ||
| }, | ||
| [&](auto *self, const Store *store) { // | ||
| result &= store->name != buf; | ||
| self->visit_base(store); | ||
| }); | ||
| return result; | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| Stmt stage_strided_loads(const Stmt &s) { | ||
| Stmt stage_strided_loads(const Stmt &stmt) { | ||
| FindStridedLoads finder; | ||
| ReplaceStridedLoads replacer; | ||
|
|
||
| // Make all strided loads distinct IR nodes so that we can uniquely identify | ||
| // them by address. We may want to mutate the same load node in different | ||
| // ways depending on the surrounding context. | ||
| Stmt s = mutate_with(stmt, [&](auto *self, const Load *l) { | ||
| const Ramp *r = l->index.as<Ramp>(); | ||
| if (l->type.is_scalar() || (r && is_const_one(r->stride))) { | ||
| // Definitely not a strided load | ||
| return self->visit_base(l); | ||
| } else { | ||
| // Might be a strided load after simplification | ||
| return Load::make(l->type, l->name, self->mutate(l->index), l->image, l->param, | ||
| self->mutate(l->predicate), l->alignment); | ||
| } | ||
| }); | ||
|
|
||
| // Find related clusters of strided loads anywhere in the stmt. While this | ||
| // appears to look globally, it requires expressions to match exactly, so | ||
| // really it's only going to find things inside the same loops and let | ||
|
|
@@ -205,7 +302,6 @@ Stmt stage_strided_loads(const Stmt &s) { | |
|
|
||
| for (const auto &l : finder.found_loads) { | ||
| const FindStridedLoads::Key &k = l.first; | ||
| const Allocate *alloc = k.allocation; | ||
| const std::map<int64_t, std::vector<const Load *>> &v = l.second; | ||
|
|
||
| // Find clusters of strided loads that can share the same dense load. | ||
|
|
@@ -225,16 +321,42 @@ Stmt stage_strided_loads(const Stmt &s) { | |
| // We have a complete cluster of loads. Make a single dense load | ||
| int lanes = k.lanes * k.stride; | ||
| int64_t first_offset = load->first; | ||
| Expr idx = Ramp::make(k.base + (int)first_offset, make_one(k.base.type()), lanes); | ||
| Expr base = common_subexpression_elimination(k.base); | ||
| Expr idx = Ramp::make(base + (int)first_offset, make_one(k.base.type()), lanes); | ||
| Type t = k.type.with_lanes(lanes); | ||
| const Load *op = load->second[0]; | ||
|
|
||
| std::set<const Load *> all_loads; | ||
| for (auto l = load; l != v.end() && l->first < first_offset + k.stride; l++) { | ||
| all_loads.insert(l->second.begin(), l->second.end()); | ||
| } | ||
|
|
||
| Expr shared_load = Load::make(t, k.buf, idx, op->image, op->param, | ||
| const_true(lanes), op->alignment); | ||
| shared_load = common_subexpression_elimination(shared_load); | ||
| for (; load != v.end() && load->first < first_offset + k.stride; load++) { | ||
| Expr shuf = Shuffle::make_slice(shared_load, load->first - first_offset, k.stride, k.lanes); | ||
| for (const Load *l : load->second) { | ||
| replacer.replacements.emplace(std::make_pair(alloc, l), shuf); | ||
|
|
||
| // We can't lift the shared load further out than the scope over | ||
| // which the loads definition occur. If k.scope is null, the loads | ||
|
||
| // are valid everywhere (it must be an input buffer) | ||
| const IRNode *outermost = k.scope ? k.scope : s.get(); | ||
| const IRNode *let_site = innermost_containing_node(outermost, all_loads); | ||
| if (can_hoist_shared_load(let_site, k.buf, idx)) { | ||
| std::string name = unique_name('t'); | ||
| Expr var = Variable::make(shared_load.type(), name); | ||
| for (; load != v.end() && load->first < first_offset + k.stride; load++) { | ||
| int row = load->first - first_offset; | ||
| Expr shuf = Shuffle::make_slice(var, row, k.stride, k.lanes); | ||
| for (const Load *l : load->second) { | ||
| replacer.replacements.emplace(l, shuf); | ||
| } | ||
| } | ||
| replacer.let_injections[let_site].emplace_back(name, shared_load); | ||
| } else { | ||
| for (; load != v.end() && load->first < first_offset + k.stride; load++) { | ||
| int row = load->first - first_offset; | ||
| Expr shuf = Shuffle::make_slice(shared_load, row, k.stride, k.lanes); | ||
| for (const Load *l : load->second) { | ||
| replacer.replacements.emplace(l, shuf); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -243,7 +365,7 @@ Stmt stage_strided_loads(const Stmt &s) { | |
| // picked up in a cluster, but for whom we know it's safe to do a | ||
| // dense load before their start. | ||
| for (const auto &[offset, loads] : reverse_view(v)) { | ||
| if (replacer.replacements.count({alloc, loads[0]})) { | ||
| if (replacer.replacements.count(loads[0])) { | ||
| continue; | ||
| } | ||
| int64_t delta = k.stride - 1; | ||
|
|
@@ -261,14 +383,14 @@ Stmt stage_strided_loads(const Stmt &s) { | |
| dense_load = common_subexpression_elimination(dense_load); | ||
| Expr shuf = Shuffle::make_slice(dense_load, delta, k.stride, k.lanes); | ||
| for (const Load *l : loads) { | ||
| replacer.replacements.emplace(std::make_pair(alloc, l), shuf); | ||
| replacer.replacements.emplace(l, shuf); | ||
| } | ||
| } | ||
|
|
||
| // Look for any loads we can densify because an overlapping load occurs | ||
| // in any parent scope. | ||
| for (const auto &[offset, loads] : reverse_view(v)) { | ||
| if (replacer.replacements.count({alloc, loads[0]})) { | ||
| if (replacer.replacements.count(loads[0])) { | ||
| continue; | ||
| } | ||
| int64_t min_offset = offset; | ||
|
|
@@ -299,7 +421,7 @@ Stmt stage_strided_loads(const Stmt &s) { | |
| dense_load = common_subexpression_elimination(dense_load); | ||
| Expr shuf = Shuffle::make_slice(dense_load, offset - final_offset, k.stride, k.lanes); | ||
| for (const Load *l : loads) { | ||
| replacer.replacements.emplace(std::make_pair(alloc, l), shuf); | ||
| replacer.replacements.emplace(l, shuf); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -308,7 +430,7 @@ Stmt stage_strided_loads(const Stmt &s) { | |
| // external allocations by doing a dense load at a trimmed size. We rely | ||
| // on codegen to do a good job at loading vectors of a funny size. | ||
| for (const auto &[offset, loads] : v) { | ||
| if (replacer.replacements.count({alloc, loads[0]})) { | ||
| if (replacer.replacements.count(loads[0])) { | ||
| continue; | ||
| } | ||
|
|
||
|
|
@@ -332,7 +454,7 @@ Stmt stage_strided_loads(const Stmt &s) { | |
| dense_load = common_subexpression_elimination(dense_load); | ||
| Expr shuf = Shuffle::make_slice(dense_load, offset - first_offset, k.stride, k.lanes); | ||
| for (const Load *l : loads) { | ||
| replacer.replacements.emplace(std::make_pair(alloc, l), shuf); | ||
| replacer.replacements.emplace(l, shuf); | ||
| } | ||
|
|
||
| } else if (k.lanes % 2 == 0) { | ||
|
|
@@ -355,7 +477,7 @@ Stmt stage_strided_loads(const Stmt &s) { | |
| Expr shuf2 = Shuffle::make_slice(dense_load2, delta, k.stride, k.lanes / 2); | ||
| Expr shuf = Shuffle::make_concat({shuf1, shuf2}); | ||
| for (const Load *l : loads) { | ||
| replacer.replacements.emplace(std::make_pair(alloc, l), shuf); | ||
| replacer.replacements.emplace(l, shuf); | ||
| } | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any risk of hoisting past an extern stage with a side-effect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there might be. I'll try to construct a failure