Skip to content

Fix bus error or segfault from roi_align with large batchsize#9441

Open
zy1git wants to merge 7 commits intopytorch:mainfrom
zy1git:issue-8206
Open

Fix bus error or segfault from roi_align with large batchsize#9441
zy1git wants to merge 7 commits intopytorch:mainfrom
zy1git:issue-8206

Conversation

@zy1git
Copy link
Contributor

@zy1git zy1git commented Mar 13, 2026

Summary
Bug: roi_align in torchvision crashes with a bus error/segfault on CPU or returns silently wrong (all-zero) results on CUDA when the total number of output elements exceeds INT_MAX (~2.1 billion). This is caused by 32-bit int overflow in index arithmetic within the C++ and CUDA kernels.

Root Cause: The kernels use int for composite index calculations like n × channels × pooled_width × pooled_height and pointer offsets like (roi_batch_ind × channels + c) × height × width. When these products exceed 2,147,483,647, the int wraps to a negative value, causing out-of-bounds memory access.

Example: FasterRCNN with batch_size=172 generates ~172,000 ROIs. The output index reaches 171,999 × 256 × 7 × 7 = 2,157,555,456 > INT_MAX, which matches the reporter's observed threshold exactly.

Fix: Promoted int to int64_t for all index, offset, and stride variables in the relevant files.

Test Plan
New overflow regression test
pytest test/test_ops.py::TestRoIAlign::test_roi_align_large_index -v

Existing tests — verify no regressions
pytest test/test_ops.py::TestRoIAlign -v

Fixes #8206

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 13, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/9441

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b0004a1 with merge base 6285457 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the cla signed label Mar 13, 2026
@zy1git zy1git marked this pull request as draft March 13, 2026 10:03
test/test_ops.py Outdated
output_bytes = n_rois * channels * pooled_h * pooled_w * 4 # float32
if output_bytes > 9 * 1024**3:
pytest.skip("Test requires ~9 GB of memory")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all these values are statically defined. This if block is either always True or always False.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. I agree. I removed that part in the new commit.

test/test_ops.py Outdated
x = torch.rand(num_imgs, channels, height, width, dtype=torch.float32, device=device)
rois = torch.zeros(n_rois, 5, dtype=torch.float32, device=device)
except RuntimeError:
pytest.skip("Not enough memory to allocate test tensors")
Copy link
Member

@NicolasHug NicolasHug Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please verify tests aren't not being skipped on the CI. If they pass, remove the try/except, if they don't, we'll have to consider other strategies to test this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment. I verified that the tests are not being skipped on the CI and my devserver. I removed the try/except in the new commit.

template <typename T>
void roi_align_backward_kernel_impl(
int nthreads,
int64_t nthreads,
Copy link
Member

@NicolasHug NicolasHug Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why nthreads needs to be int64_t? It should never need to be that large? If it's for integer comparison to not warn, we could just cast?

Copy link
Contributor Author

@zy1git zy1git Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the question. nthreads in the backward kernel is grad.numel(), which equals n_rois × channels × pooled_h × pooled_w.

I changed to int nthreads in both .cpp and .cu files and ran the test (only the forward kernel test, no backward kernel test due to the large memory requirement). The CPU test passed but the CUDA test failed with all-zero output. The reason is that the CPU forward kernel doesn't use nthreads — it loops over n_rois separately, and the overflow is handled by the int64_t changes to index_n, index_n_c, and index. The CUDA forward kernel uses a flat loop with nthreads as the bound, so truncating to int caused nthreads to wrap to a negative value, making the loop condition immediately false and skipping all output computation — resulting in all-zero output.

The CPU backward kernel does use nthreads in the same flat-loop pattern as CUDA (for (int64_t index = 0; index < nthreads; ...)) and receives the same overflowing value via grad.numel(), so it needs int64_t for the same reason.

A backward-specific test would require large memory (output + grad_output + grad_input), which might be impractical for CI. Do we want to add one with a memory skip guard, or is the current forward-only test sufficient?

Copy link
Contributor Author

@zy1git zy1git Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for "If it's for integer comparison to not warn, we could just cast?", I think it isn't a warning issue. The problem is that the value could be actually large and gets truncated at the call site before the function body runs. In the author's reproducing example (batch_size=172, default 1000 proposals per image), nthreads is 172,000 × 256 × 7 × 7 = 2,157,568,000 > INT_MAX.

I added the backward kernel test in the latest commit. If I change int64_t nthreadsto int nthreads, the CPU backward test fails with all-zero gradients because nthreads gets truncated to a negative value and the loop never executes.

nthreads doesn't mean "number of threads" in this code, instead, it means "total number of output elements to process.", which could be very large.

Feel free to let me know if you have any questions.

@zy1git zy1git marked this pull request as ready for review March 20, 2026 08:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Bus error or segfault from roi_align with large batchsize

2 participants