[AIMIGRAPHX-885][AIMIGRAPGX-987] Use External Stream Contexts#4775
[AIMIGRAPHX-885][AIMIGRAPGX-987] Use External Stream Contexts#4775TedThemistokleous wants to merge 10 commits intodevelopfrom
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #4775 +/- ##
========================================
Coverage 92.49% 92.49%
========================================
Files 583 583
Lines 29562 29562
========================================
Hits 27343 27343
Misses 2219 2219 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Pull request overview
Enables MIGraphX GPU async execution to run directly on a caller-provided HIP stream (external stream contexts) to reduce internal synchronization/stalls, and adds GPU tests to validate external-stream behavior and fallback behavior.
Changes:
- Add external-stream support in
gpu::context/hip_device::stream(override stream used by the context during async eval). - Adjust async synchronization logic (
wait_for/finish_on) to avoid creating/using an extra internal stream when an external stream is provided. - Add a comprehensive new GPU test suite covering external stream override, async eval behavior, and fallback paths.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
src/targets/gpu/include/migraphx/gpu/context.hpp |
Adds external stream override plumbing and modifies async sync behavior to use caller stream. |
test/gpu/external_stream.cpp |
Adds new GPU tests for external stream override, async eval correctness, and state cleanup expectations. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if(not get_stream().has_external_stream()) | ||
| { | ||
| get_stream().record(finish_event.get()); | ||
| auto status = hipStreamWaitEvent(queue.get<hipStream_t>(), finish_event.get(), 0); | ||
| if(status != hipSuccess) | ||
| MIGRAPHX_THROW("Failed to wait on event: " + hip_error(status)); | ||
| } |
There was a problem hiding this comment.
finish_on() skips all work when an external stream is active, but it also never restores the stream state. Since program::eval() calls wait_for()/finish_on() around async execution, this means an async eval will leave the GPU context permanently bound to the external stream (affecting later sync evals and finish()). finish_on() (or a dedicated scope guard) should clear the external stream and restore library handles back to the internal/default stream after the async run completes.
| if(not get_stream().has_external_stream()) | |
| { | |
| get_stream().record(finish_event.get()); | |
| auto status = hipStreamWaitEvent(queue.get<hipStream_t>(), finish_event.get(), 0); | |
| if(status != hipSuccess) | |
| MIGRAPHX_THROW("Failed to wait on event: " + hip_error(status)); | |
| } | |
| if(get_stream().has_external_stream()) | |
| { | |
| get_stream().set_external_stream(nullptr); | |
| return; | |
| } | |
| get_stream().record(finish_event.get()); | |
| auto status = hipStreamWaitEvent(queue.get<hipStream_t>(), finish_event.get(), 0); | |
| if(status != hipSuccess) | |
| MIGRAPHX_THROW("Failed to wait on event: " + hip_error(status)); |
| auto *ext = queue.get<hipStream_t>(); | ||
| if(ext == nullptr) | ||
| { | ||
| auto status = hipEventRecord(begin_event.get(), ext); | ||
| if(status != hipSuccess) | ||
| MIGRAPHX_THROW("Failed to record: " + hip_error(status)); | ||
| get_stream().wait(begin_event.get()); | ||
| } | ||
| else | ||
| { |
There was a problem hiding this comment.
wait_for() calls queue.get<hipStream_t>() unconditionally. If the caller passes a null stream via execution_environment{nullptr, true} (which is a reasonable way to request the fallback/event path), the any_ptr was constructed with a typed nullptr and any_ptr::get() will hit assert(not ti or ptr != nullptr) in debug builds. Consider checking queue.unsafe_get() == nullptr first and treating that as the null-stream fallback, only calling get<hipStream_t>() when the pointer is non-null.
| auto *ext = queue.get<hipStream_t>(); | |
| if(ext == nullptr) | |
| { | |
| auto status = hipEventRecord(begin_event.get(), ext); | |
| if(status != hipSuccess) | |
| MIGRAPHX_THROW("Failed to record: " + hip_error(status)); | |
| get_stream().wait(begin_event.get()); | |
| } | |
| else | |
| { | |
| if(queue.unsafe_get() == nullptr) | |
| { | |
| auto status = hipEventRecord(begin_event.get(), nullptr); | |
| if(status != hipSuccess) | |
| MIGRAPHX_THROW("Failed to record: " + hip_error(status)); | |
| get_stream().wait(begin_event.get()); | |
| } | |
| else | |
| { | |
| auto* ext = queue.get<hipStream_t>(); |
bdevorem
left a comment
There was a problem hiding this comment.
couple questions, thanks Ted
| auto host_result = migraphx::gpu::from_gpu(gout); | ||
| verify_data(host_result, out_shape, 12.0f); | ||
| } | ||
|
|
There was a problem hiding this comment.
the PR description says external libs reset to default stream on clear/finish, but the tests seem to mostly assert get_queue() or has_external_stream() and numerical results. I think none would fail if MIOpen or rocBLAS were left bound to the customer stream?
d2cfb6b to
563c78b
Compare
bdevorem
left a comment
There was a problem hiding this comment.
I think a rebase/merge will solve the CI problems. lgtm otherwise
pfultz2
left a comment
There was a problem hiding this comment.
I dont think we should do this in the wait_for and finish_on functions as it could change the semantics of the function. Instead we should add a use_queue method to the context interface and use that directly.
Sure rebased this off develop |
7b92f18 to
a164acb
Compare
Okay let me add this. This is similar and just do the create/set in the use_queue or use_external() thread?
So is the idea then run_async() -> bind the stream if its not the null/default stream? otherwise we just create an internal stream for regular run()? |
|
|
Added changes based on Paul's comments so that we odn't modify wait_for, finish_on and just use a use_queue and set_queue_context. |
Motivation
Customer workload seeing some stalls during inference. This allows us to use the customer hipSteam passed to context via run_async so that we don't need to internally sync and manage a thread within MIGraphX. This allows the synchronization to be handled externally.
As an added benefit if not external thread is used we should fall back to the old fork_join run on the GPU where we internally create a stream to sync events onto.
Technical Details
Adds additional conditions to the wait_for , finish_on calls in context.cpp such that we avoid new stream creation for async runs while also simplifying much of the code.
Test cases have been added for this to ensure we don't break existing functionality.
Additional code added to ensure we set external libraries like BLAS and MIOPEN to use the default stream on clear
Changelog Category
Add a
CHANGELOG.mdentry for any option other thanNot Applicable