Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,42 @@ def test_performance_mps(self):
execution_time_ms < execution_time_ms_threshold
), f"Expected execution to take < {execution_time_ms_threshold} ms, actually took {execution_time_ms} ms"

@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
def test_roi_align_large_index(self, device):
"""Regression test for https://github.com/pytorch/vision/issues/8206"""
pooled_h, pooled_w = 7, 7
channels = 4
# 11M * 4 * 7 * 7 = 2,156,000,000 > INT_MAX
n_rois = 11_000_000
num_imgs = 2
height, width = 4, 4
spatial_scale = 1.0
sampling_ratio = 2

output_bytes = n_rois * channels * pooled_h * pooled_w * 4 # float32
if output_bytes > 9 * 1024**3:
pytest.skip("Test requires ~9 GB of memory")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all these values are statically defined. This if block is either always True or always False.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. I agree. I removed that part in the new commit.

try:
x = torch.rand(num_imgs, channels, height, width, dtype=torch.float32, device=device)
rois = torch.zeros(n_rois, 5, dtype=torch.float32, device=device)
except RuntimeError:
pytest.skip("Not enough memory to allocate test tensors")
Copy link
Member

@NicolasHug NicolasHug Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please verify tests aren't not being skipped on the CI. If they pass, remove the try/except, if they don't, we'll have to consider other strategies to test this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment. I verified that the tests are not being skipped on the CI and my devserver. I removed the try/except in the new commit.


rois[:, 0] = torch.randint(0, num_imgs, (n_rois,))
rois[:, 1] = 0
rois[:, 2] = 0
rois[:, 3] = width - 1
rois[:, 4] = height - 1

try:
result = torch.ops.torchvision.roi_align(x, rois, spatial_scale, pooled_h, pooled_w, sampling_ratio, False)
except RuntimeError:
pytest.skip("Not enough memory for roi_align output")

assert result.shape == (n_rois, channels, pooled_h, pooled_w)
assert result.abs().sum() > 0, "roi_align returned all zeros — likely an index overflow bug"


class TestPSRoIAlign(RoIOpTester):
mps_backward_atol = 5e-2
Expand Down
18 changes: 9 additions & 9 deletions torchvision/csrc/ops/cpu/roi_align_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ namespace detail {

template <typename T>
struct PreCalc {
int pos1;
int pos2;
int pos3;
int pos4;
int64_t pos1;
int64_t pos2;
int64_t pos3;
int64_t pos4;
T w1;
T w2;
T w3;
Expand Down Expand Up @@ -42,7 +42,7 @@ void pre_calc_for_bilinear_interpolate(
int roi_bin_grid_h,
int roi_bin_grid_w,
std::vector<PreCalc<T>>& pre_calc) {
int pre_calc_index = 0;
int64_t pre_calc_index = 0;
for (int ph = 0; ph < pooled_height; ph++) {
for (int pw = 0; pw < pooled_width; pw++) {
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
Expand Down Expand Up @@ -106,10 +106,10 @@ void pre_calc_for_bilinear_interpolate(

// save weights and indices
PreCalc<T> pc;
pc.pos1 = y_low * width + x_low;
pc.pos2 = y_low * width + x_high;
pc.pos3 = y_high * width + x_low;
pc.pos4 = y_high * width + x_high;
pc.pos1 = static_cast<int64_t>(y_low) * width + x_low;
pc.pos2 = static_cast<int64_t>(y_low) * width + x_high;
pc.pos3 = static_cast<int64_t>(y_high) * width + x_low;
pc.pos4 = static_cast<int64_t>(y_high) * width + x_high;
pc.w1 = w1;
pc.w2 = w2;
pc.w3 = w3;
Expand Down
38 changes: 20 additions & 18 deletions torchvision/csrc/ops/cpu/roi_align_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ void roi_align_forward_kernel_impl(
// can be parallelized using omp
// #pragma omp parallel for num_threads(32)
for (int n = 0; n < n_rois; n++) {
int index_n = n * channels * pooled_width * pooled_height;
int64_t index_n =
static_cast<int64_t>(n) * channels * pooled_width * pooled_height;

const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
Expand Down Expand Up @@ -78,14 +79,15 @@ void roi_align_forward_kernel_impl(
pre_calc);

for (int c = 0; c < channels; c++) {
int index_n_c = index_n + c * pooled_width * pooled_height;
const T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;
int64_t index_n_c =
index_n + static_cast<int64_t>(c) * pooled_width * pooled_height;
const T* offset_input = input +
(static_cast<int64_t>(roi_batch_ind) * channels + c) * height * width;
int pre_calc_index = 0;

for (int ph = 0; ph < pooled_height; ph++) {
for (int pw = 0; pw < pooled_width; pw++) {
int index = index_n_c + ph * pooled_width + pw;
int64_t index = index_n_c + ph * pooled_width + pw;

T output_val = 0.;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
Expand Down Expand Up @@ -175,7 +177,7 @@ inline void add(T* address, const T& val) {

template <typename T>
void roi_align_backward_kernel_impl(
int nthreads,
int64_t nthreads,
Copy link
Member

@NicolasHug NicolasHug Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why nthreads needs to be int64_t? It should never need to be that large? If it's for integer comparison to not warn, we could just cast?

Copy link
Contributor Author

@zy1git zy1git Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the question. nthreads in the backward kernel is grad.numel(), which equals n_rois × channels × pooled_h × pooled_w.

I changed to int nthreads in both .cpp and .cu files and ran the test (only the forward kernel test, no backward kernel test due to the large memory requirement). The CPU test passed but the CUDA test failed with all-zero output. The reason is that the CPU forward kernel doesn't use nthreads — it loops over n_rois separately, and the overflow is handled by the int64_t changes to index_n, index_n_c, and index. The CUDA forward kernel uses a flat loop with nthreads as the bound, so truncating to int caused nthreads to wrap to a negative value, making the loop condition immediately false and skipping all output computation — resulting in all-zero output.

The CPU backward kernel does use nthreads in the same flat-loop pattern as CUDA (for (int64_t index = 0; index < nthreads; ...)) and receives the same overflowing value via grad.numel(), so it needs int64_t for the same reason.

A backward-specific test would require large memory (output + grad_output + grad_input), which might be impractical for CI. Do we want to add one with a memory skip guard, or is the current forward-only test sufficient?

Copy link
Contributor Author

@zy1git zy1git Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for "If it's for integer comparison to not warn, we could just cast?", I think it isn't a warning issue. The problem is that the value could be actually large and gets truncated at the call site before the function body runs. In the author's reproducing example (batch_size=172, default 1000 proposals per image), nthreads is 172,000 × 256 × 7 × 7 = 2,157,568,000 > INT_MAX.

I added the backward kernel test in the latest commit. If I change int64_t nthreadsto int nthreads, the CPU backward test fails with all-zero gradients because nthreads gets truncated to a negative value and the loop never executes.

nthreads doesn't mean "number of threads" in this code, instead, it means "total number of output elements to process.", which could be very large.

Feel free to let me know if you have any questions.

const T* grad_output,
const T& spatial_scale,
int channels,
Expand All @@ -187,11 +189,11 @@ void roi_align_backward_kernel_impl(
bool aligned,
T* grad_input,
const T* rois,
int n_stride,
int c_stride,
int h_stride,
int w_stride) {
for (int index = 0; index < nthreads; index++) {
int64_t n_stride,
int64_t c_stride,
int64_t h_stride,
int64_t w_stride) {
for (int64_t index = 0; index < nthreads; index++) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
Expand Down Expand Up @@ -219,10 +221,10 @@ void roi_align_backward_kernel_impl(
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

T* offset_grad_input =
grad_input + ((roi_batch_ind * channels + c) * height * width);
T* offset_grad_input = grad_input +
((static_cast<int64_t>(roi_batch_ind) * channels + c) * height * width);

int output_offset = n * n_stride + c * c_stride;
int64_t output_offset = static_cast<int64_t>(n) * n_stride + c * c_stride;
const T* offset_grad_output = grad_output + output_offset;
const T grad_output_this_bin =
offset_grad_output[ph * h_stride + pw * w_stride];
Expand Down Expand Up @@ -359,10 +361,10 @@ at::Tensor roi_align_backward_kernel(
}

// get stride values to ensure indexing into gradients is correct.
int n_stride = grad.stride(0);
int c_stride = grad.stride(1);
int h_stride = grad.stride(2);
int w_stride = grad.stride(3);
int64_t n_stride = grad.stride(0);
int64_t c_stride = grad.stride(1);
int64_t h_stride = grad.stride(2);
int64_t w_stride = grad.stride(3);

auto rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
Expand Down
36 changes: 19 additions & 17 deletions torchvision/csrc/ops/cuda/roi_align_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ __device__ T bilinear_interpolate(

template <typename T>
__global__ void roi_align_forward_kernel_impl(
int nthreads,
int64_t nthreads,
const T* input,
const T spatial_scale,
int channels,
Expand All @@ -79,7 +79,7 @@ __global__ void roi_align_forward_kernel_impl(
bool aligned,
const T* rois,
T* output) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
CUDA_1D_KERNEL_LOOP_T(index, nthreads, int64_t) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
Expand Down Expand Up @@ -107,8 +107,8 @@ __global__ void roi_align_forward_kernel_impl(
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

const T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;
const T* offset_input = input +
(static_cast<int64_t>(roi_batch_ind) * channels + c) * height * width;

// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
Expand Down Expand Up @@ -203,7 +203,7 @@ __device__ void bilinear_interpolate_gradient(

template <typename T>
__global__ void roi_align_backward_kernel_impl(
int nthreads,
int64_t nthreads,
const T* grad_output,
const T spatial_scale,
int channels,
Expand All @@ -215,12 +215,12 @@ __global__ void roi_align_backward_kernel_impl(
bool aligned,
T* grad_input,
const T* rois,
int n_stride,
int c_stride,
int h_stride,
int w_stride,
const int memory_span) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int64_t n_stride,
int64_t c_stride,
int64_t h_stride,
int64_t w_stride,
const int64_t memory_span) {
CUDA_1D_KERNEL_LOOP_T(index, nthreads, int64_t) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
Expand Down Expand Up @@ -250,7 +250,8 @@ __global__ void roi_align_backward_kernel_impl(

// We need to index the gradient using the tensor strides to access the
// correct values.
const int output_offset = n * n_stride + c * c_stride;
const int64_t output_offset =
static_cast<int64_t>(n) * n_stride + c * c_stride;
const T* offset_grad_output = grad_output + output_offset;
const T grad_output_this_bin =
offset_grad_output[ph * h_stride + pw * w_stride];
Expand All @@ -265,7 +266,8 @@ __global__ void roi_align_backward_kernel_impl(
// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4

const int input_offset = (roi_batch_ind * channels + c) * height * width;
const int64_t input_offset =
(static_cast<int64_t>(roi_batch_ind) * channels + c) * height * width;

for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
{
Expand Down Expand Up @@ -432,10 +434,10 @@ at::Tensor roi_align_backward_kernel(
return grad_input;
}

int n_stride = grad.stride(0);
int c_stride = grad.stride(1);
int h_stride = grad.stride(2);
int w_stride = grad.stride(3);
int64_t n_stride = grad.stride(0);
int64_t c_stride = grad.stride(1);
int64_t h_stride = grad.stride(2);
int64_t w_stride = grad.stride(3);

at::globalContext().alertNotDeterministic("roi_align_backward_kernel");

Expand Down
Loading