Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 37 additions & 15 deletions src/common.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -91,10 +91,8 @@ compute_broadcasted_dyn_dims(std::vector<shape::dynamic_dimension> dds0,

std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1)
{
// change both shapes to dynamic_dimension representation
s0 = s0.to_dynamic();
s1 = s1.to_dynamic();
return compute_broadcasted_dyn_dims(s0.dyn_dims(), s1.dyn_dims());
auto aligned = shape::to_dynamic({s0, s1});
return compute_broadcasted_dyn_dims(aligned[0].dyn_dims(), aligned[1].dyn_dims());
}

std::vector<shape::dynamic_dimension> compute_common_dyn_dims(const std::vector<shape>& shapes)
Expand Down Expand Up @@ -235,19 +233,43 @@ instruction_ref add_common_op(module& m,
return insert_common_op(m, m.end(), op, std::move(inputs), options);
}

shape make_bcast_shape(const shape& input_shape, const std::vector<std::size_t>& bcast_lens)
// Unified broadcast-shape builder: propagates input strides on matching axes and leaves
// `zero` on broadcast axes. Works for both static (Dim = std::size_t, Stride = std::size_t)
// and symbolic (Dim = shape::dynamic_dimension, Stride = sym::expr) representations.
template <class Dim, class Stride>
static shape make_bcast_shape_impl(shape::type_t type,
const std::vector<Dim>& input_dims,
const std::vector<Stride>& input_strides,
const std::vector<Dim>& bcast_dims,
const Stride& zero)
{
assert(not input_shape.dynamic());
auto offset = bcast_lens.size() - input_shape.ndim();
std::vector<size_t> bcast_strides(bcast_lens.size(), 0);
for(std::ptrdiff_t i : reverse(range(input_shape.ndim())))
assert(bcast_dims.size() >= input_dims.size());
auto offset = bcast_dims.size() - input_dims.size();
std::vector<Stride> bcast_strides(bcast_dims.size(), zero);
for(std::size_t i = 0; i < input_dims.size(); ++i)
{
if(bcast_lens.at(i + offset) == input_shape.lens()[i])
{
bcast_strides.at(i + offset) = input_shape.strides()[i];
}
if(bcast_dims[i + offset] == input_dims[i])
bcast_strides[i + offset] = input_strides[i];
}
Comment thread
shivadbhavsar marked this conversation as resolved.
return shape{input_shape.type(), bcast_lens, bcast_strides};
return {type, bcast_dims, bcast_strides};
}

shape make_bcast_shape(const shape& input_shape, const std::vector<std::size_t>& bcast_lens)
{
assert(not input_shape.dynamic());
return make_bcast_shape_impl(
input_shape.type(), input_shape.lens(), input_shape.strides(), bcast_lens, std::size_t{0});
}

shape make_bcast_shape(const shape& input_shape,
const std::vector<shape::dynamic_dimension>& bcast_dyn_dims)
{
assert(input_shape.symbolic());
return make_bcast_shape_impl(input_shape.type(),
input_shape.dyn_dims(),
input_shape.dyn_strides(),
bcast_dyn_dims,
sym::lit(0));
Comment thread
shivadbhavsar marked this conversation as resolved.
}

} // namespace MIGRAPHX_INLINE_NS
Expand Down
15 changes: 15 additions & 0 deletions src/include/migraphx/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,21 @@ instruction_ref add_common_op(module& m,
MIGRAPHX_EXPORT
shape make_bcast_shape(const shape& input_shape, const std::vector<std::size_t>& bcast_lens);

/**
* Calculates the broadcasted shape from a symbolic broadcast target. The input shape MUST
* already be symbolic (`input_shape.symbolic()`); callers bridging from a static shape
* should promote via `shape::to_symbolic()` first. Mirrors the single-modality contract of
* the static overload above. Broadcast axes receive `sym::lit(0)`; matching axes propagate
* the input's symbolic stride.
*
* @param input_shape symbolic dynamic shape to broadcast
* @param bcast_dyn_dims symbolic dynamic dimensions to broadcast to
* @return broadcasted shape with symbolic dyn_strides
*/
MIGRAPHX_EXPORT
shape make_bcast_shape(const shape& input_shape,
const std::vector<shape::dynamic_dimension>& bcast_dyn_dims);

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_HPP
1 change: 1 addition & 0 deletions src/include/migraphx/matcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,7 @@ MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins)
}
MIGRAPHX_PRED_MATCHER(dynamic_shape, instruction_ref ins) { return ins->get_shape().dynamic(); }
MIGRAPHX_PRED_MATCHER(static_shape, instruction_ref ins) { return not ins->get_shape().dynamic(); }
MIGRAPHX_PRED_MATCHER(symbolic_shape, instruction_ref ins) { return ins->get_shape().symbolic(); }
MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins)
{
return ins->get_shape().broadcasted();
Expand Down
10 changes: 8 additions & 2 deletions src/include/migraphx/op/binary.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -70,7 +70,8 @@ struct binary : op_name<Derived>
.same_dims();
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if(s0.dynamic() or s1.dynamic())
// Range-based dynamic (or mixed dynamic/static) inputs only support strict equality.
if((s0.dynamic() or s1.dynamic()) and not(s0.symbolic() and s1.symbolic()))
{
if(s0 == s1)
return s0;
Expand All @@ -86,10 +87,15 @@ struct binary : op_name<Derived>
}
else if(s0.broadcasted() != s1.broadcasted())
{
if(s0.symbolic())
return s0.broadcasted() ? s1.with_lens(s0.dyn_dims()) : s0.with_lens(s0.dyn_dims());
return s0.broadcasted() ? s1.with_lens(s0.lens()) : s0.with_lens(s0.lens());
}
else
{
if(s0.symbolic())
return shape::from_permutation(
s0.type(), s0.dyn_dims(), find_permutation({s0, s1}));
return shape::from_permutation(s0.type(), s0.lens(), find_permutation({s0, s1}));
}
}
Expand Down
185 changes: 113 additions & 72 deletions src/include/migraphx/op/broadcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,103 +45,144 @@ namespace op {
* (left-most to rightwards element-wise comparison)
* ex: broadcasting shape [2, 2] -> [2, 2, 3] with axis = 0
*
* Symbolic 1 input version: opt-in via a fully-symbolic `output_dyn_dims` attribute. Input may be
* static (promoted via shape::to_symbolic()) or already symbolic. Range-based dynamic input is
* not allowed (per the "no mixing symbolic and range-based" design rule). `broadcast_lens` is
* not used in this mode.
*
* 2 input version:
* Broadcast the first input 1D shape into the second input shape based on the axis parameter.
* Handles broadcasting a 1D static shape into a higher rank dynamic shape.
* broadcast_lens is not used
*/
struct broadcast
{
uint64_t axis = 0;
std::vector<std::size_t> broadcast_lens = {};
uint64_t axis = 0;
std::vector<std::size_t> broadcast_lens = {};
std::vector<shape::dynamic_dimension> output_dyn_dims = {};

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"), f(self.broadcast_lens, "out_lens"));
return pack(f(self.axis, "axis"),
f(self.broadcast_lens, "out_lens"),
f(self.output_dyn_dims, "out_dyn_dims"));
}

std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const

// Validate axis/dims and place `in_strides` at `axis`, filling broadcast positions with
// `zero`. Templated for the static (size_t) and symbolic (sym::expr) paths.
template <class Target, class Dims, class Strides, class Zero>
shape build_output(shape::type_t t,
const Target& target,
const Dims& in_dims,
const Strides& in_strides,
const Zero& zero) const
{
check_shapes{inputs, *this, true}.has(1, 2);
auto s0 = inputs.at(0);
auto t = s0.type();
if(inputs.size() == 1)
if(axis >= target.size())
MIGRAPHX_THROW("BROADCAST : axis " + migraphx::to_string(axis) + " is out of range");
if(target.size() - axis < in_dims.size())
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims");
for(std::size_t i = 0; i < in_dims.size(); ++i)
{
// the ONNX broadcast op is deprecated now, so not handling the negative
// value of axis anymore
if(s0.dynamic())
MIGRAPHX_THROW(
"BROADCAST: Single dynamic input shape not supported. Use two inputs.");
if(axis >= broadcast_lens.size())
{
MIGRAPHX_THROW("BROADCAST : axis " + migraphx::to_string(axis) +
" is out of range");
}
if(broadcast_lens.size() - axis < s0.lens().size())
{
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims");
}
if(not std::equal(s0.lens().begin(), s0.lens().end(), broadcast_lens.begin() + axis))
{
if(target[axis + i] != in_dims[i])
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
}
}
std::vector<Zero> bcast_strides(target.size(), zero);
std::copy(in_strides.begin(), in_strides.end(), bcast_strides.begin() + axis);
return shape{t, target, std::move(bcast_strides)};
}

std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < s0.elements())
{
// don't think this can occur?
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size");
}
return output;
shape compute_shape_1in(shape s0) const
{
// the ONNX broadcast op is deprecated now, so not handling the negative
// value of axis anymore
const bool symbolic_target = not output_dyn_dims.empty() and
std::all_of(output_dyn_dims.begin(),
output_dyn_dims.end(),
[](const auto& d) { return d.is_symbolic(); });
if(not output_dyn_dims.empty() and not symbolic_target)
MIGRAPHX_THROW("BROADCAST: output_dyn_dims must be fully symbolic");

if(s0.dynamic() and not(symbolic_target and s0.symbolic()))
MIGRAPHX_THROW("BROADCAST: Single dynamic input shape not supported. Use two inputs.");

if(symbolic_target)
{
auto s0_sym = s0.to_symbolic();
return build_output(
s0.type(), output_dyn_dims, s0_sym.dyn_dims(), s0_sym.dyn_strides(), sym::lit(0));
}
else

auto output =
build_output(s0.type(), broadcast_lens, s0.lens(), s0.strides(), std::size_t{0});
if(output.elements() < s0.elements())
{
// two inputs
auto s1 = inputs.at(1);
if(s0.dynamic())
{
MIGRAPHX_THROW("BROADCAST_2in: s0 is a dynamic shape, does not handle broadcasting "
"a dynamic shape");
}
if(s0.ndim() != 1)
{
MIGRAPHX_THROW("BROADCAST_2in: s0 has ndim " + migraphx::to_string(s0.ndim()) +
", only handle ndim = 1");
}
if(axis >= s1.ndim())
{
MIGRAPHX_THROW("BROADCAST_2in: axis " + migraphx::to_string(axis) +
" is out of range");
}
if(s1.dynamic())
{
s0 = s0.to_dynamic();
if(s0.dyn_dims()[0] != s1.dyn_dims()[axis])
{
MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with dynamic s1 axis "
"dimension length (" +
migraphx::to_string(s0.dyn_dims()[0]) +
" != " + migraphx::to_string(s1.dyn_dims()[axis]) + ")");
}
return s1;
}
// don't think this can occur?
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size");
}
return output;
}

shape compute_shape_2in(shape s0, shape s1) const
{
if(s0.ndim() != 1)
{
MIGRAPHX_THROW("BROADCAST_2in: s0 has ndim " + migraphx::to_string(s0.ndim()) +
", only handle ndim = 1");
}

if(s0.symbolic() or s1.symbolic())
{
auto s0_sym = s0.to_symbolic();
auto s1_sym = s1.to_symbolic();
return build_output(
s0.type(), s1_sym.dyn_dims(), s0_sym.dyn_dims(), s0_sym.dyn_strides(), sym::lit(0));
}

if(axis >= s1.ndim())
{
MIGRAPHX_THROW("BROADCAST_2in: axis " + migraphx::to_string(axis) + " is out of range");
}

if(s0.lens()[0] != s1.lens()[axis])
// Range-based dynamic s0 alone is not supported.
if(s0.dynamic())
{
MIGRAPHX_THROW("BROADCAST_2in: s0 is a dynamic shape, does not handle broadcasting "
"a dynamic shape");
}
if(s1.dynamic())
{
s0 = s0.to_dynamic();
if(s0.dyn_dims()[0] != s1.dyn_dims()[axis])
{
MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with static s1 axis "
MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with dynamic s1 axis "
"dimension length (" +
migraphx::to_string(s0.lens()[0]) +
" != " + migraphx::to_string(s1.lens()[axis]) + ")");
migraphx::to_string(s0.dyn_dims()[0]) +
" != " + migraphx::to_string(s1.dyn_dims()[axis]) + ")");
}
std::vector<size_t> bcast_strides(s1.ndim(), 0);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
shape output{t, s1.lens(), std::move(bcast_strides)};
return output;
return s1;
}

if(s0.lens()[0] != s1.lens()[axis])
{
MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with static s1 axis "
"dimension length (" +
migraphx::to_string(s0.lens()[0]) +
" != " + migraphx::to_string(s1.lens()[axis]) + ")");
}
std::vector<size_t> bcast_strides(s1.ndim(), 0);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
return shape{s0.type(), s1.lens(), std::move(bcast_strides)};
}

shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1, 2);
if(inputs.size() == 1)
return compute_shape_1in(inputs.at(0));
return compute_shape_2in(inputs.at(0), inputs.at(1));
}

argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
Expand Down
Loading
Loading