Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/targets/gpu/jit/channelwise_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ extern "C" {
MIGRAPHX_GLOBAL void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last())(${args})([](auto output, auto x, auto w, auto... inputs) {
channelwise_conv<index_ints<${tile}>, ${ntiles}>(index_ints<${tile}>{}, index_ints<${padding}>{}, ${post}, output, x, w, inputs...);
channelwise_conv<index_ints<${tile}>, ${ntiles}>(index_ints<${tile}>{}, index_ints<${strides}>{}, index_ints<${padding}>{}, ${post}, output, x, w, inputs...);
});
}

Expand Down Expand Up @@ -113,9 +113,14 @@ struct channelwise_conv_compiler : compiler<channelwise_conv_compiler>
if(padding.size() < 2 * num_spatial)
padding.resize(2 * num_spatial, 0);

auto strides = v.get("strides", std::vector<std::size_t>(num_spatial, 1));
if(strides.size() < num_spatial)
strides.resize(num_spatial, 1);

auto src = interpolate_string(channelwise_conv_kernel,
{{"tile", to_string_range(tile_sizes)},
{"ntiles", std::to_string(noutputs)},
{"strides", to_string_range(strides)},
{"padding", to_string_range(padding)},
{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,19 @@ namespace migraphx {

template <class TileLens,
index_int NTiles,
class Strides,
class Padding,
class F,
class Output,
class Input,
class Weights,
class... Inputs>
__device__ void
channelwise_conv(TileLens, Padding, F f, Output output, Input x, Weights w, Inputs... inputs)
__device__ void channelwise_conv(
TileLens, Strides, Padding, F f, Output output, Input x, Weights w, Inputs... inputs)
{
auto idx = make_index();
auto tiler = make_spatial_tiler<NTiles>(idx, TileLens{}, get_shape_c<Output>{}, Padding{});

auto idx = make_index();
auto tiler =
make_spatial_tiler<NTiles>(idx, TileLens{}, get_shape_c<Output>{}, Strides{}, Padding{});
__shared__ decltype(tiler.template shared_allocate<Input>()) smem;

auto x_ch = tiler.copy(x, smem);
Expand All @@ -60,10 +61,14 @@ channelwise_conv(TileLens, Padding, F f, Output output, Input x, Weights w, Inpu
__syncthreads();

tiler.for_each([&](auto out_pos, auto out_multi) {
auto halo_multi = out_multi;
constexpr auto cs = decltype(tiler)::conv_strides();
for(index_int d = 0; d < halo_multi.size(); d++)
halo_multi[d] *= cs[d];
type acc = 0;
repeat(wregs.get_shape().elements(), [&](auto ki) {
auto k_multi = wregs.get_shape().multi(ki);
acc += x_ch[out_multi + k_multi] * wregs[k_multi];
acc += x_ch[halo_multi + k_multi] * wregs[k_multi];
});
xs_pack([&](auto... xs) { out_ch[out_pos] = f(acc, xs[out_pos]...); });
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ constexpr bool has_nonzero(index_ints<Ps...>)
return ((Ps != 0) or ...);
}

template <index_int NTiles, class TileLens, class OutputShape, class Padding = index_ints<>>
template <index_int NTiles,
class TileLens,
class OutputShape,
class Padding = index_ints<>,
class Strides = index_ints<1, 1>>
struct spatial_tiler
{
static constexpr auto keep_spatial()
Expand Down Expand Up @@ -108,27 +112,37 @@ struct spatial_tiler
index idx;
array<index_int, ndim()> tile_origin;

// Compute halo lens for a given input shape: output_lens + (input_spatial - output_spatial)
// With padding, the output is larger so the raw difference is too small; add total padding.
// Per-dimension strides array: {1, 1, s_h, s_w}
static constexpr auto conv_strides()
{
return return_array_c([] {
constexpr auto s = Strides{};
constexpr auto ns = s.size();
auto result = output_lens();
for(auto& x : result)
x = 1;
for(index_int i = 0; i < ns; i++)
result[result.size() - ns + i] = s[i];
return result;
});
}

// Compute halo lens for a given input shape.
template <class InputShape>
static constexpr auto halo_lens_for()
{
constexpr auto halo_extra = [] {
constexpr auto halo_extra = return_array_c([] {
constexpr auto input_spatial = make_slice(InputShape{}, keep_spatial()).lens;
constexpr auto scaled_out =
transform(out_spatial_lens(), conv_strides(), [](auto o, auto s) { return o * s; });
if constexpr(has_conv_padding())
{
return return_array_c([] {
return make_slice(InputShape{}, keep_spatial()).lens - out_spatial_lens() +
total_padding();
});
}
return input_spatial - scaled_out + total_padding();
else
{
constexpr auto input_spatial = make_slice(InputShape{}, keep_spatial()).lens;
return transform(
input_spatial, out_spatial_lens(), [](auto is, auto os) { return is - os; });
}
}();
return transform(output_lens(), halo_extra, [](auto o, auto h) { return o + h; });
return input_spatial - scaled_out;
});
constexpr auto scaled_tile = transform(
output_lens(), conv_strides(), [](auto o, auto s) { return (o - 1) * s + 1; });
return transform(scaled_tile, halo_extra, [](auto t, auto h) { return t + h; });
}

// Type for shared memory allocation
Expand Down Expand Up @@ -164,9 +178,14 @@ struct spatial_tiler
auto input_ch = slice_tensor(
input, (channel_idx / index_int{groups}) % index_int{n_in}, keep_spatial());

auto strided_origin = tile_origin;
constexpr auto cs = conv_strides();
for(index_int d = 0; d < ndim(); d++)
strided_origin[d] *= cs[d];

idx.local_stride(_c<hl.product()>, [&](auto i) {
auto halo_multi = halo_shape.multi(i);
auto src_pos = tile_origin + halo_multi;
auto src_pos = strided_origin + halo_multi;
if constexpr(has_conv_padding())
{
constexpr auto pad = left_padding();
Expand Down Expand Up @@ -203,10 +222,14 @@ struct spatial_tiler
}
};

template <index_int NTiles, class TileLens, class OutputShape, class Padding = index_ints<>>
__device__ auto make_spatial_tiler(index idx, TileLens, OutputShape, Padding = {})
template <index_int NTiles,
class TileLens,
class OutputShape,
class Strides = index_ints<1, 1>,
class Padding = index_ints<>>
__device__ auto make_spatial_tiler(index idx, TileLens, OutputShape, Strides = {}, Padding = {})
{
using tiler_type = spatial_tiler<NTiles, TileLens, OutputShape, Padding>;
using tiler_type = spatial_tiler<NTiles, TileLens, OutputShape, Padding, Strides>;

constexpr auto block_shape = make_shape(return_array_c([] {
auto result = tiler_type::tiles_per_dim().base();
Expand Down
18 changes: 12 additions & 6 deletions src/targets/gpu/prefuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,16 @@ struct channelwise_conv
{
std::size_t num_spatial = 2;
std::vector<std::size_t> padding;
std::vector<std::size_t> strides;

std::string name() const { return "gpu::channelwise_conv"; }

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.num_spatial, "num_spatial"), f(self.padding, "padding"));
return pack(f(self.num_spatial, "num_spatial"),
f(self.padding, "padding"),
f(self.strides, "strides"));
}

shape compute_shape(std::vector<shape> inputs) const
Expand All @@ -265,7 +268,8 @@ struct channelwise_conv
total_pad += padding[i];
if(i + num_spatial < padding.size())
total_pad += padding[i + num_spatial];
out_lens.push_back(x_lens[i + 2] + total_pad - w_lens[i + 2] + 1);
std::size_t s = (i < strides.size()) ? strides[i] : 1;
out_lens.push_back((x_lens[i + 2] + total_pad - w_lens[i + 2]) / s + 1);
}
return inputs[0].with_lens(out_lens);
}
Expand All @@ -277,8 +281,6 @@ MIGRAPHX_PRED_MATCHER(conv_channelwise, instruction_ref ins)
if(ins->name() != "convolution")
return false;
auto v = ins->get_operator().to_value();
if(not all_of(v.at("stride"), [](const value& x) { return x.to<std::size_t>() == 1; }))
return false;
if(not all_of(v.at("dilation"), [](const value& x) { return x.to<std::size_t>() == 1; }))
return false;
auto w_lens = ins->inputs().back()->get_shape().lens();
Expand Down Expand Up @@ -312,8 +314,12 @@ struct find_channelwise_convolution
std::back_inserter(padding),
[](const value& x) { return x.to<std::size_t>(); });

m.replace_instruction(
ins, channelwise_conv{num_spatial, std::move(padding)}, input, weights);
auto strides = v.at("stride").to_vector<std::size_t>();

m.replace_instruction(ins,
channelwise_conv{num_spatial, std::move(padding), std::move(strides)},
input,
weights);
}
};

Expand Down
Loading