diff --git a/test/test_ops.py b/test/test_ops.py index 9521f21a815..106a993acd6 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -643,6 +643,40 @@ def test_performance_mps(self): execution_time_ms < execution_time_ms_threshold ), f"Expected execution to take < {execution_time_ms_threshold} ms, actually took {execution_time_ms} ms" + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_roi_align_large_index(self, device): + """Regression test for https://github.com/pytorch/vision/issues/8206""" + pooled_h, pooled_w = 7, 7 + channels = 4 + # 11M * 4 * 7 * 7 = 2,156,000,000 > INT_MAX + n_rois = 11_000_000 + num_imgs = 2 + height, width = 4, 4 + spatial_scale = 1.0 + sampling_ratio = 2 + + x = torch.rand(num_imgs, channels, height, width, dtype=torch.float32, device=device, requires_grad=True) + rois = torch.zeros(n_rois, 5, dtype=torch.float32, device=device) + + rois[:, 0] = torch.randint(0, num_imgs, (n_rois,)) + rois[:, 1] = 0 + rois[:, 2] = 0 + rois[:, 3] = width - 1 + rois[:, 4] = height - 1 + + # Call the C++ kernel directly, in case that torchvision.ops.roi_align may fall + # back to a pure-Python path that doesn't have the int32 overflow bug. + result = torch.ops.torchvision.roi_align(x, rois, spatial_scale, pooled_h, pooled_w, sampling_ratio, False) + + # Forward kernel test + assert result.shape == (n_rois, channels, pooled_h, pooled_w) + assert result.abs().sum() > 0, "roi_align returned all zeros — likely an index overflow bug" + + # Backward kernel test + result.sum().backward() + assert x.grad is not None, "x.grad is None — backward was not executed" + assert x.grad.abs().sum() > 0, "x.grad is all zeros — likely an index overflow bug in the backward kernel" + class TestPSRoIAlign(RoIOpTester): mps_backward_atol = 5e-2 diff --git a/torchvision/csrc/ops/cpu/roi_align_common.h b/torchvision/csrc/ops/cpu/roi_align_common.h index e10c67b5b79..cb5c0deb658 100644 --- a/torchvision/csrc/ops/cpu/roi_align_common.h +++ b/torchvision/csrc/ops/cpu/roi_align_common.h @@ -8,10 +8,10 @@ namespace detail { template struct PreCalc { - int pos1; - int pos2; - int pos3; - int pos4; + int64_t pos1; + int64_t pos2; + int64_t pos3; + int64_t pos4; T w1; T w2; T w3; @@ -42,7 +42,7 @@ void pre_calc_for_bilinear_interpolate( int roi_bin_grid_h, int roi_bin_grid_w, std::vector>& pre_calc) { - int pre_calc_index = 0; + int64_t pre_calc_index = 0; for (int ph = 0; ph < pooled_height; ph++) { for (int pw = 0; pw < pooled_width; pw++) { for (int iy = 0; iy < roi_bin_grid_h; iy++) { @@ -106,10 +106,10 @@ void pre_calc_for_bilinear_interpolate( // save weights and indices PreCalc pc; - pc.pos1 = y_low * width + x_low; - pc.pos2 = y_low * width + x_high; - pc.pos3 = y_high * width + x_low; - pc.pos4 = y_high * width + x_high; + pc.pos1 = static_cast(y_low) * width + x_low; + pc.pos2 = static_cast(y_low) * width + x_high; + pc.pos3 = static_cast(y_high) * width + x_low; + pc.pos4 = static_cast(y_high) * width + x_high; pc.w1 = w1; pc.w2 = w2; pc.w3 = w3; diff --git a/torchvision/csrc/ops/cpu/roi_align_kernel.cpp b/torchvision/csrc/ops/cpu/roi_align_kernel.cpp index e0185da45df..39f670d8112 100644 --- a/torchvision/csrc/ops/cpu/roi_align_kernel.cpp +++ b/torchvision/csrc/ops/cpu/roi_align_kernel.cpp @@ -26,7 +26,8 @@ void roi_align_forward_kernel_impl( // can be parallelized using omp // #pragma omp parallel for num_threads(32) for (int n = 0; n < n_rois; n++) { - int index_n = n * channels * pooled_width * pooled_height; + int64_t index_n = + static_cast(n) * channels * pooled_width * pooled_height; const T* offset_rois = rois + n * 5; int roi_batch_ind = offset_rois[0]; @@ -78,14 +79,15 @@ void roi_align_forward_kernel_impl( pre_calc); for (int c = 0; c < channels; c++) { - int index_n_c = index_n + c * pooled_width * pooled_height; - const T* offset_input = - input + (roi_batch_ind * channels + c) * height * width; + int64_t index_n_c = + index_n + static_cast(c) * pooled_width * pooled_height; + const T* offset_input = input + + (static_cast(roi_batch_ind) * channels + c) * height * width; int pre_calc_index = 0; for (int ph = 0; ph < pooled_height; ph++) { for (int pw = 0; pw < pooled_width; pw++) { - int index = index_n_c + ph * pooled_width + pw; + int64_t index = index_n_c + ph * pooled_width + pw; T output_val = 0.; for (int iy = 0; iy < roi_bin_grid_h; iy++) { @@ -175,7 +177,7 @@ inline void add(T* address, const T& val) { template void roi_align_backward_kernel_impl( - int nthreads, + int64_t nthreads, const T* grad_output, const T& spatial_scale, int channels, @@ -187,11 +189,11 @@ void roi_align_backward_kernel_impl( bool aligned, T* grad_input, const T* rois, - int n_stride, - int c_stride, - int h_stride, - int w_stride) { - for (int index = 0; index < nthreads; index++) { + int64_t n_stride, + int64_t c_stride, + int64_t h_stride, + int64_t w_stride) { + for (int64_t index = 0; index < nthreads; index++) { // (n, c, ph, pw) is an element in the pooled output int pw = index % pooled_width; int ph = (index / pooled_width) % pooled_height; @@ -219,10 +221,10 @@ void roi_align_backward_kernel_impl( T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - T* offset_grad_input = - grad_input + ((roi_batch_ind * channels + c) * height * width); + T* offset_grad_input = grad_input + + ((static_cast(roi_batch_ind) * channels + c) * height * width); - int output_offset = n * n_stride + c * c_stride; + int64_t output_offset = static_cast(n) * n_stride + c * c_stride; const T* offset_grad_output = grad_output + output_offset; const T grad_output_this_bin = offset_grad_output[ph * h_stride + pw * w_stride]; @@ -359,10 +361,10 @@ at::Tensor roi_align_backward_kernel( } // get stride values to ensure indexing into gradients is correct. - int n_stride = grad.stride(0); - int c_stride = grad.stride(1); - int h_stride = grad.stride(2); - int w_stride = grad.stride(3); + int64_t n_stride = grad.stride(0); + int64_t c_stride = grad.stride(1); + int64_t h_stride = grad.stride(2); + int64_t w_stride = grad.stride(3); auto rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( diff --git a/torchvision/csrc/ops/cuda/roi_align_kernel.cu b/torchvision/csrc/ops/cuda/roi_align_kernel.cu index 26c53448663..e9fbf4060f2 100644 --- a/torchvision/csrc/ops/cuda/roi_align_kernel.cu +++ b/torchvision/csrc/ops/cuda/roi_align_kernel.cu @@ -67,7 +67,7 @@ __device__ T bilinear_interpolate( template __global__ void roi_align_forward_kernel_impl( - int nthreads, + int64_t nthreads, const T* input, const T spatial_scale, int channels, @@ -79,7 +79,7 @@ __global__ void roi_align_forward_kernel_impl( bool aligned, const T* rois, T* output) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { + CUDA_1D_KERNEL_LOOP_T(index, nthreads, int64_t) { // (n, c, ph, pw) is an element in the pooled output int pw = index % pooled_width; int ph = (index / pooled_width) % pooled_height; @@ -107,8 +107,8 @@ __global__ void roi_align_forward_kernel_impl( T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - const T* offset_input = - input + (roi_batch_ind * channels + c) * height * width; + const T* offset_input = input + + (static_cast(roi_batch_ind) * channels + c) * height * width; // We use roi_bin_grid to sample the grid and mimic integral int roi_bin_grid_h = (sampling_ratio > 0) @@ -203,7 +203,7 @@ __device__ void bilinear_interpolate_gradient( template __global__ void roi_align_backward_kernel_impl( - int nthreads, + int64_t nthreads, const T* grad_output, const T spatial_scale, int channels, @@ -215,12 +215,12 @@ __global__ void roi_align_backward_kernel_impl( bool aligned, T* grad_input, const T* rois, - int n_stride, - int c_stride, - int h_stride, - int w_stride, - const int memory_span) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { + int64_t n_stride, + int64_t c_stride, + int64_t h_stride, + int64_t w_stride, + const int64_t memory_span) { + CUDA_1D_KERNEL_LOOP_T(index, nthreads, int64_t) { // (n, c, ph, pw) is an element in the pooled output int pw = index % pooled_width; int ph = (index / pooled_width) % pooled_height; @@ -250,7 +250,8 @@ __global__ void roi_align_backward_kernel_impl( // We need to index the gradient using the tensor strides to access the // correct values. - const int output_offset = n * n_stride + c * c_stride; + const int64_t output_offset = + static_cast(n) * n_stride + c * c_stride; const T* offset_grad_output = grad_output + output_offset; const T grad_output_this_bin = offset_grad_output[ph * h_stride + pw * w_stride]; @@ -265,7 +266,8 @@ __global__ void roi_align_backward_kernel_impl( // We do average (integral) pooling inside a bin const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 - const int input_offset = (roi_batch_ind * channels + c) * height * width; + const int64_t input_offset = + (static_cast(roi_batch_ind) * channels + c) * height * width; for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 { @@ -432,10 +434,10 @@ at::Tensor roi_align_backward_kernel( return grad_input; } - int n_stride = grad.stride(0); - int c_stride = grad.stride(1); - int h_stride = grad.stride(2); - int w_stride = grad.stride(3); + int64_t n_stride = grad.stride(0); + int64_t c_stride = grad.stride(1); + int64_t h_stride = grad.stride(2); + int64_t w_stride = grad.stride(3); at::globalContext().alertNotDeterministic("roi_align_backward_kernel");