Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/targets/gpu/prefuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_LAYERNORM_FUSION);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR);

namespace {
// Minimum C*H*W for half_type to use channelwise kernel instead of MLIR
constexpr std::size_t channelwise_half_min_chw = 48 * 1024;

template <class Derived, std::size_t N>
struct layernorm_base
Expand Down Expand Up @@ -301,9 +303,18 @@ struct find_channelwise_convolution
auto weights = ins->inputs().back();
auto num_spatial = ins->get_shape().ndim() - 2;

if(input->get_shape().type() != shape::float_type)
const auto type = input->get_shape().type();
if(type != shape::float_type and type != shape::half_type)
return;

if(type == shape::half_type and input->get_shape().ndim() == 4)
{
const auto& lens = input->get_shape().lens();
const auto chw = lens[1] * lens[2] * lens[3];
if(chw < channelwise_half_min_chw)
return;
}

auto v = ins->get_operator().to_value();
auto pad_vals = v.at("padding");
std::vector<std::size_t> padding;
Expand Down
Loading