-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Fix bus error or segfault from roi_align with large batchsize #9441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8c71ea8
40b2276
d9ab5ce
544d960
76f6a16
b6b7ab7
b0004a1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -643,6 +643,40 @@ def test_performance_mps(self): | |||||
| execution_time_ms < execution_time_ms_threshold | ||||||
| ), f"Expected execution to take < {execution_time_ms_threshold} ms, actually took {execution_time_ms} ms" | ||||||
|
|
||||||
| @pytest.mark.parametrize("device", cpu_and_cuda()) | ||||||
| def test_roi_align_large_index(self, device): | ||||||
| """Regression test for https://github.com/pytorch/vision/issues/8206""" | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| pooled_h, pooled_w = 7, 7 | ||||||
| channels = 4 | ||||||
| # 11M * 4 * 7 * 7 = 2,156,000,000 > INT_MAX | ||||||
| n_rois = 11_000_000 | ||||||
| num_imgs = 2 | ||||||
| height, width = 4, 4 | ||||||
| spatial_scale = 1.0 | ||||||
| sampling_ratio = 2 | ||||||
|
|
||||||
| x = torch.rand(num_imgs, channels, height, width, dtype=torch.float32, device=device, requires_grad=True) | ||||||
| rois = torch.zeros(n_rois, 5, dtype=torch.float32, device=device) | ||||||
|
|
||||||
| rois[:, 0] = torch.randint(0, num_imgs, (n_rois,)) | ||||||
| rois[:, 1] = 0 | ||||||
| rois[:, 2] = 0 | ||||||
| rois[:, 3] = width - 1 | ||||||
| rois[:, 4] = height - 1 | ||||||
|
|
||||||
| # Call the C++ kernel directly, in case that torchvision.ops.roi_align may fall | ||||||
| # back to a pure-Python path that doesn't have the int32 overflow bug. | ||||||
| result = torch.ops.torchvision.roi_align(x, rois, spatial_scale, pooled_h, pooled_w, sampling_ratio, False) | ||||||
|
|
||||||
| # Forward kernel test | ||||||
| assert result.shape == (n_rois, channels, pooled_h, pooled_w) | ||||||
| assert result.abs().sum() > 0, "roi_align returned all zeros — likely an index overflow bug" | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here and below, don't specify anything beyond the
Suggested change
|
||||||
|
|
||||||
| # Backward kernel test | ||||||
| result.sum().backward() | ||||||
| assert x.grad is not None, "x.grad is None — backward was not executed" | ||||||
| assert x.grad.abs().sum() > 0, "x.grad is all zeros — likely an index overflow bug in the backward kernel" | ||||||
|
|
||||||
|
|
||||||
| class TestPSRoIAlign(RoIOpTester): | ||||||
| mps_backward_atol = 5e-2 | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,10 +8,10 @@ namespace detail { | |
|
|
||
| template <typename T> | ||
| struct PreCalc { | ||
| int pos1; | ||
| int pos2; | ||
| int pos3; | ||
| int pos4; | ||
| int64_t pos1; | ||
| int64_t pos2; | ||
| int64_t pos3; | ||
| int64_t pos4; | ||
| T w1; | ||
| T w2; | ||
| T w3; | ||
|
|
@@ -42,7 +42,7 @@ void pre_calc_for_bilinear_interpolate( | |
| int roi_bin_grid_h, | ||
| int roi_bin_grid_w, | ||
| std::vector<PreCalc<T>>& pre_calc) { | ||
| int pre_calc_index = 0; | ||
| int64_t pre_calc_index = 0; | ||
| for (int ph = 0; ph < pooled_height; ph++) { | ||
| for (int pw = 0; pw < pooled_width; pw++) { | ||
| for (int iy = 0; iy < roi_bin_grid_h; iy++) { | ||
|
|
@@ -106,10 +106,10 @@ void pre_calc_for_bilinear_interpolate( | |
|
|
||
| // save weights and indices | ||
| PreCalc<T> pc; | ||
| pc.pos1 = y_low * width + x_low; | ||
| pc.pos2 = y_low * width + x_high; | ||
| pc.pos3 = y_high * width + x_low; | ||
| pc.pos4 = y_high * width + x_high; | ||
| pc.pos1 = static_cast<int64_t>(y_low) * width + x_low; | ||
| pc.pos2 = static_cast<int64_t>(y_low) * width + x_high; | ||
| pc.pos3 = static_cast<int64_t>(y_high) * width + x_low; | ||
| pc.pos4 = static_cast<int64_t>(y_high) * width + x_high; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you double check that these casts are needed? claude says: y_low and y_high are pixel coordinates bounded by height, and width is the image width. y_low * width + x_low is at most height * width, which is the number of pixels in a single channel of a single image. That's not going to overflow int. Please check every single other change in this PR. |
||
| pc.w1 = w1; | ||
| pc.w2 = w2; | ||
| pc.w3 = w3; | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -26,7 +26,8 @@ void roi_align_forward_kernel_impl( | |||||||||||||||||||||||
| // can be parallelized using omp | ||||||||||||||||||||||||
| // #pragma omp parallel for num_threads(32) | ||||||||||||||||||||||||
| for (int n = 0; n < n_rois; n++) { | ||||||||||||||||||||||||
| int index_n = n * channels * pooled_width * pooled_height; | ||||||||||||||||||||||||
| int64_t index_n = | ||||||||||||||||||||||||
| static_cast<int64_t>(n) * channels * pooled_width * pooled_height; | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| const T* offset_rois = rois + n * 5; | ||||||||||||||||||||||||
| int roi_batch_ind = offset_rois[0]; | ||||||||||||||||||||||||
|
|
@@ -78,14 +79,15 @@ void roi_align_forward_kernel_impl( | |||||||||||||||||||||||
| pre_calc); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| for (int c = 0; c < channels; c++) { | ||||||||||||||||||||||||
| int index_n_c = index_n + c * pooled_width * pooled_height; | ||||||||||||||||||||||||
| const T* offset_input = | ||||||||||||||||||||||||
| input + (roi_batch_ind * channels + c) * height * width; | ||||||||||||||||||||||||
| int64_t index_n_c = | ||||||||||||||||||||||||
| index_n + static_cast<int64_t>(c) * pooled_width * pooled_height; | ||||||||||||||||||||||||
| const T* offset_input = input + | ||||||||||||||||||||||||
| (static_cast<int64_t>(roi_batch_ind) * channels + c) * height * width; | ||||||||||||||||||||||||
| int pre_calc_index = 0; | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| for (int ph = 0; ph < pooled_height; ph++) { | ||||||||||||||||||||||||
| for (int pw = 0; pw < pooled_width; pw++) { | ||||||||||||||||||||||||
| int index = index_n_c + ph * pooled_width + pw; | ||||||||||||||||||||||||
| int64_t index = index_n_c + ph * pooled_width + pw; | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| T output_val = 0.; | ||||||||||||||||||||||||
| for (int iy = 0; iy < roi_bin_grid_h; iy++) { | ||||||||||||||||||||||||
|
|
@@ -175,7 +177,7 @@ inline void add(T* address, const T& val) { | |||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| template <typename T> | ||||||||||||||||||||||||
| void roi_align_backward_kernel_impl( | ||||||||||||||||||||||||
| int nthreads, | ||||||||||||||||||||||||
| int64_t nthreads, | ||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain why nthreads needs to be
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the question. nthreads in the backward kernel is I changed to The CPU backward kernel does use A backward-specific test would require large memory (output + grad_output + grad_input), which might be impractical for CI. Do we want to add one with a memory skip guard, or is the current forward-only test sufficient?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As for "If it's for integer comparison to not warn, we could just cast?", I think it isn't a warning issue. The problem is that the value could be actually large and gets truncated at the call site before the function body runs. In the author's reproducing example (batch_size=172, default 1000 proposals per image), I added the backward kernel test in the latest commit. If I change
Feel free to let me know if you have any questions.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. For my own ref:
vision/torchvision/csrc/ops/cuda/roi_align_kernel.cu Lines 378 to 379 in d7400a3
where output size was already inferred as uint64 since size() returns uint64: vision/torchvision/csrc/ops/cuda/roi_align_kernel.cu Lines 354 to 362 in d7400a3
|
||||||||||||||||||||||||
| const T* grad_output, | ||||||||||||||||||||||||
| const T& spatial_scale, | ||||||||||||||||||||||||
| int channels, | ||||||||||||||||||||||||
|
|
@@ -187,11 +189,11 @@ void roi_align_backward_kernel_impl( | |||||||||||||||||||||||
| bool aligned, | ||||||||||||||||||||||||
| T* grad_input, | ||||||||||||||||||||||||
| const T* rois, | ||||||||||||||||||||||||
| int n_stride, | ||||||||||||||||||||||||
| int c_stride, | ||||||||||||||||||||||||
| int h_stride, | ||||||||||||||||||||||||
| int w_stride) { | ||||||||||||||||||||||||
| for (int index = 0; index < nthreads; index++) { | ||||||||||||||||||||||||
| int64_t n_stride, | ||||||||||||||||||||||||
| int64_t c_stride, | ||||||||||||||||||||||||
| int64_t h_stride, | ||||||||||||||||||||||||
| int64_t w_stride) { | ||||||||||||||||||||||||
| for (int64_t index = 0; index < nthreads; index++) { | ||||||||||||||||||||||||
| // (n, c, ph, pw) is an element in the pooled output | ||||||||||||||||||||||||
| int pw = index % pooled_width; | ||||||||||||||||||||||||
| int ph = (index / pooled_width) % pooled_height; | ||||||||||||||||||||||||
|
|
@@ -219,10 +221,10 @@ void roi_align_backward_kernel_impl( | |||||||||||||||||||||||
| T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height); | ||||||||||||||||||||||||
| T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| T* offset_grad_input = | ||||||||||||||||||||||||
| grad_input + ((roi_batch_ind * channels + c) * height * width); | ||||||||||||||||||||||||
| T* offset_grad_input = grad_input + | ||||||||||||||||||||||||
| ((static_cast<int64_t>(roi_batch_ind) * channels + c) * height * width); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| int output_offset = n * n_stride + c * c_stride; | ||||||||||||||||||||||||
| int64_t output_offset = static_cast<int64_t>(n) * n_stride + c * c_stride; | ||||||||||||||||||||||||
| const T* offset_grad_output = grad_output + output_offset; | ||||||||||||||||||||||||
| const T grad_output_this_bin = | ||||||||||||||||||||||||
| offset_grad_output[ph * h_stride + pw * w_stride]; | ||||||||||||||||||||||||
|
|
@@ -359,10 +361,10 @@ at::Tensor roi_align_backward_kernel( | |||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| // get stride values to ensure indexing into gradients is correct. | ||||||||||||||||||||||||
| int n_stride = grad.stride(0); | ||||||||||||||||||||||||
| int c_stride = grad.stride(1); | ||||||||||||||||||||||||
| int h_stride = grad.stride(2); | ||||||||||||||||||||||||
| int w_stride = grad.stride(3); | ||||||||||||||||||||||||
| int64_t n_stride = grad.stride(0); | ||||||||||||||||||||||||
| int64_t c_stride = grad.stride(1); | ||||||||||||||||||||||||
| int64_t h_stride = grad.stride(2); | ||||||||||||||||||||||||
| int64_t w_stride = grad.stride(3); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| auto rois_ = rois.contiguous(); | ||||||||||||||||||||||||
| AT_DISPATCH_FLOATING_TYPES_AND_HALF( | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually taking a super long time on CPU:
https://productionresultssa18.blob.core.windows.net/actions-results/5a8fef95-8252-44a2-86fb-264662c987f6/workflow-job-run-05f658de-4299-5615-a958-010a82bf0999/logs/job/job-logs.txt?rsct=text%2Fplain&se=2026-03-23T10%3A33%3A43Z&sig=9bWAdM4OgZyShqk2hZAKOacge4p93mIQqKcfYlkUHVI%3D&ske=2026-03-23T12%3A13%3A31Z&skoid=ca7593d4-ee42-46cd-af88-8b886a2f84eb&sks=b&skt=2026-03-23T08%3A13%3A31Z&sktid=398a6654-997b-47e9-b12b-9515b896b4de&skv=2025-11-05&sp=r&spr=https&sr=b&st=2026-03-23T10%3A23%3A38Z&sv=2025-11-05
on GPU it seems OK. Let's just skip the test on CPU indicating it takes too long.