-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Fix bus error or segfault from roi_align with large batchsize #9441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 10 commits
8c71ea8
40b2276
d9ab5ce
544d960
76f6a16
b6b7ab7
b0004a1
8ff88ed
bad6aba
e0b1a5b
1d9d1d4
f40a46d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -26,7 +26,8 @@ void roi_align_forward_kernel_impl( | |||||||||||||||||||||||
| // can be parallelized using omp | ||||||||||||||||||||||||
| // #pragma omp parallel for num_threads(32) | ||||||||||||||||||||||||
| for (int n = 0; n < n_rois; n++) { | ||||||||||||||||||||||||
| int index_n = n * channels * pooled_width * pooled_height; | ||||||||||||||||||||||||
| int64_t index_n = | ||||||||||||||||||||||||
| static_cast<int64_t>(n) * channels * pooled_width * pooled_height; | ||||||||||||||||||||||||
|
Comment on lines
+29
to
+30
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type and cast are good. |
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| const T* offset_rois = rois + n * 5; | ||||||||||||||||||||||||
| int roi_batch_ind = offset_rois[0]; | ||||||||||||||||||||||||
|
|
@@ -78,14 +79,15 @@ void roi_align_forward_kernel_impl( | |||||||||||||||||||||||
| pre_calc); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| for (int c = 0; c < channels; c++) { | ||||||||||||||||||||||||
| int index_n_c = index_n + c * pooled_width * pooled_height; | ||||||||||||||||||||||||
| const T* offset_input = | ||||||||||||||||||||||||
| input + (roi_batch_ind * channels + c) * height * width; | ||||||||||||||||||||||||
| int64_t index_n_c = | ||||||||||||||||||||||||
| index_n + static_cast<int64_t>(c) * pooled_width * pooled_height; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| const T* offset_input = input + | ||||||||||||||||||||||||
| (static_cast<int64_t>(roi_batch_ind) * channels + c) * height * width; | ||||||||||||||||||||||||
|
Comment on lines
+82
to
+84
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type and cast are good |
||||||||||||||||||||||||
| int pre_calc_index = 0; | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| for (int ph = 0; ph < pooled_height; ph++) { | ||||||||||||||||||||||||
| for (int pw = 0; pw < pooled_width; pw++) { | ||||||||||||||||||||||||
| int index = index_n_c + ph * pooled_width + pw; | ||||||||||||||||||||||||
| int64_t index = index_n_c + ph * pooled_width + pw; | ||||||||||||||||||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type is good |
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| T output_val = 0.; | ||||||||||||||||||||||||
| for (int iy = 0; iy < roi_bin_grid_h; iy++) { | ||||||||||||||||||||||||
|
|
@@ -175,7 +177,7 @@ inline void add(T* address, const T& val) { | |||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| template <typename T> | ||||||||||||||||||||||||
| void roi_align_backward_kernel_impl( | ||||||||||||||||||||||||
| int nthreads, | ||||||||||||||||||||||||
| int64_t nthreads, | ||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain why nthreads needs to be
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the question. nthreads in the backward kernel is I changed to The CPU backward kernel does use A backward-specific test would require large memory (output + grad_output + grad_input), which might be impractical for CI. Do we want to add one with a memory skip guard, or is the current forward-only test sufficient?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As for "If it's for integer comparison to not warn, we could just cast?", I think it isn't a warning issue. The problem is that the value could be actually large and gets truncated at the call site before the function body runs. In the author's reproducing example (batch_size=172, default 1000 proposals per image), I added the backward kernel test in the latest commit. If I change
Feel free to let me know if you have any questions.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. For my own ref:
vision/torchvision/csrc/ops/cuda/roi_align_kernel.cu Lines 378 to 379 in d7400a3
where output size was already inferred as uint64 since size() returns uint64: vision/torchvision/csrc/ops/cuda/roi_align_kernel.cu Lines 354 to 362 in d7400a3
|
||||||||||||||||||||||||
| const T* grad_output, | ||||||||||||||||||||||||
| const T& spatial_scale, | ||||||||||||||||||||||||
| int channels, | ||||||||||||||||||||||||
|
|
@@ -187,11 +189,11 @@ void roi_align_backward_kernel_impl( | |||||||||||||||||||||||
| bool aligned, | ||||||||||||||||||||||||
| T* grad_input, | ||||||||||||||||||||||||
| const T* rois, | ||||||||||||||||||||||||
| int n_stride, | ||||||||||||||||||||||||
| int c_stride, | ||||||||||||||||||||||||
| int h_stride, | ||||||||||||||||||||||||
| int w_stride) { | ||||||||||||||||||||||||
| for (int index = 0; index < nthreads; index++) { | ||||||||||||||||||||||||
| int64_t n_stride, | ||||||||||||||||||||||||
| int64_t c_stride, | ||||||||||||||||||||||||
| int64_t h_stride, | ||||||||||||||||||||||||
| int64_t w_stride) { | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| for (int64_t index = 0; index < nthreads; index++) { | ||||||||||||||||||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. int64_t is needed. |
||||||||||||||||||||||||
| // (n, c, ph, pw) is an element in the pooled output | ||||||||||||||||||||||||
| int pw = index % pooled_width; | ||||||||||||||||||||||||
| int ph = (index / pooled_width) % pooled_height; | ||||||||||||||||||||||||
|
|
@@ -219,10 +221,10 @@ void roi_align_backward_kernel_impl( | |||||||||||||||||||||||
| T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height); | ||||||||||||||||||||||||
| T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| T* offset_grad_input = | ||||||||||||||||||||||||
| grad_input + ((roi_batch_ind * channels + c) * height * width); | ||||||||||||||||||||||||
| T* offset_grad_input = grad_input + | ||||||||||||||||||||||||
| ((static_cast<int64_t>(roi_batch_ind) * channels + c) * height * width); | ||||||||||||||||||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. casting is needed. |
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| int output_offset = n * n_stride + c * c_stride; | ||||||||||||||||||||||||
| int64_t output_offset = static_cast<int64_t>(n) * n_stride + c * c_stride; | ||||||||||||||||||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type and cast are good. |
||||||||||||||||||||||||
| const T* offset_grad_output = grad_output + output_offset; | ||||||||||||||||||||||||
| const T grad_output_this_bin = | ||||||||||||||||||||||||
| offset_grad_output[ph * h_stride + pw * w_stride]; | ||||||||||||||||||||||||
|
|
@@ -359,10 +361,10 @@ at::Tensor roi_align_backward_kernel( | |||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| // get stride values to ensure indexing into gradients is correct. | ||||||||||||||||||||||||
| int n_stride = grad.stride(0); | ||||||||||||||||||||||||
| int c_stride = grad.stride(1); | ||||||||||||||||||||||||
| int h_stride = grad.stride(2); | ||||||||||||||||||||||||
| int w_stride = grad.stride(3); | ||||||||||||||||||||||||
| int64_t n_stride = grad.stride(0); | ||||||||||||||||||||||||
| int64_t c_stride = grad.stride(1); | ||||||||||||||||||||||||
| int64_t h_stride = grad.stride(2); | ||||||||||||||||||||||||
| int64_t w_stride = grad.stride(3); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| auto rois_ = rois.contiguous(); | ||||||||||||||||||||||||
| AT_DISPATCH_FLOATING_TYPES_AND_HALF( | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,7 +67,7 @@ __device__ T bilinear_interpolate( | |
|
|
||
| template <typename T> | ||
| __global__ void roi_align_forward_kernel_impl( | ||
| int nthreads, | ||
| int64_t nthreads, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type is good |
||
| const T* input, | ||
| const T spatial_scale, | ||
| int channels, | ||
|
|
@@ -79,7 +79,7 @@ __global__ void roi_align_forward_kernel_impl( | |
| bool aligned, | ||
| const T* rois, | ||
| T* output) { | ||
| CUDA_1D_KERNEL_LOOP(index, nthreads) { | ||
| CUDA_1D_KERNEL_LOOP_T(index, nthreads, int64_t) { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type is good |
||
| // (n, c, ph, pw) is an element in the pooled output | ||
| int pw = index % pooled_width; | ||
| int ph = (index / pooled_width) % pooled_height; | ||
|
|
@@ -107,8 +107,8 @@ __global__ void roi_align_forward_kernel_impl( | |
| T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height); | ||
| T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width); | ||
|
|
||
| const T* offset_input = | ||
| input + (roi_batch_ind * channels + c) * height * width; | ||
| const T* offset_input = input + | ||
| (static_cast<int64_t>(roi_batch_ind) * channels + c) * height * width; | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cast is needed here. |
||
|
|
||
| // We use roi_bin_grid to sample the grid and mimic integral | ||
| int roi_bin_grid_h = (sampling_ratio > 0) | ||
|
|
@@ -203,7 +203,7 @@ __device__ void bilinear_interpolate_gradient( | |
|
|
||
| template <typename T> | ||
| __global__ void roi_align_backward_kernel_impl( | ||
| int nthreads, | ||
| int64_t nthreads, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type is good. |
||
| const T* grad_output, | ||
| const T spatial_scale, | ||
| int channels, | ||
|
|
@@ -215,12 +215,12 @@ __global__ void roi_align_backward_kernel_impl( | |
| bool aligned, | ||
| T* grad_input, | ||
| const T* rois, | ||
| int n_stride, | ||
| int c_stride, | ||
| int h_stride, | ||
| int w_stride, | ||
| const int memory_span) { | ||
| CUDA_1D_KERNEL_LOOP(index, nthreads) { | ||
| int64_t n_stride, | ||
| int64_t c_stride, | ||
| int64_t h_stride, | ||
| int64_t w_stride, | ||
|
||
| const int64_t memory_span) { | ||
| CUDA_1D_KERNEL_LOOP_T(index, nthreads, int64_t) { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type is good. |
||
| // (n, c, ph, pw) is an element in the pooled output | ||
| int pw = index % pooled_width; | ||
| int ph = (index / pooled_width) % pooled_height; | ||
|
|
@@ -250,7 +250,8 @@ __global__ void roi_align_backward_kernel_impl( | |
|
|
||
| // We need to index the gradient using the tensor strides to access the | ||
| // correct values. | ||
| const int output_offset = n * n_stride + c * c_stride; | ||
| const int64_t output_offset = | ||
| static_cast<int64_t>(n) * n_stride + c * c_stride; | ||
|
Comment on lines
+253
to
+254
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Type is good. The cast is technically redundant since n_stride is already int64_t and the multiplication would promote automatically. I kept it to make the 64-bit intent explicit. |
||
| const T* offset_grad_output = grad_output + output_offset; | ||
| const T grad_output_this_bin = | ||
| offset_grad_output[ph * h_stride + pw * w_stride]; | ||
|
|
@@ -265,7 +266,8 @@ __global__ void roi_align_backward_kernel_impl( | |
| // We do average (integral) pooling inside a bin | ||
| const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 | ||
|
|
||
| const int input_offset = (roi_batch_ind * channels + c) * height * width; | ||
| const int64_t input_offset = | ||
| (static_cast<int64_t>(roi_batch_ind) * channels + c) * height * width; | ||
|
Comment on lines
+269
to
+270
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type and cast are good. |
||
|
|
||
| for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 | ||
| { | ||
|
|
@@ -432,10 +434,10 @@ at::Tensor roi_align_backward_kernel( | |
| return grad_input; | ||
| } | ||
|
|
||
| int n_stride = grad.stride(0); | ||
| int c_stride = grad.stride(1); | ||
| int h_stride = grad.stride(2); | ||
| int w_stride = grad.stride(3); | ||
| int64_t n_stride = grad.stride(0); | ||
| int64_t c_stride = grad.stride(1); | ||
| int64_t h_stride = grad.stride(2); | ||
| int64_t w_stride = grad.stride(3); | ||
|
|
||
| at::globalContext().alertNotDeterministic("roi_align_backward_kernel"); | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.