diff --git a/src/targets/gpu/prefuse_ops.cpp b/src/targets/gpu/prefuse_ops.cpp index 3db298930bd..a5491e8346e 100644 --- a/src/targets/gpu/prefuse_ops.cpp +++ b/src/targets/gpu/prefuse_ops.cpp @@ -43,6 +43,8 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_LAYERNORM_FUSION); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR); namespace { +// Minimum C*H*W for half_type to use channelwise kernel instead of MLIR +constexpr std::size_t channelwise_half_min_chw = 48 * 1024; template struct layernorm_base @@ -301,9 +303,18 @@ struct find_channelwise_convolution auto weights = ins->inputs().back(); auto num_spatial = ins->get_shape().ndim() - 2; - if(input->get_shape().type() != shape::float_type) + const auto type = input->get_shape().type(); + if(type != shape::float_type and type != shape::half_type) return; + if(type == shape::half_type and input->get_shape().ndim() == 4) + { + const auto& lens = input->get_shape().lens(); + const auto chw = lens[1] * lens[2] * lens[3]; + if(chw < channelwise_half_min_chw) + return; + } + auto v = ins->get_operator().to_value(); auto pad_vals = v.at("padding"); std::vector padding;