Conversation
shivadbhavsar
left a comment
There was a problem hiding this comment.
I did a quick initial scan focusing on the interface mostly and it should be relatively straight forward to integrate (I think a few explicit conversions from scalar to size_t need to be added in some places). I'd prefer to get #4702 in first and make sure that all the shape integration tests also pass here.
I'll do another round soon to actually go through the implementation details
| : pimpl(make_impl(std::move(node), std::move(children))) | ||
| { | ||
| } | ||
|
|
There was a problem hiding this comment.
can we add a is_literal somewhere in here, its need for is_fixed in the shape class.
There was a problem hiding this comment.
You can check the name of the node: expr.name() == "literal".
| scalar_max(scalar_max(p1, p2), scalar_max(p3, p4))}; | ||
| } | ||
|
|
||
| interval operator/(interval a, interval b) |
There was a problem hiding this comment.
i think this should probably throw if interval b crosses 0 (eg. [-2, 4]). Wont really affect our current use case but would be good for correctness
There was a problem hiding this comment.
So [-2, 4] doesnt mean we will divide by zero. These are ranges of max and min and doesnt mean every value in between is always used.
There was a problem hiding this comment.
well for floats especially, the true resulting interval would be unbounded. And also for integers im not sure this gives the correct result, bounds would have to be computed at -1 and 1
There was a problem hiding this comment.
I see, the formula is derived by taking the reciprocal. The reciprocal formula is [1/max, 1/min] so for [-2, 4], it produces [1/3, -1/2] which is not a correct interval. The correct answer is the union of two intervals (-inf, -1/2) U (1/3, inf), which as a single interval will have to be (-inf, inf).
|
|
||
| if(var_optimals.empty()) | ||
| return {eval({})}; |
There was a problem hiding this comment.
why do we want to add optimals at all in this case?
| expr operator+(const expr& a, const expr& b) | ||
| bool expr::is_raw() const { return pimpl and pimpl->raw_flag; } | ||
|
|
||
| const std::vector<expr>& expr::children() const { return pimpl->children; } |
| friend bool operator==(const literal_node& a, const literal_node& b) | ||
| { | ||
| return scalar_invoke_common<bool>( | ||
| [](auto a, auto b) { return float_equal(a, b); }, a.val, b.val); |
There was a problem hiding this comment.
this makes lit(1) == lit(1.0). But they will hash differently when using expr in maps and sets
There was a problem hiding this comment.
Ok let me fix the hashing.
| } | ||
|
|
||
| static expr_ptr parse_unary(const char*& p) | ||
| std::size_t expr::eval_uint(const std::unordered_map<expr, std::size_t>& symbol_map) const |
There was a problem hiding this comment.
might be good to explicitly throw if eval result is negative
| bool empty() const { return start == last; } | ||
|
|
||
| std::size_t size() const { return std::distance(start, last); } |
There was a problem hiding this comment.
iterator_range::size() uses std::distance, but ranges.hpp doesn’t include <iterator>. Please include it explicitly rather than relying on transitive includes.
| auto compute = [&](const scalar& x, const scalar& y) -> scalar { | ||
| if(std::holds_alternative<int64_t>(x) and std::holds_alternative<int64_t>(y)) | ||
| return f(std::get<int64_t>(x), std::get<int64_t>(y)); | ||
| return fd(to<double>(x), to<double>(y)); | ||
| }; |
There was a problem hiding this comment.
interval operator%: the integer branch uses % directly and will hit UB if any endpoint of b is 0 (or if b spans 0). This should be detected and handled explicitly.
| if((std::isdigit(p.peek_char()) == 0) and p.peek_char() != '.') | ||
| return {}; |
There was a problem hiding this comment.
std::isdigit(p.peek_char()) passes a char directly into ctype; if char is signed this is UB for negative values. Cast to unsigned char before calling ctype functions (consistent with src/driver/main.cpp:357-359).
| overloaded{[](const literal_node& n) -> std::optional<scalar> { return n.val; }, | ||
| [](const auto&) -> std::optional<scalar> { return std::nullopt; }}, |
There was a problem hiding this comment.
In eval_uint, the replacer returns nullopt for variable_node when the variable isn’t in the map. For a leaf variable this means generic_eval will eventually try std::get<op_node>(...) on a variable_node, throwing std::bad_variant_access instead of a controlled MIGRAPHX_THROW (e.g. “unbound symbol”). Handle unbound variables explicitly in the replacer (and consider validating that map keys are variables).
| overloaded{[](const literal_node& n) -> std::optional<scalar> { return n.val; }, | |
| [](const auto&) -> std::optional<scalar> { return std::nullopt; }}, | |
| overloaded{ | |
| [](const literal_node& n) -> std::optional<scalar> { return n.val; }, | |
| [](const variable_node&) -> std::optional<scalar> { | |
| MIGRAPHX_THROW("Unbound symbol in eval_uint"); | |
| }, | |
| [](const auto&) -> std::optional<scalar> { return std::nullopt; }}, |
| if constexpr(std::is_integral<T>{}) | ||
| return lit(static_cast<int64_t>(x)); | ||
| else | ||
| return lit(static_cast<double>(x)); |
There was a problem hiding this comment.
arg(T) casts integral values to int64_t directly. For large unsigned values this is implementation-defined / lossy and can wrap. Prefer lit(make_scalar(x)) here to match the clipping/typing rules used elsewhere.
| if constexpr(std::is_integral<T>{}) | |
| return lit(static_cast<int64_t>(x)); | |
| else | |
| return lit(static_cast<double>(x)); | |
| return lit(make_scalar(x)); |
| MIGRAPHX_CALL_FUNC(ceil), | ||
| }; | ||
| #undef MIGRAPHX_CALL_FUNC | ||
| return functions.at(name)(args); |
There was a problem hiding this comment.
call_function uses functions.at(name), so unknown names throw std::out_of_range and lose useful context. Prefer a find() + MIGRAPHX_THROW that includes the unknown function/operator name (and parse position when called from parsing).
| return functions.at(name)(args); | |
| auto it = functions.find(name); | |
| if(it == functions.end()) | |
| MIGRAPHX_THROW("Unknown function/operator: " + name); | |
| return it->second(args); |
| char c = p.peek_char(); | ||
| if((std::isalpha(c) == 0) and c != '_') | ||
| return {}; | ||
| auto name = p.parse_while([](unsigned char ch) { return std::isalnum(ch) or ch == '_'; }); |
There was a problem hiding this comment.
Same issue here: std::isalpha(c) is called with a char. Cast to unsigned char (and similarly for std::isalnum) to avoid UB on non-ASCII bytes.
| return to<std::size_t>(generic_eval<scalar>(*this, [&](const expr& e) -> std::optional<scalar> { | ||
| auto it = symbol_map.find(e); |
There was a problem hiding this comment.
Also, eval_uint converts the final scalar to std::size_t without checking for negative results. A negative expression value will silently wrap to a huge size_t. Consider preserving the previous behavior of throwing if the evaluated value is negative.
| auto letters = p.parse_while([](char c) { return std::isalpha(c); }); | ||
| EXPECT(letters == "abc"); | ||
| auto digits = p.parse_while([](char c) { return std::isdigit(c); }); |
There was a problem hiding this comment.
std::isalpha / std::isdigit are called with char values here. If char is signed, passing negative values to ctype is UB. Please cast to unsigned char in these predicates (consistent with src/driver/main.cpp:357-359).
| auto letters = p.parse_while([](char c) { return std::isalpha(c); }); | |
| EXPECT(letters == "abc"); | |
| auto digits = p.parse_while([](char c) { return std::isdigit(c); }); | |
| auto letters = | |
| p.parse_while([](char c) { return std::isalpha(static_cast<unsigned char>(c)); }); | |
| EXPECT(letters == "abc"); | |
| auto digits = | |
| p.parse_while([](char c) { return std::isdigit(static_cast<unsigned char>(c)); }); |
| template <class Iterator> | ||
| void hash_range(std::size_t& seed, Iterator first, Iterator last) | ||
| { | ||
| std::for_each(first, last, [&](const auto& x) { hash_combine(seed, x); }); | ||
| } |
There was a problem hiding this comment.
hash_range uses std::for_each, but this header doesn’t include <algorithm>, which can break compilation depending on include order. Please add <algorithm> (or avoid std::for_each).
| char c = p.peek_char(); | ||
| if((std::isalpha(c) == 0) and c != '_') | ||
| return {}; | ||
| auto name = p.parse_while([](unsigned char ch) { return std::isalnum(ch) or ch == '_'; }); |
There was a problem hiding this comment.
parse_func_or_var uses std::isalpha(c) on a char from peek_char(). In this repo we typically cast to unsigned char before ctype functions to avoid UB when char is signed (e.g. src/driver/main.cpp:357-359). Please use std::isalpha(static_cast<unsigned char>(c)) here (and similarly for any other new ctype calls on char).
| interval log(interval x) { return {std::log(to<double>(x.min)), std::log(to<double>(x.max))}; } | ||
|
|
||
| static add_parts extract_add(const expr_ptr& e) | ||
| interval sqrt(interval x) | ||
| { | ||
| return std::visit( | ||
| overloaded{[](const integer_data& d) -> add_parts { return {d.value, {}}; }, | ||
| [](const add_data& d) -> add_parts { return {d.constant, d.terms}; }, | ||
| [&](const mul_data& d) -> add_parts { | ||
| auto base = build_mul(1, d.factors); | ||
| return {0, {{base, d.coefficient}}}; | ||
| }, | ||
| [&](const auto&) -> add_parts { return {0, {{e, 1}}}; }}, | ||
| e->data); | ||
| auto lo = std::sqrt(std::max(0.0, to<double>(x.min))); | ||
| auto hi = std::sqrt(std::max(0.0, to<double>(x.max))); | ||
| return {lo, hi}; |
There was a problem hiding this comment.
interval log(interval) and interval sqrt(interval) don't handle out-of-domain inputs robustly. log on an interval with min <= 0 will produce NaNs via std::log, and sqrt currently clamps negative endpoints to 0 which can yield a seemingly-valid interval for an invalid domain (e.g. sqrt([-4,-1]) becomes [0,0]). Consider returning a conservative extended-real bound (e.g. log([-1, 5]) -> [-inf, log(5)]) or throwing when the interval is entirely out of domain (max <= 0 for log, max < 0 for sqrt).
| std::size_t expr::eval_uint(const std::unordered_map<expr, std::size_t>& symbol_map) const | ||
| { | ||
| if(empty()) | ||
| return 0; | ||
| binding_map bindings; | ||
| for(const auto& [k, v] : symbol_map) | ||
| { | ||
| if(k.empty() or not holds<symbol_data>(k.p->node)) | ||
| MIGRAPHX_THROW("sym::expr::eval_uint: map key '" + k.to_string() + "' is not a symbol"); | ||
| bindings[k.p->node] = v; | ||
| } | ||
| auto v = eval_direct(p->node, bindings); | ||
| if(v < 0) | ||
| MIGRAPHX_THROW("sym::expr::eval_uint: expression evaluated to negative value"); | ||
| return v; | ||
| return to<std::size_t>(generic_eval<scalar>(*this, [&](const expr& e) -> std::optional<scalar> { | ||
| auto it = symbol_map.find(e); | ||
| if(it != symbol_map.end()) | ||
| return make_scalar(it->second); | ||
| return std::visit( | ||
| overloaded{[](const literal_node& n) -> std::optional<scalar> { return n.val; }, | ||
| [](const auto&) -> std::optional<scalar> { return std::nullopt; }}, | ||
| get_node(e)); | ||
| })); | ||
| } |
There was a problem hiding this comment.
expr::eval_uint no longer handles empty expressions and can hit get_node(e) on an empty expr (assert/deref). It also converts the evaluated scalar to std::size_t without checking for negative or non-integer results, so lit(-1).eval_uint({}) would silently wrap to a huge value. Please add an explicit empty guard and validate the evaluated scalar is a non-negative integer before converting/returning.
| scalar expr::eval(const std::unordered_map<expr, scalar>& vars) const | ||
| { | ||
| if(a.empty() or b.empty()) | ||
| return {}; | ||
| return {std::make_shared<expr::impl>(make_add(a.p->node, b.p->node))}; | ||
| return generic_eval<scalar>(*this, [&](const expr& e) -> std::optional<scalar> { | ||
| auto it = vars.find(e); | ||
| if(it != vars.end()) | ||
| return it->second; | ||
| return std::visit( | ||
| overloaded{[](const literal_node& n) -> std::optional<scalar> { return n.val; }, | ||
| [](const auto&) -> std::optional<scalar> { return std::nullopt; }}, | ||
| get_node(e)); | ||
| }); |
There was a problem hiding this comment.
expr::eval calls get_node(e) via the replace lambda even when *this is empty; get_node asserts pimpl != nullptr, so evaluating an empty expression will crash. Please add an explicit if(empty()) behavior (e.g., throw MIGRAPHX_THROW with a clear message) or handle empty in the replace lambda before touching get_node.
| if((std::isdigit(p.peek_char()) == 0) and p.peek_char() != '.') | ||
| return {}; | ||
| auto token = p.parse_while([](unsigned char c) { return std::isdigit(c) or c == '.'; }); | ||
| bool is_float = token.find('.') != std::string_view::npos; | ||
| if(is_float) | ||
| return lit(std::stod(std::string(token))); | ||
| return lit(std::stoll(std::string(token))); |
There was a problem hiding this comment.
parse_number calls std::isdigit(p.peek_char()) where peek_char() is a char; if char is signed and the input contains non-ASCII bytes this is undefined behavior. Also, tokenization accepts any mix of digits and '.' (e.g. "." or "1..2"), which can make std::stod/std::stoll throw a standard exception rather than a MIGRAPHX_THROW with location context. Consider casting to unsigned char for ctype calls, validating the numeric token (single '.' / at least one digit), and converting parse failures into MIGRAPHX_THROW(p.error_message(...)).
| interval expr::eval_interval(const std::unordered_map<expr, interval>& vars) const | ||
| { | ||
| if(a.empty() or b.empty()) | ||
| return {}; | ||
| return {std::make_shared<expr::impl>(make_sub(a.p->node, b.p->node))}; | ||
| return generic_eval<interval>(*this, [&](const expr& e) -> std::optional<interval> { | ||
| auto it = vars.find(e); | ||
| if(it != vars.end()) | ||
| return it->second; | ||
| return std::visit( | ||
| overloaded{[](const literal_node& n) -> std::optional<interval> { | ||
| return interval{n.val, n.val}; | ||
| }, | ||
| [](const variable_node& n) -> std::optional<interval> { | ||
| if(not n.constraints.empty()) | ||
| return n.constraints.front(); | ||
| MIGRAPHX_THROW("Variable '" + n.name + "' not found in interval map"); | ||
| }, | ||
| [](const op_node&) -> std::optional<interval> { return std::nullopt; }}, | ||
| get_node(e)); | ||
| }); |
There was a problem hiding this comment.
expr::eval_interval has the same empty-expression issue as eval: if *this is empty, the replace lambda falls through to get_node(e) and will assert/crash. Consider handling empty() up front (or in the replace lambda) with a consistent policy (throw or return a sentinel interval).
| static std::string scalar_to_string(const scalar& v) | ||
| { | ||
| if(a.empty() and b.empty()) | ||
| return true; | ||
| if(a.empty() != b.empty()) | ||
| return false; | ||
| return expr_equal(a.p->node, b.p->node); | ||
| return std::visit( | ||
| [](auto x) -> std::string { | ||
| std::ostringstream ss; | ||
| ss << x; | ||
| return ss.str(); | ||
| }, | ||
| v); |
There was a problem hiding this comment.
scalar_to_string uses default std::ostringstream formatting for doubles, which may emit scientific notation (e.g. 1e-06). However parse_number only accepts digits and '.' and will reject/throw on that output, breaking parse(to_string(expr)) for many valid doubles. Please either (a) extend parse_number to accept exponent forms, or (b) force a non-scientific, round-trippable formatting in scalar_to_string (and add coverage for such cases).
| auto inner = lit(7) * lit(2); | ||
| EXPECT(inner.eval_uint({}) == 14); |
There was a problem hiding this comment.
eval_uint_symbol_map_partial currently doesn't use e or the symbol map at all; it just evaluates lit(7) * lit(2). This doesn't test partial evaluation/mapping behavior as the name/comment suggest. Consider changing it to actually call e.eval_uint({{x, 7}}) (or removing/renaming the test if partial mapping isn't supported).
| auto inner = lit(7) * lit(2); | |
| EXPECT(inner.eval_uint({}) == 14); | |
| EXPECT(e.eval_uint({{x, 7}}) == 14); |
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Motivation
This PR introduces a revamped symbolic expression subsystem (
migraphx::sym) with richer scalar/interval semantics, parsing, rewriting/simplification, and broad new test coverage. It also adds a reusable lightweight parser utility used by the symbolic expression parser.Technical Details
Changelog Category
Add a
CHANGELOG.mdentry for any option other thanNot Applicable