diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/channelwise_conv.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/channelwise_conv.hpp index be186ecf91e..cb1ddb445f5 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/channelwise_conv.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/channelwise_conv.hpp @@ -53,19 +53,19 @@ channelwise_conv(TileLens, Padding, F f, Output output, Input x, Weights w, Inpu auto xs_pack = pack(tiler.slice(inputs)...); using type = typename Output::type; - array wregs_arr; + array wregs_arr; auto wregs = make_tensor_view(wregs_arr.begin(), make_packed_shape(w_ch.get_shape())); copy(w_ch.begin(), w_ch.end(), wregs.begin()); __syncthreads(); tiler.for_each([&](auto out_pos, auto out_multi) { - type acc = 0; + float acc = 0; repeat(wregs.get_shape().elements(), [&](auto ki) { auto k_multi = wregs.get_shape().multi(ki); - acc += x_ch[out_multi + k_multi] * wregs[k_multi]; + acc += static_cast(x_ch[out_multi + k_multi]) * wregs[k_multi]; }); - xs_pack([&](auto... xs) { out_ch[out_pos] = f(acc, xs[out_pos]...); }); + xs_pack([&](auto... xs) { out_ch[out_pos] = f(static_cast(acc), xs[out_pos]...); }); }); }