Skip to content

Add symbolic expression#4782

Open
pfultz2 wants to merge 159 commits intodevelopfrom
sym-expr2
Open

Add symbolic expression#4782
pfultz2 wants to merge 159 commits intodevelopfrom
sym-expr2

Conversation

@pfultz2
Copy link
Copy Markdown
Collaborator

@pfultz2 pfultz2 commented Apr 13, 2026

Motivation

This PR introduces a revamped symbolic expression subsystem (migraphx::sym) with richer scalar/interval semantics, parsing, rewriting/simplification, and broad new test coverage. It also adds a reusable lightweight parser utility used by the symbolic expression parser.

Technical Details

Changelog Category

Add a CHANGELOG.md entry for any option other than Not Applicable

    • Added: New functionality.
    • Changed: Changes to existing functionality.
    • Removed: Functionality or support that has been removed. (Compared to a previous release)
    • Optimized: Component performance that has been optimized or improved.
    • Resolved Issues: Known issues from a previous version that have been resolved.
    • Not Applicable: This PR is not to be included in the changelog.

Copy link
Copy Markdown
Contributor

@shivadbhavsar shivadbhavsar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a quick initial scan focusing on the interface mostly and it should be relatively straight forward to integrate (I think a few explicit conversions from scalar to size_t need to be added in some places). I'd prefer to get #4702 in first and make sure that all the shape integration tests also pass here.

I'll do another round soon to actually go through the implementation details

: pimpl(make_impl(std::move(node), std::move(children)))
{
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a is_literal somewhere in here, its need for is_fixed in the shape class.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can check the name of the node: expr.name() == "literal".

Comment thread src/sym.cpp
scalar_max(scalar_max(p1, p2), scalar_max(p3, p4))};
}

interval operator/(interval a, interval b)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this should probably throw if interval b crosses 0 (eg. [-2, 4]). Wont really affect our current use case but would be good for correctness

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So [-2, 4] doesnt mean we will divide by zero. These are ranges of max and min and doesnt mean every value in between is always used.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well for floats especially, the true resulting interval would be unbounded. And also for integers im not sure this gives the correct result, bounds would have to be computed at -1 and 1

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, the formula is derived by taking the reciprocal. The reciprocal formula is [1/max, 1/min] so for [-2, 4], it produces [1/3, -1/2] which is not a correct interval. The correct answer is the union of two intervals (-inf, -1/2) U (1/3, inf), which as a single interval will have to be (-inf, inf).

Comment thread src/sym.cpp Outdated
Comment on lines +1097 to +1099

if(var_optimals.empty())
return {eval({})};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we want to add optimals at all in this case?

Comment thread src/sym.cpp Outdated
expr operator+(const expr& a, const expr& b)
bool expr::is_raw() const { return pimpl and pimpl->raw_flag; }

const std::vector<expr>& expr::children() const { return pimpl->children; }
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing null check?

Comment thread src/sym.cpp
friend bool operator==(const literal_node& a, const literal_node& b)
{
return scalar_invoke_common<bool>(
[](auto a, auto b) { return float_equal(a, b); }, a.val, b.val);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this makes lit(1) == lit(1.0). But they will hash differently when using expr in maps and sets

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok let me fix the hashing.

Comment thread src/sym.cpp
}

static expr_ptr parse_unary(const char*& p)
std::size_t expr::eval_uint(const std::unordered_map<expr, std::size_t>& symbol_map) const
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be good to explicitly throw if eval result is negative

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 11 changed files in this pull request and generated 12 comments.

Comment on lines +272 to +274
bool empty() const { return start == last; }

std::size_t size() const { return std::distance(start, last); }
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iterator_range::size() uses std::distance, but ranges.hpp doesn’t include <iterator>. Please include it explicitly rather than relying on transitive includes.

Copilot uses AI. Check for mistakes.
Comment thread src/sym.cpp Outdated
Comment on lines +93 to +97
auto compute = [&](const scalar& x, const scalar& y) -> scalar {
if(std::holds_alternative<int64_t>(x) and std::holds_alternative<int64_t>(y))
return f(std::get<int64_t>(x), std::get<int64_t>(y));
return fd(to<double>(x), to<double>(y));
};
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interval operator%: the integer branch uses % directly and will hit UB if any endpoint of b is 0 (or if b spans 0). This should be detected and handled explicitly.

Copilot uses AI. Check for mistakes.
Comment thread src/sym.cpp
Comment on lines +1408 to +1409
if((std::isdigit(p.peek_char()) == 0) and p.peek_char() != '.')
return {};
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::isdigit(p.peek_char()) passes a char directly into ctype; if char is signed this is UB for negative values. Cast to unsigned char before calling ctype functions (consistent with src/driver/main.cpp:357-359).

Copilot uses AI. Check for mistakes.
Comment thread src/sym.cpp
Comment on lines +944 to +945
overloaded{[](const literal_node& n) -> std::optional<scalar> { return n.val; },
[](const auto&) -> std::optional<scalar> { return std::nullopt; }},
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In eval_uint, the replacer returns nullopt for variable_node when the variable isn’t in the map. For a leaf variable this means generic_eval will eventually try std::get<op_node>(...) on a variable_node, throwing std::bad_variant_access instead of a controlled MIGRAPHX_THROW (e.g. “unbound symbol”). Handle unbound variables explicitly in the replacer (and consider validating that map keys are variables).

Suggested change
overloaded{[](const literal_node& n) -> std::optional<scalar> { return n.val; },
[](const auto&) -> std::optional<scalar> { return std::nullopt; }},
overloaded{
[](const literal_node& n) -> std::optional<scalar> { return n.val; },
[](const variable_node&) -> std::optional<scalar> {
MIGRAPHX_THROW("Unbound symbol in eval_uint");
},
[](const auto&) -> std::optional<scalar> { return std::nullopt; }},

Copilot uses AI. Check for mistakes.
Comment on lines +233 to +236
if constexpr(std::is_integral<T>{})
return lit(static_cast<int64_t>(x));
else
return lit(static_cast<double>(x));
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arg(T) casts integral values to int64_t directly. For large unsigned values this is implementation-defined / lossy and can wrap. Prefer lit(make_scalar(x)) here to match the clipping/typing rules used elsewhere.

Suggested change
if constexpr(std::is_integral<T>{})
return lit(static_cast<int64_t>(x));
else
return lit(static_cast<double>(x));
return lit(make_scalar(x));

Copilot uses AI. Check for mistakes.
Comment thread src/sym.cpp
MIGRAPHX_CALL_FUNC(ceil),
};
#undef MIGRAPHX_CALL_FUNC
return functions.at(name)(args);
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call_function uses functions.at(name), so unknown names throw std::out_of_range and lose useful context. Prefer a find() + MIGRAPHX_THROW that includes the unknown function/operator name (and parse position when called from parsing).

Suggested change
return functions.at(name)(args);
auto it = functions.find(name);
if(it == functions.end())
MIGRAPHX_THROW("Unknown function/operator: " + name);
return it->second(args);

Copilot uses AI. Check for mistakes.
Comment thread src/sym.cpp
Comment on lines +1419 to +1422
char c = p.peek_char();
if((std::isalpha(c) == 0) and c != '_')
return {};
auto name = p.parse_while([](unsigned char ch) { return std::isalnum(ch) or ch == '_'; });
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue here: std::isalpha(c) is called with a char. Cast to unsigned char (and similarly for std::isalnum) to avoid UB on non-ASCII bytes.

Copilot uses AI. Check for mistakes.
Comment thread src/sym.cpp
Comment on lines +939 to +940
return to<std::size_t>(generic_eval<scalar>(*this, [&](const expr& e) -> std::optional<scalar> {
auto it = symbol_map.find(e);
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, eval_uint converts the final scalar to std::size_t without checking for negative results. A negative expression value will silently wrap to a huge size_t. Consider preserving the previous behavior of throwing if the evaluated value is negative.

Copilot uses AI. Check for mistakes.
Comment thread test/simple_parser.cpp
Comment on lines +64 to +66
auto letters = p.parse_while([](char c) { return std::isalpha(c); });
EXPECT(letters == "abc");
auto digits = p.parse_while([](char c) { return std::isdigit(c); });
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::isalpha / std::isdigit are called with char values here. If char is signed, passing negative values to ctype is UB. Please cast to unsigned char in these predicates (consistent with src/driver/main.cpp:357-359).

Suggested change
auto letters = p.parse_while([](char c) { return std::isalpha(c); });
EXPECT(letters == "abc");
auto digits = p.parse_while([](char c) { return std::isdigit(c); });
auto letters =
p.parse_while([](char c) { return std::isalpha(static_cast<unsigned char>(c)); });
EXPECT(letters == "abc");
auto digits =
p.parse_while([](char c) { return std::isdigit(static_cast<unsigned char>(c)); });

Copilot uses AI. Check for mistakes.
Comment on lines +58 to +62
template <class Iterator>
void hash_range(std::size_t& seed, Iterator first, Iterator last)
{
std::for_each(first, last, [&](const auto& x) { hash_combine(seed, x); });
}
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hash_range uses std::for_each, but this header doesn’t include <algorithm>, which can break compilation depending on include order. Please add <algorithm> (or avoid std::for_each).

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 11 changed files in this pull request and generated 14 comments.

Comment thread src/sym.cpp
Comment on lines +1453 to +1456
char c = p.peek_char();
if((std::isalpha(c) == 0) and c != '_')
return {};
auto name = p.parse_while([](unsigned char ch) { return std::isalnum(ch) or ch == '_'; });
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parse_func_or_var uses std::isalpha(c) on a char from peek_char(). In this repo we typically cast to unsigned char before ctype functions to avoid UB when char is signed (e.g. src/driver/main.cpp:357-359). Please use std::isalpha(static_cast<unsigned char>(c)) here (and similarly for any other new ctype calls on char).

Copilot uses AI. Check for mistakes.
Comment thread src/sym.cpp
Comment on lines +218 to +224
interval log(interval x) { return {std::log(to<double>(x.min)), std::log(to<double>(x.max))}; }

static add_parts extract_add(const expr_ptr& e)
interval sqrt(interval x)
{
return std::visit(
overloaded{[](const integer_data& d) -> add_parts { return {d.value, {}}; },
[](const add_data& d) -> add_parts { return {d.constant, d.terms}; },
[&](const mul_data& d) -> add_parts {
auto base = build_mul(1, d.factors);
return {0, {{base, d.coefficient}}};
},
[&](const auto&) -> add_parts { return {0, {{e, 1}}}; }},
e->data);
auto lo = std::sqrt(std::max(0.0, to<double>(x.min)));
auto hi = std::sqrt(std::max(0.0, to<double>(x.max)));
return {lo, hi};
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interval log(interval) and interval sqrt(interval) don't handle out-of-domain inputs robustly. log on an interval with min <= 0 will produce NaNs via std::log, and sqrt currently clamps negative endpoints to 0 which can yield a seemingly-valid interval for an invalid domain (e.g. sqrt([-4,-1]) becomes [0,0]). Consider returning a conservative extended-real bound (e.g. log([-1, 5]) -> [-inf, log(5)]) or throwing when the interval is entirely out of domain (max <= 0 for log, max < 0 for sqrt).

Copilot uses AI. Check for mistakes.
Comment thread src/sym.cpp
Comment on lines 971 to 982
std::size_t expr::eval_uint(const std::unordered_map<expr, std::size_t>& symbol_map) const
{
if(empty())
return 0;
binding_map bindings;
for(const auto& [k, v] : symbol_map)
{
if(k.empty() or not holds<symbol_data>(k.p->node))
MIGRAPHX_THROW("sym::expr::eval_uint: map key '" + k.to_string() + "' is not a symbol");
bindings[k.p->node] = v;
}
auto v = eval_direct(p->node, bindings);
if(v < 0)
MIGRAPHX_THROW("sym::expr::eval_uint: expression evaluated to negative value");
return v;
return to<std::size_t>(generic_eval<scalar>(*this, [&](const expr& e) -> std::optional<scalar> {
auto it = symbol_map.find(e);
if(it != symbol_map.end())
return make_scalar(it->second);
return std::visit(
overloaded{[](const literal_node& n) -> std::optional<scalar> { return n.val; },
[](const auto&) -> std::optional<scalar> { return std::nullopt; }},
get_node(e));
}));
}
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expr::eval_uint no longer handles empty expressions and can hit get_node(e) on an empty expr (assert/deref). It also converts the evaluated scalar to std::size_t without checking for negative or non-integer results, so lit(-1).eval_uint({}) would silently wrap to a huge value. Please add an explicit empty guard and validate the evaluated scalar is a non-negative integer before converting/returning.

Copilot uses AI. Check for mistakes.
Comment thread src/sym.cpp
Comment on lines +1081 to +1091
scalar expr::eval(const std::unordered_map<expr, scalar>& vars) const
{
if(a.empty() or b.empty())
return {};
return {std::make_shared<expr::impl>(make_add(a.p->node, b.p->node))};
return generic_eval<scalar>(*this, [&](const expr& e) -> std::optional<scalar> {
auto it = vars.find(e);
if(it != vars.end())
return it->second;
return std::visit(
overloaded{[](const literal_node& n) -> std::optional<scalar> { return n.val; },
[](const auto&) -> std::optional<scalar> { return std::nullopt; }},
get_node(e));
});
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expr::eval calls get_node(e) via the replace lambda even when *this is empty; get_node asserts pimpl != nullptr, so evaluating an empty expression will crash. Please add an explicit if(empty()) behavior (e.g., throw MIGRAPHX_THROW with a clear message) or handle empty in the replace lambda before touching get_node.

Copilot uses AI. Check for mistakes.
Comment thread src/sym.cpp
Comment on lines +1442 to +1448
if((std::isdigit(p.peek_char()) == 0) and p.peek_char() != '.')
return {};
auto token = p.parse_while([](unsigned char c) { return std::isdigit(c) or c == '.'; });
bool is_float = token.find('.') != std::string_view::npos;
if(is_float)
return lit(std::stod(std::string(token)));
return lit(std::stoll(std::string(token)));
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parse_number calls std::isdigit(p.peek_char()) where peek_char() is a char; if char is signed and the input contains non-ASCII bytes this is undefined behavior. Also, tokenization accepts any mix of digits and '.' (e.g. "." or "1..2"), which can make std::stod/std::stoll throw a standard exception rather than a MIGRAPHX_THROW with location context. Consider casting to unsigned char for ctype calls, validating the numeric token (single '.' / at least one digit), and converting parse failures into MIGRAPHX_THROW(p.error_message(...)).

Copilot uses AI. Check for mistakes.
Comment thread src/sym.cpp Outdated
Comment thread src/sym.cpp
Comment on lines +1094 to +1111
interval expr::eval_interval(const std::unordered_map<expr, interval>& vars) const
{
if(a.empty() or b.empty())
return {};
return {std::make_shared<expr::impl>(make_sub(a.p->node, b.p->node))};
return generic_eval<interval>(*this, [&](const expr& e) -> std::optional<interval> {
auto it = vars.find(e);
if(it != vars.end())
return it->second;
return std::visit(
overloaded{[](const literal_node& n) -> std::optional<interval> {
return interval{n.val, n.val};
},
[](const variable_node& n) -> std::optional<interval> {
if(not n.constraints.empty())
return n.constraints.front();
MIGRAPHX_THROW("Variable '" + n.name + "' not found in interval map");
},
[](const op_node&) -> std::optional<interval> { return std::nullopt; }},
get_node(e));
});
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expr::eval_interval has the same empty-expression issue as eval: if *this is empty, the replace lambda falls through to get_node(e) and will assert/crash. Consider handling empty() up front (or in the replace lambda) with a consistent policy (throw or return a sentinel interval).

Copilot uses AI. Check for mistakes.
Comment thread src/sym.cpp
Comment on lines +1217 to +1225
static std::string scalar_to_string(const scalar& v)
{
if(a.empty() and b.empty())
return true;
if(a.empty() != b.empty())
return false;
return expr_equal(a.p->node, b.p->node);
return std::visit(
[](auto x) -> std::string {
std::ostringstream ss;
ss << x;
return ss.str();
},
v);
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scalar_to_string uses default std::ostringstream formatting for doubles, which may emit scientific notation (e.g. 1e-06). However parse_number only accepts digits and '.' and will reject/throw on that output, breaking parse(to_string(expr)) for many valid doubles. Please either (a) extend parse_number to accept exponent forms, or (b) force a non-scientific, round-trippable formatting in scalar_to_string (and add coverage for such cases).

Copilot uses AI. Check for mistakes.
Comment thread test/sym.cpp
Comment on lines +827 to +828
auto inner = lit(7) * lit(2);
EXPECT(inner.eval_uint({}) == 14);
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eval_uint_symbol_map_partial currently doesn't use e or the symbol map at all; it just evaluates lit(7) * lit(2). This doesn't test partial evaluation/mapping behavior as the name/comment suggest. Consider changing it to actually call e.eval_uint({{x, 7}}) (or removing/renaming the test if partial mapping isn't supported).

Suggested change
auto inner = lit(7) * lit(2);
EXPECT(inner.eval_uint({}) == 14);
EXPECT(e.eval_uint({{x, 7}}) == 14);

Copilot uses AI. Check for mistakes.
Comment thread test/sym.cpp Outdated
pfultz2 and others added 2 commits April 24, 2026 13:18
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants