Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions src/common.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -235,19 +235,43 @@
return insert_common_op(m, m.end(), op, std::move(inputs), options);
}

shape make_bcast_shape(const shape& input_shape, const std::vector<std::size_t>& bcast_lens)
// Unified broadcast-shape builder: propagates input strides on matching axes and leaves
// `zero` on broadcast axes. Works for both static (Dim = std::size_t, Stride = std::size_t)
// and symbolic (Dim = shape::dynamic_dimension, Stride = sym::expr) representations.
template <class Dim, class Stride>
static shape make_bcast_shape_impl(shape::type_t type,
const std::vector<Dim>& input_dims,
const std::vector<Stride>& input_strides,
const std::vector<Dim>& bcast_dims,
Stride zero)

Check warning on line 246 in src/common.cpp

View workflow job for this annotation

GitHub Actions / tidy

the parameter 'zero' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param,-warnings-as-errors]
{
assert(not input_shape.dynamic());
auto offset = bcast_lens.size() - input_shape.ndim();
std::vector<size_t> bcast_strides(bcast_lens.size(), 0);
for(std::ptrdiff_t i : reverse(range(input_shape.ndim())))
assert(bcast_dims.size() >= input_dims.size());
auto offset = bcast_dims.size() - input_dims.size();
std::vector<Stride> bcast_strides(bcast_dims.size(), zero);
for(std::size_t i = 0; i < input_dims.size(); ++i)
{
if(bcast_lens.at(i + offset) == input_shape.lens()[i])
{
bcast_strides.at(i + offset) = input_shape.strides()[i];
}
if(bcast_dims[i + offset] == input_dims[i])
bcast_strides[i + offset] = input_strides[i];
}
Comment thread
shivadbhavsar marked this conversation as resolved.
return shape{input_shape.type(), bcast_lens, bcast_strides};
return {type, bcast_dims, bcast_strides};
}

shape make_bcast_shape(const shape& input_shape, const std::vector<std::size_t>& bcast_lens)
{
assert(not input_shape.dynamic());
return make_bcast_shape_impl(
input_shape.type(), input_shape.lens(), input_shape.strides(), bcast_lens, std::size_t{0});
}

shape make_bcast_shape(const shape& input_shape,
const std::vector<shape::dynamic_dimension>& bcast_dyn_dims)
{
assert(input_shape.symbolic());
return make_bcast_shape_impl(input_shape.type(),
input_shape.dyn_dims(),
input_shape.dyn_strides(),
bcast_dyn_dims,
sym::lit(0));
Comment thread
shivadbhavsar marked this conversation as resolved.
}

} // namespace MIGRAPHX_INLINE_NS
Expand Down
15 changes: 15 additions & 0 deletions src/include/migraphx/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,21 @@ instruction_ref add_common_op(module& m,
MIGRAPHX_EXPORT
shape make_bcast_shape(const shape& input_shape, const std::vector<std::size_t>& bcast_lens);

/**
* Calculates the broadcasted shape from a symbolic broadcast target. The input shape MUST
* already be symbolic (`input_shape.symbolic()`); callers bridging from a static shape
* should promote via `shape::to_symbolic()` first. Mirrors the single-modality contract of
* the static overload above. Broadcast axes receive `sym::lit(0)`; matching axes propagate
* the input's symbolic stride.
*
* @param input_shape symbolic dynamic shape to broadcast
* @param bcast_dyn_dims symbolic dynamic dimensions to broadcast to
* @return broadcasted shape with symbolic dyn_strides
*/
MIGRAPHX_EXPORT
shape make_bcast_shape(const shape& input_shape,
const std::vector<shape::dynamic_dimension>& bcast_dyn_dims);

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_HPP
1 change: 1 addition & 0 deletions src/include/migraphx/matcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,7 @@ MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins)
}
MIGRAPHX_PRED_MATCHER(dynamic_shape, instruction_ref ins) { return ins->get_shape().dynamic(); }
MIGRAPHX_PRED_MATCHER(static_shape, instruction_ref ins) { return not ins->get_shape().dynamic(); }
MIGRAPHX_PRED_MATCHER(symbolic_shape, instruction_ref ins) { return ins->get_shape().symbolic(); }
MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins)
{
return ins->get_shape().broadcasted();
Expand Down
10 changes: 8 additions & 2 deletions src/include/migraphx/op/binary.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -70,7 +70,8 @@ struct binary : op_name<Derived>
.same_dims();
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if(s0.dynamic() or s1.dynamic())
// Range-based dynamic (or mixed dynamic/static) inputs only support strict equality.
if((s0.dynamic() or s1.dynamic()) and not(s0.symbolic() and s1.symbolic()))
{
if(s0 == s1)
return s0;
Expand All @@ -86,10 +87,15 @@ struct binary : op_name<Derived>
}
else if(s0.broadcasted() != s1.broadcasted())
{
if(s0.symbolic())
return s0.broadcasted() ? s1.with_lens(s0.dyn_dims()) : s0.with_lens(s0.dyn_dims());
return s0.broadcasted() ? s1.with_lens(s0.lens()) : s0.with_lens(s0.lens());
}
else
{
if(s0.symbolic())
return shape::from_permutation(
s0.type(), s0.dyn_dims(), find_permutation({s0, s1}));
return shape::from_permutation(s0.type(), s0.lens(), find_permutation({s0, s1}));
}
}
Expand Down
82 changes: 59 additions & 23 deletions src/include/migraphx/op/broadcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,28 @@
* (left-most to rightwards element-wise comparison)
* ex: broadcasting shape [2, 2] -> [2, 2, 3] with axis = 0
*
* Symbolic 1 input version: opt-in via a fully-symbolic `output_dyn_dims` attribute. Input may be
* static (promoted via shape::to_symbolic()) or already symbolic. Range-based dynamic input is
* not allowed (per the "no mixing symbolic and range-based" design rule). `broadcast_lens` is
* not used in this mode.
*
* 2 input version:
* Broadcast the first input 1D shape into the second input shape based on the axis parameter.
* Handles broadcasting a 1D static shape into a higher rank dynamic shape.
* broadcast_lens is not used
*/
struct broadcast
{
uint64_t axis = 0;
std::vector<std::size_t> broadcast_lens = {};
uint64_t axis = 0;
std::vector<std::size_t> broadcast_lens = {};
std::vector<shape::dynamic_dimension> output_dyn_dims = {};

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"), f(self.broadcast_lens, "out_lens"));
return pack(f(self.axis, "axis"),
f(self.broadcast_lens, "out_lens"),
f(self.output_dyn_dims, "out_dyn_dims"));
}

std::string name() const { return "broadcast"; }
Expand All @@ -67,30 +75,47 @@
check_shapes{inputs, *this, true}.has(1, 2);
auto s0 = inputs.at(0);
auto t = s0.type();

// Validate axis/dims and place `in_strides` at `axis`, filling broadcast positions with
// `zero`. Templated for the static (size_t) and symbolic (sym::expr) paths.
auto build_output =
[&](const auto& target, const auto& in_dims, const auto& in_strides, auto zero) {

Check warning on line 82 in src/include/migraphx/op/broadcast.hpp

View workflow job for this annotation

GitHub Actions / tidy

the parameter 'zero' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param,-warnings-as-errors]

Check warning on line 82 in src/include/migraphx/op/broadcast.hpp

View workflow job for this annotation

GitHub Actions / tidy

the parameter 'zero' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param,-warnings-as-errors]

Check warning on line 82 in src/include/migraphx/op/broadcast.hpp

View workflow job for this annotation

GitHub Actions / tidy

the parameter 'zero' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param,-warnings-as-errors]
if(axis >= target.size())
MIGRAPHX_THROW("BROADCAST : axis " + migraphx::to_string(axis) +
" is out of range");
if(target.size() - axis < in_dims.size())
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims");
for(std::size_t i = 0; i < in_dims.size(); ++i)
{
if(target[axis + i] != in_dims[i])
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
}
std::vector<decltype(zero)> bcast_strides(target.size(), zero);
std::copy(in_strides.begin(), in_strides.end(), bcast_strides.begin() + axis);
return shape{t, target, std::move(bcast_strides)};
};

if(inputs.size() == 1)
{
// the ONNX broadcast op is deprecated now, so not handling the negative
// value of axis anymore
if(s0.dynamic())
const bool symbolic_target = not output_dyn_dims.empty() and
std::all_of(output_dyn_dims.begin(),
output_dyn_dims.end(),
[](const auto& d) { return d.is_symbolic(); });

if(s0.dynamic() and not(symbolic_target and s0.symbolic()))
MIGRAPHX_THROW(
"BROADCAST: Single dynamic input shape not supported. Use two inputs.");
if(axis >= broadcast_lens.size())
{
MIGRAPHX_THROW("BROADCAST : axis " + migraphx::to_string(axis) +
" is out of range");
}
if(broadcast_lens.size() - axis < s0.lens().size())
{
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims");
}
if(not std::equal(s0.lens().begin(), s0.lens().end(), broadcast_lens.begin() + axis))

Comment thread
shivadbhavsar marked this conversation as resolved.
Outdated
if(symbolic_target)
{
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
auto s0_sym = s0.to_symbolic();
return build_output(
output_dyn_dims, s0_sym.dyn_dims(), s0_sym.dyn_strides(), sym::lit(0));
}

std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
shape output{t, broadcast_lens, std::move(bcast_strides)};
auto output = build_output(broadcast_lens, s0.lens(), s0.strides(), std::size_t{0});
if(output.elements() < s0.elements())
{
// don't think this can occur?
Expand All @@ -102,21 +127,32 @@
{
// two inputs
auto s1 = inputs.at(1);
if(s0.dynamic())
{
MIGRAPHX_THROW("BROADCAST_2in: s0 is a dynamic shape, does not handle broadcasting "
"a dynamic shape");
}
if(s0.ndim() != 1)
{
MIGRAPHX_THROW("BROADCAST_2in: s0 has ndim " + migraphx::to_string(s0.ndim()) +
", only handle ndim = 1");
}

if(s0.symbolic() or s1.symbolic())
{
auto s0_sym = s0.to_symbolic();
auto s1_sym = s1.to_symbolic();
return build_output(
s1_sym.dyn_dims(), s0_sym.dyn_dims(), s0_sym.dyn_strides(), sym::lit(0));
}

if(axis >= s1.ndim())
{
MIGRAPHX_THROW("BROADCAST_2in: axis " + migraphx::to_string(axis) +
" is out of range");
}

// Range-based dynamic s0 alone is not supported.
if(s0.dynamic())
{
MIGRAPHX_THROW("BROADCAST_2in: s0 is a dynamic shape, does not handle broadcasting "
"a dynamic shape");
}
if(s1.dynamic())
{
s0 = s0.to_dynamic();
Expand Down
46 changes: 31 additions & 15 deletions src/include/migraphx/op/multibroadcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ namespace op {
/**
* Broadcast multiple dimensions between two tensors.
* Two versions of this operator: 1 input and 2+ inputs.
* One input version uses output_lens attribute and broadcasts to it.
* One input version uses output_lens (static target) or output_dyn_dims (symbolic target);
* see compute_shape for the symbolic single-input contract.
* 2+ inputs version broadcasts first input to the common shape at evaluation time.
*/
struct multibroadcast
Expand All @@ -59,7 +60,7 @@ struct multibroadcast
{
check_shapes{inputs, *this, true}.has_at_least(1);

auto t = inputs.at(0).type();
auto t = inputs.at(0).type();
const auto& s0 = inputs.at(0);

if(s0.ndim() < 1)
Expand All @@ -69,25 +70,40 @@ struct multibroadcast

if(inputs.size() == 1)
{
if(s0.dynamic())
// Symbolic 1-input mode: opt-in via a fully-symbolic output_dyn_dims attribute.
// Input may be static (bridged via to_symbolic()) or already symbolic.
// Range-based dynamic input is not allowed.
const bool symbolic_target = not output_dyn_dims.empty() and
std::all_of(output_dyn_dims.begin(),
output_dyn_dims.end(),
[](const auto& d) { return d.is_symbolic(); });

if(s0.dynamic() and not(symbolic_target and s0.symbolic()))
MIGRAPHX_THROW(
"MULTIBROADCAST: Single dynamic input shape not supported. Use two inputs.");
if(s0.ndim() > output_lens.size())
{
MIGRAPHX_THROW("MULTIBROADCAST: input dimensions should <= output size");
}

Comment thread
shivadbhavsar marked this conversation as resolved.
auto offset = output_lens.size() - s0.ndim();
for(std::ptrdiff_t i = s0.ndim() - 1; i >= 0; i--)
{
if(output_lens[i + offset] != s0.lens()[i] and s0.lens()[i] != 1)
// Shared validation: input dims must align with target dims, with axis-1 broadcast.
auto validate = [](const auto& in_dims, const auto& out_dims) {
if(in_dims.size() > out_dims.size())
MIGRAPHX_THROW("MULTIBROADCAST: input dimensions should <= output size");
auto offset = out_dims.size() - in_dims.size();
for(std::ptrdiff_t i = in_dims.size() - 1; i >= 0; --i)
{
MIGRAPHX_THROW("MULTIBROADCAST: input shape {" + to_string_range(s0.lens()) +
"} cannot be broadcasted to {" + to_string_range(output_lens) +
"}!");
if(out_dims[i + offset] != in_dims[i] and in_dims[i] != 1)
MIGRAPHX_THROW("MULTIBROADCAST: input shape {" + to_string_range(in_dims) +
"} cannot be broadcasted to {" + to_string_range(out_dims) +
"}!");
}
};

if(symbolic_target)
{
auto s0_sym = s0.to_symbolic();
validate(s0_sym.dyn_dims(), output_dyn_dims);
return make_bcast_shape(s0_sym, output_dyn_dims);
}

validate(s0.lens(), output_lens);
return make_bcast_shape(s0, output_lens);
}
else
Expand All @@ -105,7 +121,7 @@ struct multibroadcast
else
{
// output_lens will not be set for 2+ input version
auto bcast_lens = compute_common_lens(inputs);
auto bcast_lens = compute_common_lens(inputs);
return make_bcast_shape(s0, bcast_lens);
}
}
Expand Down
5 changes: 5 additions & 0 deletions src/include/migraphx/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,11 @@ struct MIGRAPHX_EXPORT shape
// convert the shape to an equivalent dynamic shape with constant symbolic strides
shape to_dynamic() const;

// convert the shape to an equivalent symbolic dynamic shape: each static len becomes
// dd{sym::lit(len)} and each static stride becomes sym::lit(stride). Idempotent on a
// shape that is already symbolic. Throws on a range-based dynamic shape.
shape to_symbolic() const;

// convert the shape to a static one setting any non-fixed dynamic_dimensions to x
shape to_static(std::size_t x) const;
shape to_static(const std::unordered_map<sym::expr, std::size_t>& symbol_map) const;
Expand Down
34 changes: 34 additions & 0 deletions src/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,40 @@ shape shape::to_dynamic() const
return {type(), std::move(dims), std::move(dstrides)};
}

shape shape::to_symbolic() const
{
if(not sub_shapes().empty())
{
std::vector<shape> subs;
std::transform(sub_shapes().cbegin(),
sub_shapes().cend(),
std::back_inserter(subs),
[](auto s) { return s.to_symbolic(); });
return shape(subs);
}
if(this->symbolic())
{
return *this;
}
if(this->dynamic())
{
// Range-based dynamic shapes have no clean symbolic representation: range info
// would have to be silently dropped. Disallowed by the "no mixing" design rule.
MIGRAPHX_THROW("SHAPE: to_symbolic() called on a range-based dynamic shape");
}
std::vector<dynamic_dimension> dims;
dims.reserve(ndim());
std::transform(lens().begin(), lens().end(), std::back_inserter(dims), [](auto len) {
return dynamic_dimension{sym::lit(len)};
});
std::vector<sym::expr> dstrides;
dstrides.reserve(ndim());
std::transform(strides().begin(), strides().end(), std::back_inserter(dstrides), [](auto s) {
return sym::lit(s);
});
return {type(), std::move(dims), std::move(dstrides)};
}

shape shape::to_static(std::size_t x) const
{
if(not sub_shapes().empty())
Expand Down
Loading
Loading