Skip to content
Open
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion torchvision/csrc/ops/cuda/roi_pool_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <float.h>
Expand Down Expand Up @@ -63,10 +64,14 @@ __global__ void roi_pool_forward_kernel_impl(
int maxidx = -1;
const T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;
using acc_t = at::acc_type<T, /*is_cuda=*/true>;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int input_index = h * width + w;
if (offset_input[input_index] > maxval) {
acc_t v = static_cast<acc_t>(offset_input[input_index]);
acc_t mv = static_cast<acc_t>(maxval);

if (v > mv) {
maxval = offset_input[input_index];
maxidx = input_index;
}
Expand Down
Loading