-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Fix bus error or segfault from roi_align with large batchsize #9441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
8c71ea8
40b2276
d9ab5ce
544d960
76f6a16
b6b7ab7
b0004a1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -643,6 +643,42 @@ def test_performance_mps(self): | |
| execution_time_ms < execution_time_ms_threshold | ||
| ), f"Expected execution to take < {execution_time_ms_threshold} ms, actually took {execution_time_ms} ms" | ||
|
|
||
| @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) | ||
| def test_roi_align_large_index(self, device): | ||
| """Regression test for https://github.com/pytorch/vision/issues/8206""" | ||
| pooled_h, pooled_w = 7, 7 | ||
| channels = 4 | ||
| # 11M * 4 * 7 * 7 = 2,156,000,000 > INT_MAX | ||
| n_rois = 11_000_000 | ||
| num_imgs = 2 | ||
| height, width = 4, 4 | ||
| spatial_scale = 1.0 | ||
| sampling_ratio = 2 | ||
|
|
||
| output_bytes = n_rois * channels * pooled_h * pooled_w * 4 # float32 | ||
| if output_bytes > 9 * 1024**3: | ||
| pytest.skip("Test requires ~9 GB of memory") | ||
|
|
||
| try: | ||
| x = torch.rand(num_imgs, channels, height, width, dtype=torch.float32, device=device) | ||
| rois = torch.zeros(n_rois, 5, dtype=torch.float32, device=device) | ||
| except RuntimeError: | ||
| pytest.skip("Not enough memory to allocate test tensors") | ||
|
||
|
|
||
| rois[:, 0] = torch.randint(0, num_imgs, (n_rois,)) | ||
| rois[:, 1] = 0 | ||
| rois[:, 2] = 0 | ||
| rois[:, 3] = width - 1 | ||
| rois[:, 4] = height - 1 | ||
|
|
||
| try: | ||
| result = torch.ops.torchvision.roi_align(x, rois, spatial_scale, pooled_h, pooled_w, sampling_ratio, False) | ||
| except RuntimeError: | ||
| pytest.skip("Not enough memory for roi_align output") | ||
|
|
||
| assert result.shape == (n_rois, channels, pooled_h, pooled_w) | ||
| assert result.abs().sum() > 0, "roi_align returned all zeros — likely an index overflow bug" | ||
|
|
||
|
|
||
| class TestPSRoIAlign(RoIOpTester): | ||
| mps_backward_atol = 5e-2 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,7 +26,8 @@ void roi_align_forward_kernel_impl( | |
| // can be parallelized using omp | ||
| // #pragma omp parallel for num_threads(32) | ||
| for (int n = 0; n < n_rois; n++) { | ||
| int index_n = n * channels * pooled_width * pooled_height; | ||
| int64_t index_n = | ||
| static_cast<int64_t>(n) * channels * pooled_width * pooled_height; | ||
|
|
||
| const T* offset_rois = rois + n * 5; | ||
| int roi_batch_ind = offset_rois[0]; | ||
|
|
@@ -78,14 +79,15 @@ void roi_align_forward_kernel_impl( | |
| pre_calc); | ||
|
|
||
| for (int c = 0; c < channels; c++) { | ||
| int index_n_c = index_n + c * pooled_width * pooled_height; | ||
| const T* offset_input = | ||
| input + (roi_batch_ind * channels + c) * height * width; | ||
| int64_t index_n_c = | ||
| index_n + static_cast<int64_t>(c) * pooled_width * pooled_height; | ||
| const T* offset_input = input + | ||
| (static_cast<int64_t>(roi_batch_ind) * channels + c) * height * width; | ||
| int pre_calc_index = 0; | ||
|
|
||
| for (int ph = 0; ph < pooled_height; ph++) { | ||
| for (int pw = 0; pw < pooled_width; pw++) { | ||
| int index = index_n_c + ph * pooled_width + pw; | ||
| int64_t index = index_n_c + ph * pooled_width + pw; | ||
|
|
||
| T output_val = 0.; | ||
| for (int iy = 0; iy < roi_bin_grid_h; iy++) { | ||
|
|
@@ -175,7 +177,7 @@ inline void add(T* address, const T& val) { | |
|
|
||
| template <typename T> | ||
| void roi_align_backward_kernel_impl( | ||
| int nthreads, | ||
| int64_t nthreads, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain why nthreads needs to be
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the question. nthreads in the backward kernel is I changed to The CPU backward kernel does use A backward-specific test would require large memory (output + grad_output + grad_input), which might be impractical for CI. Do we want to add one with a memory skip guard, or is the current forward-only test sufficient?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As for "If it's for integer comparison to not warn, we could just cast?", I think it isn't a warning issue. The problem is that the value could be actually large and gets truncated at the call site before the function body runs. In the author's reproducing example (batch_size=172, default 1000 proposals per image), I added the backward kernel test in the latest commit. If I change
Feel free to let me know if you have any questions. |
||
| const T* grad_output, | ||
| const T& spatial_scale, | ||
| int channels, | ||
|
|
@@ -187,11 +189,11 @@ void roi_align_backward_kernel_impl( | |
| bool aligned, | ||
| T* grad_input, | ||
| const T* rois, | ||
| int n_stride, | ||
| int c_stride, | ||
| int h_stride, | ||
| int w_stride) { | ||
| for (int index = 0; index < nthreads; index++) { | ||
| int64_t n_stride, | ||
| int64_t c_stride, | ||
| int64_t h_stride, | ||
| int64_t w_stride) { | ||
| for (int64_t index = 0; index < nthreads; index++) { | ||
| // (n, c, ph, pw) is an element in the pooled output | ||
| int pw = index % pooled_width; | ||
| int ph = (index / pooled_width) % pooled_height; | ||
|
|
@@ -219,10 +221,10 @@ void roi_align_backward_kernel_impl( | |
| T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height); | ||
| T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width); | ||
|
|
||
| T* offset_grad_input = | ||
| grad_input + ((roi_batch_ind * channels + c) * height * width); | ||
| T* offset_grad_input = grad_input + | ||
| ((static_cast<int64_t>(roi_batch_ind) * channels + c) * height * width); | ||
|
|
||
| int output_offset = n * n_stride + c * c_stride; | ||
| int64_t output_offset = static_cast<int64_t>(n) * n_stride + c * c_stride; | ||
| const T* offset_grad_output = grad_output + output_offset; | ||
| const T grad_output_this_bin = | ||
| offset_grad_output[ph * h_stride + pw * w_stride]; | ||
|
|
@@ -359,10 +361,10 @@ at::Tensor roi_align_backward_kernel( | |
| } | ||
|
|
||
| // get stride values to ensure indexing into gradients is correct. | ||
| int n_stride = grad.stride(0); | ||
| int c_stride = grad.stride(1); | ||
| int h_stride = grad.stride(2); | ||
| int w_stride = grad.stride(3); | ||
| int64_t n_stride = grad.stride(0); | ||
| int64_t c_stride = grad.stride(1); | ||
| int64_t h_stride = grad.stride(2); | ||
| int64_t w_stride = grad.stride(3); | ||
|
|
||
| auto rois_ = rois.contiguous(); | ||
| AT_DISPATCH_FLOATING_TYPES_AND_HALF( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all these values are statically defined. This if block is either always True or always False.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing this out. I agree. I removed that part in the new commit.