-
Notifications
You must be signed in to change notification settings - Fork 128
[AIMIGRAPHX-835] integrate symbolic expression in dynamic_dimension #4702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 77 commits
d2b684e
aa55785
314f7cf
dcfe825
2ec0969
18caf6b
483932b
200f3c4
7ff2045
6af3621
b7d7c23
3486135
edbce87
964f934
2719ae6
364bd23
33614e0
2ba3b74
830594f
bd70d84
def3038
359070d
9ad996f
b61b6d8
bda9f91
5027376
1209717
003c9d3
3759299
50944fb
fe649ff
7be6e7f
1274d3a
5b28774
83da044
d680d55
8d5629b
e7ca1d6
54debb5
c7f698c
149b661
293b5d8
cc521e4
3b7077a
47ec30a
59a7ef4
95894db
be87f71
78f06c7
71d7ae7
1a68619
2044073
d6c8d49
4556366
378e3d7
5412e60
0bf5680
4fa771a
2545a8e
e904332
c8b8df4
4e0da9c
ca454b2
29ca4d5
a0026ba
1db31d6
c128adc
3b2c259
49d7aa6
cc74c4a
44cd175
ae322fc
7b5484e
39e0442
5be3d69
093964c
72f00ea
812a1b1
8573cfb
718e599
d94f367
1747639
3e0e2c2
4bbb2e6
94c7941
a64a0d1
c4a33e1
3887e3f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,7 @@ | |
| #include <cstdint> | ||
| #include <memory> | ||
| #include <ostream> | ||
| #include <set> | ||
| #include <string> | ||
| #include <unordered_map> | ||
|
|
||
|
|
@@ -40,8 +41,21 @@ struct value; | |
|
|
||
| namespace sym { | ||
|
|
||
| struct interval | ||
| { | ||
| int64_t min = 0; | ||
| int64_t max = 0; | ||
|
shivadbhavsar marked this conversation as resolved.
Outdated
|
||
| friend bool operator==(const interval& a, const interval& b) | ||
| { | ||
| return a.min == b.min and a.max == b.max; | ||
| } | ||
| friend bool operator!=(const interval& a, const interval& b) { return not(a == b); } | ||
| }; | ||
|
|
||
| struct expr; | ||
| MIGRAPHX_EXPORT expr var(const std::string& name); | ||
| MIGRAPHX_EXPORT expr var(const std::string& name, | ||
| interval bounds = {1, 1}, | ||
| std::set<int64_t> optimals = {}); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having the bounds and optimals in the symbolic variable feels wrong to me. The bounds should be supplied from the interval?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm seeing the bigger picture now. This is expecting that an expression can have variables that each have different intervals. Making it different from the top-level interval in the dynamic_dimension. I'm thinking there should be a map between variables and their intervals stored in the program or module itself.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought I had put a comment somewhere regarding this (maybe in the shape implementation). The big picture is that symbolic shapes will not even have a top level interval in dynamic_dimension. Its implemented as optional because it should only exist for range-based shapes for backward compatibility (and any other reason we might want a pure range-based dynamic shape). For symbolic shapes, when the interval is needed, it will be computed on the fly based on the symbols in the symbolic expression for that shape |
||
| MIGRAPHX_EXPORT expr lit(int64_t n); | ||
| MIGRAPHX_EXPORT expr parse(const std::string& s); | ||
|
|
||
|
|
@@ -50,11 +64,14 @@ struct MIGRAPHX_EXPORT expr | |
| expr(); | ||
|
|
||
| bool empty() const; | ||
| bool is_literal() const; | ||
| std::size_t hash() const; | ||
| std::string to_string() const; | ||
| value to_value() const; | ||
| void from_value(const value& v); | ||
| std::size_t eval_uint(const std::unordered_map<expr, std::size_t>& symbol_map) const; | ||
| interval eval_interval() const; | ||
| std::set<std::size_t> eval_optimals() const; | ||
| expr subs(const std::unordered_map<expr, expr>& symbol_map) const; | ||
|
|
||
| MIGRAPHX_EXPORT friend expr operator+(const expr& a, const expr& b); | ||
|
|
@@ -63,6 +80,10 @@ struct MIGRAPHX_EXPORT expr | |
| MIGRAPHX_EXPORT friend expr operator/(const expr& a, const expr& b); | ||
| MIGRAPHX_EXPORT friend bool operator==(const expr& a, const expr& b); | ||
| MIGRAPHX_EXPORT friend bool operator!=(const expr& a, const expr& b); | ||
| MIGRAPHX_EXPORT friend bool operator<(const expr& a, const expr& b); | ||
| friend bool operator>(const expr& a, const expr& b) { return b < a; } | ||
| friend bool operator<=(const expr& a, const expr& b) { return not(b < a); } | ||
| friend bool operator>=(const expr& a, const expr& b) { return not(a < b); } | ||
| MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const expr& e); | ||
|
|
||
| friend expr operator+(const expr& a, int64_t b) { return a + lit(b); } | ||
|
|
@@ -76,7 +97,8 @@ struct MIGRAPHX_EXPORT expr | |
|
|
||
| struct impl; | ||
|
|
||
| MIGRAPHX_EXPORT friend expr var(const std::string& name); | ||
| MIGRAPHX_EXPORT friend expr | ||
| var(const std::string& name, interval bounds, std::set<int64_t> optimals); | ||
| MIGRAPHX_EXPORT friend expr lit(int64_t n); | ||
| MIGRAPHX_EXPORT friend expr parse(const std::string& s); | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,7 @@ | |
| #include <migraphx/permutation.hpp> | ||
| #include <migraphx/functional.hpp> | ||
| #include <migraphx/algorithm.hpp> | ||
| #include <migraphx/sym.hpp> | ||
| #include <map> | ||
| #include <functional> | ||
|
|
||
|
|
@@ -33,6 +34,10 @@ inline namespace MIGRAPHX_INLINE_NS { | |
|
|
||
| shape reorder_shape(const shape& s, const std::vector<int64_t>& permutation) | ||
| { | ||
| if(s.symbolic()) | ||
| return {s.type(), | ||
| reorder_dims(s.dyn_dims(), permutation), | ||
| reorder_dims(s.dyn_strides(), permutation)}; | ||
| return {s.type(), reorder_dims(s.lens(), permutation), reorder_dims(s.strides(), permutation)}; | ||
| } | ||
|
|
||
|
|
@@ -43,11 +48,25 @@ std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation) | |
|
|
||
| std::vector<int64_t> find_permutation(const shape& s) | ||
| { | ||
| std::vector<std::int64_t> result(s.lens().size()); | ||
| if(s.dynamic() and not s.symbolic()) | ||
| MIGRAPHX_THROW("FIND_PERMUTATION: non-symbolic dynamic shapes not supported"); | ||
| std::vector<std::int64_t> result(s.ndim()); | ||
| std::iota(result.begin(), result.end(), 0); | ||
| std::stable_sort(result.begin(), result.end(), by(std::greater<>{}, [&](auto x) { | ||
| return std::make_tuple(s.strides()[x], s.lens()[x]); | ||
| })); | ||
| if(s.symbolic()) | ||
| { | ||
| const auto& strides = s.dyn_strides(); | ||
| const auto& dds = s.dyn_dims(); | ||
| std::stable_sort(result.begin(), result.end(), by(std::greater<>{}, [&](auto x) { | ||
| return std::make_tuple(strides[x].eval_interval().max, | ||
| dds[x].sym_expr.eval_interval().max); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like this is finding the permutation by evaluating out the intervals and then ordering the symbolic expressions using those. Is this why the symbolic variable needs the interval? This feels incorrect to me and instead if we have the same assumptions of monotonic increasing functions, evaluating any random interval should work?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i had a large comment block on is_sorted_strides to address this very thing but somehow it got removed when i was consolidating changes after the interval refactor. It should be there now.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure the idea of permutations makes sense for symbolic shapes. Take this contrived example: Maybe instead the order for the static shape must be determined at runtime unless something like this holds :
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why would a stride expression have a division by a non constant symbol? Is there any op that would actually create such a stride? Regardless, I got claude to try and create breaking examples for this using stride manipulating shape ops (like transpose, slice, step, broadcasts, reshape), and the only way it was able to come up with an ambiguous case like this is if the input bounds are just really badly defined. Eg.: Here the input bounds dont really make sense considering the step value, and it would even give an incorrect result for the static shape case if compiling the model for something like [2,2] input shape.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay, updated to be more clear about assumptions being made, and explicitly throw when intervals cause contradictory results at boundaries
shivadbhavsar marked this conversation as resolved.
Outdated
|
||
| })); | ||
| } | ||
| else | ||
| { | ||
| std::stable_sort(result.begin(), result.end(), by(std::greater<>{}, [&](auto x) { | ||
| return std::make_tuple(s.strides()[x], s.lens()[x]); | ||
| })); | ||
| } | ||
| return result; | ||
| } | ||
|
|
||
|
|
@@ -64,7 +83,7 @@ std::vector<int64_t> find_permutation(const std::vector<shape>& shapes) | |
| } | ||
| if(count.empty()) | ||
| { | ||
| std::vector<int64_t> r(shapes.front().lens().size()); | ||
| std::vector<int64_t> r(shapes.front().ndim()); | ||
| std::iota(r.begin(), r.end(), 0); | ||
| return r; | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.