From 0d1eda82c216a3bf7f3f35e5f4a9fef244a17868 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 9 Apr 2026 18:42:25 -0500 Subject: [PATCH 01/11] add external stream support to context --- .../gpu/include/migraphx/gpu/context.hpp | 76 ++++++++++++++++--- 1 file changed, 66 insertions(+), 10 deletions(-) diff --git a/src/targets/gpu/include/migraphx/gpu/context.hpp b/src/targets/gpu/include/migraphx/gpu/context.hpp index e29414d41f3..c5126c578f0 100644 --- a/src/targets/gpu/include/migraphx/gpu/context.hpp +++ b/src/targets/gpu/include/migraphx/gpu/context.hpp @@ -90,6 +90,8 @@ struct hip_device hipStream_t get() { + if(external_stream_ != nullptr) + return external_stream_; if(not enabled(MIGRAPHX_ENABLE_NULL_STREAM{})) { setup(); @@ -144,8 +146,47 @@ struct hip_device } #endif + void set_external_stream(hipStream_t ext_stream) + { + external_stream_ = ext_stream; +#if MIGRAPHX_USE_MIOPEN + if(mihandle != nullptr) + miopenSetStream(mihandle.get(), ext_stream); +#endif +#if MIGRAPHX_USE_ROCBLAS + if(rbhandle != nullptr) + rocblas_set_stream(rbhandle.get(), ext_stream); +#endif + } + + void clear_external_stream() + { + if(external_stream_ == nullptr) + return; + external_stream_ = nullptr; + auto internal = get(); +#if MIGRAPHX_USE_MIOPEN + if(mihandle != nullptr) + miopenSetStream(mihandle.get(), internal); +#endif +#if MIGRAPHX_USE_ROCBLAS + if(rbhandle != nullptr) + rocblas_set_stream(rbhandle.get(), internal); +#endif + } + + bool has_external_stream() const { return external_stream_ != nullptr; } + void wait() const { + if(external_stream_ != nullptr) + { + setup(); + auto status = hipStreamSynchronize(external_stream_); + if(status != hipSuccess) + MIGRAPHX_THROW("Failed to wait: " + hip_error(status)); + return; + } if(s == nullptr) return; setup(); @@ -173,6 +214,7 @@ struct hip_device private: std::size_t id = 0; shared s = nullptr; + hipStream_t external_stream_ = nullptr; #if MIGRAPHX_USE_MIOPEN shared mihandle = nullptr; #endif @@ -336,20 +378,34 @@ struct context void wait_for(any_ptr queue) { - auto status = hipEventRecord(begin_event.get(), queue.get()); - if(status != hipSuccess) - MIGRAPHX_THROW("Failed to record: " + hip_error(status)); - - get_stream().wait(begin_event.get()); + auto ext = queue.get(); + if(ext != nullptr) + { + get_stream().set_external_stream(ext); + } + else + { + auto status = hipEventRecord(begin_event.get(), ext); + if(status != hipSuccess) + MIGRAPHX_THROW("Failed to record: " + hip_error(status)); + get_stream().wait(begin_event.get()); + } } void finish_on(any_ptr queue) { - get_stream().record(finish_event.get()); - - auto status = hipStreamWaitEvent(queue.get(), finish_event.get(), 0); - if(status != hipSuccess) - MIGRAPHX_THROW("Failed to wait on event: " + hip_error(status)); + if(get_stream().has_external_stream()) + { + get_stream().clear_external_stream(); + } + else + { + get_stream().record(finish_event.get()); + auto status = + hipStreamWaitEvent(queue.get(), finish_event.get(), 0); + if(status != hipSuccess) + MIGRAPHX_THROW("Failed to wait on event: " + hip_error(status)); + } } any_ptr get_queue() { return get_stream().get(); } From 136cd2429ca49253b7ffa32cbccebc106cba01a8 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 9 Apr 2026 18:43:12 -0500 Subject: [PATCH 02/11] [AIGenerated] Add tests for external streams as well as fallback modes --- test/gpu/external_stream.cpp | 459 +++++++++++++++++++++++++++++++++++ 1 file changed, 459 insertions(+) create mode 100644 test/gpu/external_stream.cpp diff --git a/test/gpu/external_stream.cpp b/test/gpu/external_stream.cpp new file mode 100644 index 00000000000..4463694065a --- /dev/null +++ b/test/gpu/external_stream.cpp @@ -0,0 +1,459 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "test.hpp" + +using hip_stream_ptr = MIGRAPHX_MANAGE_PTR(hipStream_t, hipStreamDestroy); + +static hip_stream_ptr create_external_stream() +{ + hipStream_t stream; + auto status = hipStreamCreate(&stream); + if(status != hipSuccess) + MIGRAPHX_THROW("Failed to create stream"); + return hip_stream_ptr{stream}; +} + +TEST_CASE(test_stream_override_get) +{ + migraphx::gpu::context ctx{}; + auto& stream = ctx.get_stream(); + + hipStream_t internal = stream.get(); + EXPECT(internal != nullptr); + + auto ext = create_external_stream(); + stream.set_external_stream(ext.get()); + + EXPECT(stream.get() == ext.get()); + EXPECT(stream.get() != internal); + EXPECT(stream.has_external_stream()); + + stream.clear_external_stream(); + + EXPECT(stream.get() == internal); + EXPECT(not stream.has_external_stream()); +} + +TEST_CASE(test_stream_override_get_queue) +{ + migraphx::gpu::context ctx{}; + auto ext = create_external_stream(); + + hipStream_t original_queue = ctx.get_queue().get(); + EXPECT(original_queue != nullptr); + + ctx.get_stream().set_external_stream(ext.get()); + EXPECT(ctx.get_queue().get() == ext.get()); + + ctx.get_stream().clear_external_stream(); + EXPECT(ctx.get_queue().get() == original_queue); +} + +TEST_CASE(test_context_wait_for_sets_external_stream) +{ + migraphx::gpu::context ctx{}; + auto ext = create_external_stream(); + + migraphx::any_ptr queue(ext.get()); + + hipStream_t before = ctx.get_queue().get(); + ctx.wait_for(queue); + EXPECT(ctx.get_queue().get() == ext.get()); + EXPECT(ctx.get_queue().get() != before); + + ctx.finish_on(queue); + EXPECT(ctx.get_queue().get() == before); +} + +TEST_CASE(test_external_stream_eval_uses_caller_stream) +{ + const unsigned int m = 64; + const unsigned int k = 128; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {m, k}}); + auto y = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {k, m}})); + mm->add_instruction(migraphx::make_op("dot"), x, y); + + p.compile(migraphx::make_target("gpu")); + + migraphx::shape input_shape{migraphx::shape::float_type, {m, k}}; + migraphx::shape output_shape{migraphx::shape::float_type, {m, m}}; + auto input = migraphx::fill_argument(input_shape, 1); + auto ginput = migraphx::gpu::to_gpu(input); + + auto output = migraphx::fill_argument(output_shape, 0); + auto goutput = migraphx::gpu::to_gpu(output); + + auto ext = create_external_stream(); + + auto results = + p.eval({{"x", ginput}, {"main:#output_0", goutput}}, {ext.get(), true}); + + EXPECT(not results.empty()); + + EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); + auto host_output = migraphx::gpu::from_gpu(goutput); + EXPECT(host_output != output); +} + +TEST_CASE(test_external_stream_serialized_on_caller_stream) +{ + const unsigned int n = 256; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 1.0f); + std::vector ydata(n, 2.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + auto ext = create_external_stream(); + + auto results = + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {ext.get(), true}); + + EXPECT(not results.empty()); + + EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); + auto host_result = migraphx::gpu::from_gpu(gout); + + auto* result_data = reinterpret_cast(host_result.data()); + bool all_correct = true; + for(unsigned int i = 0; i < n; ++i) + { + if(result_data[i] != 3.0f) + { + all_correct = false; + break; + } + } + EXPECT(all_correct); +} + +TEST_CASE(test_multiple_async_evals_same_stream) +{ + const unsigned int n = 128; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 1.0f); + std::vector ydata(n, 2.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + auto ext = create_external_stream(); + + for(int iter = 0; iter < 5; ++iter) + { + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {ext.get(), true}); + } + + EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); + auto host_result = migraphx::gpu::from_gpu(gout); + + auto* result_data = reinterpret_cast(host_result.data()); + bool all_correct = true; + for(unsigned int i = 0; i < n; ++i) + { + if(result_data[i] != 3.0f) + { + all_correct = false; + break; + } + } + EXPECT(all_correct); +} + +TEST_CASE(test_external_stream_cleared_after_eval) +{ + const unsigned int n = 64; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 1.0f); + std::vector ydata(n, 2.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + auto ext = create_external_stream(); + + migraphx::context& ctx_ref = p.get_context(); + auto* gpu_ctx = ctx_ref.any_cast(); + EXPECT(gpu_ctx != nullptr); + + hipStream_t internal_stream = gpu_ctx->get_queue().get(); + + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {ext.get(), true}); + + EXPECT(gpu_ctx->get_queue().get() == internal_stream); + EXPECT(not gpu_ctx->get_stream().has_external_stream()); +} + +TEST_CASE(test_wait_for_null_stream_uses_event_fallback) +{ + migraphx::gpu::context ctx{}; + + hipStream_t null_stream = nullptr; + migraphx::any_ptr queue(null_stream); + + hipStream_t internal_before = ctx.get_queue().get(); + + ctx.wait_for(queue); + + EXPECT(not ctx.get_stream().has_external_stream()); + EXPECT(ctx.get_queue().get() == internal_before); + + ctx.finish_on(queue); + + EXPECT(not ctx.get_stream().has_external_stream()); + EXPECT(ctx.get_queue().get() == internal_before); +} + +TEST_CASE(test_fallback_event_path_produces_correct_results) +{ + const unsigned int n = 128; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 5.0f); + std::vector ydata(n, 7.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + hipStream_t null_stream = nullptr; + auto results = + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {null_stream, true}); + + EXPECT(not results.empty()); + + EXPECT(hipDeviceSynchronize() == hipSuccess); + auto host_result = migraphx::gpu::from_gpu(gout); + + auto* result_data = reinterpret_cast(host_result.data()); + bool all_correct = true; + for(unsigned int i = 0; i < n; ++i) + { + if(result_data[i] != 12.0f) + { + all_correct = false; + break; + } + } + EXPECT(all_correct); +} + +TEST_CASE(test_non_async_eval_uses_internal_stream) +{ + const unsigned int n = 128; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 4.0f); + std::vector ydata(n, 6.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + migraphx::context& ctx_ref = p.get_context(); + auto* gpu_ctx = ctx_ref.any_cast(); + EXPECT(gpu_ctx != nullptr); + + auto results = p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}); + + EXPECT(not results.empty()); + EXPECT(not gpu_ctx->get_stream().has_external_stream()); + + p.finish(); + auto host_result = migraphx::gpu::from_gpu(gout); + + auto* result_data = reinterpret_cast(host_result.data()); + bool all_correct = true; + for(unsigned int i = 0; i < n; ++i) + { + if(result_data[i] != 10.0f) + { + all_correct = false; + break; + } + } + EXPECT(all_correct); +} + +TEST_CASE(test_mixed_async_and_sync_evals) +{ + const unsigned int n = 128; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 1.0f); + std::vector ydata(n, 2.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + migraphx::context& ctx_ref = p.get_context(); + auto* gpu_ctx = ctx_ref.any_cast(); + EXPECT(gpu_ctx != nullptr); + + auto ext = create_external_stream(); + + // Async eval with external stream + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {ext.get(), true}); + EXPECT(not gpu_ctx->get_stream().has_external_stream()); + EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); + + auto host_result = migraphx::gpu::from_gpu(gout); + auto* result_data = reinterpret_cast(host_result.data()); + for(unsigned int i = 0; i < n; ++i) + { + EXPECT(result_data[i] == 3.0f); + } + + // Sync eval with internal stream + auto gout2 = migraphx::gpu::to_gpu(out); + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout2}}); + EXPECT(not gpu_ctx->get_stream().has_external_stream()); + p.finish(); + + auto host_result2 = migraphx::gpu::from_gpu(gout2); + auto* result_data2 = reinterpret_cast(host_result2.data()); + for(unsigned int i = 0; i < n; ++i) + { + EXPECT(result_data2[i] == 3.0f); + } + + // Async eval again to confirm no stale state + auto gout3 = migraphx::gpu::to_gpu(out); + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout3}}, {ext.get(), true}); + EXPECT(not gpu_ctx->get_stream().has_external_stream()); + EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); + + auto host_result3 = migraphx::gpu::from_gpu(gout3); + auto* result_data3 = reinterpret_cast(host_result3.data()); + for(unsigned int i = 0; i < n; ++i) + { + EXPECT(result_data3[i] == 3.0f); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } From 15de4eec2fa85ca6d49e1e797b60e07f1a2db74e Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 9 Apr 2026 18:43:39 -0500 Subject: [PATCH 03/11] Fix format --- .../gpu/include/migraphx/gpu/context.hpp | 5 ++--- test/gpu/external_stream.cpp | 21 ++++++++----------- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/src/targets/gpu/include/migraphx/gpu/context.hpp b/src/targets/gpu/include/migraphx/gpu/context.hpp index c5126c578f0..7c10fc2d1bb 100644 --- a/src/targets/gpu/include/migraphx/gpu/context.hpp +++ b/src/targets/gpu/include/migraphx/gpu/context.hpp @@ -164,7 +164,7 @@ struct hip_device if(external_stream_ == nullptr) return; external_stream_ = nullptr; - auto internal = get(); + auto internal = get(); #if MIGRAPHX_USE_MIOPEN if(mihandle != nullptr) miopenSetStream(mihandle.get(), internal); @@ -401,8 +401,7 @@ struct context else { get_stream().record(finish_event.get()); - auto status = - hipStreamWaitEvent(queue.get(), finish_event.get(), 0); + auto status = hipStreamWaitEvent(queue.get(), finish_event.get(), 0); if(status != hipSuccess) MIGRAPHX_THROW("Failed to wait on event: " + hip_error(status)); } diff --git a/test/gpu/external_stream.cpp b/test/gpu/external_stream.cpp index 4463694065a..444fa4d4a5e 100644 --- a/test/gpu/external_stream.cpp +++ b/test/gpu/external_stream.cpp @@ -124,8 +124,7 @@ TEST_CASE(test_external_stream_eval_uses_caller_stream) auto ext = create_external_stream(); - auto results = - p.eval({{"x", ginput}, {"main:#output_0", goutput}}, {ext.get(), true}); + auto results = p.eval({{"x", ginput}, {"main:#output_0", goutput}}, {ext.get(), true}); EXPECT(not results.empty()); @@ -161,8 +160,7 @@ TEST_CASE(test_external_stream_serialized_on_caller_stream) auto ext = create_external_stream(); - auto results = - p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {ext.get(), true}); + auto results = p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {ext.get(), true}); EXPECT(not results.empty()); @@ -258,7 +256,7 @@ TEST_CASE(test_external_stream_cleared_after_eval) auto ext = create_external_stream(); migraphx::context& ctx_ref = p.get_context(); - auto* gpu_ctx = ctx_ref.any_cast(); + auto* gpu_ctx = ctx_ref.any_cast(); EXPECT(gpu_ctx != nullptr); hipStream_t internal_stream = gpu_ctx->get_queue().get(); @@ -315,8 +313,7 @@ TEST_CASE(test_fallback_event_path_produces_correct_results) auto gout = migraphx::gpu::to_gpu(out); hipStream_t null_stream = nullptr; - auto results = - p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {null_stream, true}); + auto results = p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {null_stream, true}); EXPECT(not results.empty()); @@ -362,7 +359,7 @@ TEST_CASE(test_non_async_eval_uses_internal_stream) auto gout = migraphx::gpu::to_gpu(out); migraphx::context& ctx_ref = p.get_context(); - auto* gpu_ctx = ctx_ref.any_cast(); + auto* gpu_ctx = ctx_ref.any_cast(); EXPECT(gpu_ctx != nullptr); auto results = p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}); @@ -412,7 +409,7 @@ TEST_CASE(test_mixed_async_and_sync_evals) auto gout = migraphx::gpu::to_gpu(out); migraphx::context& ctx_ref = p.get_context(); - auto* gpu_ctx = ctx_ref.any_cast(); + auto* gpu_ctx = ctx_ref.any_cast(); EXPECT(gpu_ctx != nullptr); auto ext = create_external_stream(); @@ -422,7 +419,7 @@ TEST_CASE(test_mixed_async_and_sync_evals) EXPECT(not gpu_ctx->get_stream().has_external_stream()); EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); - auto host_result = migraphx::gpu::from_gpu(gout); + auto host_result = migraphx::gpu::from_gpu(gout); auto* result_data = reinterpret_cast(host_result.data()); for(unsigned int i = 0; i < n; ++i) { @@ -435,7 +432,7 @@ TEST_CASE(test_mixed_async_and_sync_evals) EXPECT(not gpu_ctx->get_stream().has_external_stream()); p.finish(); - auto host_result2 = migraphx::gpu::from_gpu(gout2); + auto host_result2 = migraphx::gpu::from_gpu(gout2); auto* result_data2 = reinterpret_cast(host_result2.data()); for(unsigned int i = 0; i < n; ++i) { @@ -448,7 +445,7 @@ TEST_CASE(test_mixed_async_and_sync_evals) EXPECT(not gpu_ctx->get_stream().has_external_stream()); EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); - auto host_result3 = migraphx::gpu::from_gpu(gout3); + auto host_result3 = migraphx::gpu::from_gpu(gout3); auto* result_data3 = reinterpret_cast(host_result3.data()); for(unsigned int i = 0; i < n; ++i) { From a0c057e2f8cb6aee66d31d4572495a4feb01381e Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 9 Apr 2026 20:20:33 -0500 Subject: [PATCH 04/11] Cleanup --- .../gpu/include/migraphx/gpu/context.hpp | 32 +++---- test/gpu/external_stream.cpp | 83 ++++--------------- 2 files changed, 33 insertions(+), 82 deletions(-) diff --git a/src/targets/gpu/include/migraphx/gpu/context.hpp b/src/targets/gpu/include/migraphx/gpu/context.hpp index 7c10fc2d1bb..ad2e649f9b7 100644 --- a/src/targets/gpu/include/migraphx/gpu/context.hpp +++ b/src/targets/gpu/include/migraphx/gpu/context.hpp @@ -90,8 +90,8 @@ struct hip_device hipStream_t get() { - if(external_stream_ != nullptr) - return external_stream_; + if(external_stream != nullptr) + return external_stream; if(not enabled(MIGRAPHX_ENABLE_NULL_STREAM{})) { setup(); @@ -148,7 +148,7 @@ struct hip_device void set_external_stream(hipStream_t ext_stream) { - external_stream_ = ext_stream; + external_stream = ext_stream; #if MIGRAPHX_USE_MIOPEN if(mihandle != nullptr) miopenSetStream(mihandle.get(), ext_stream); @@ -161,10 +161,10 @@ struct hip_device void clear_external_stream() { - if(external_stream_ == nullptr) + if(external_stream == nullptr) return; - external_stream_ = nullptr; - auto internal = get(); + external_stream = nullptr; + auto *internal = get(); #if MIGRAPHX_USE_MIOPEN if(mihandle != nullptr) miopenSetStream(mihandle.get(), internal); @@ -175,14 +175,14 @@ struct hip_device #endif } - bool has_external_stream() const { return external_stream_ != nullptr; } + bool has_external_stream() const { return external_stream != nullptr; } void wait() const { - if(external_stream_ != nullptr) + if(external_stream != nullptr) { setup(); - auto status = hipStreamSynchronize(external_stream_); + auto status = hipStreamSynchronize(external_stream); if(status != hipSuccess) MIGRAPHX_THROW("Failed to wait: " + hip_error(status)); return; @@ -214,7 +214,7 @@ struct hip_device private: std::size_t id = 0; shared s = nullptr; - hipStream_t external_stream_ = nullptr; + hipStream_t external_stream = nullptr; #if MIGRAPHX_USE_MIOPEN shared mihandle = nullptr; #endif @@ -378,18 +378,18 @@ struct context void wait_for(any_ptr queue) { - auto ext = queue.get(); - if(ext != nullptr) - { - get_stream().set_external_stream(ext); - } - else + auto *ext = queue.get(); + if(ext == nullptr) { auto status = hipEventRecord(begin_event.get(), ext); if(status != hipSuccess) MIGRAPHX_THROW("Failed to record: " + hip_error(status)); get_stream().wait(begin_event.get()); } + else + { + get_stream().set_external_stream(ext); + } } void finish_on(any_ptr queue) diff --git a/test/gpu/external_stream.cpp b/test/gpu/external_stream.cpp index 444fa4d4a5e..7c12b786a69 100644 --- a/test/gpu/external_stream.cpp +++ b/test/gpu/external_stream.cpp @@ -47,6 +47,13 @@ static hip_stream_ptr create_external_stream() return hip_stream_ptr{stream}; } +static void verify_data(const migraphx::argument& result, const migraphx::shape& s, float expected) +{ + std::vector expected_data(s.elements(), expected); + auto expected_arg = migraphx::argument{s, expected_data.data()}; + EXPECT(result == expected_arg); +} + TEST_CASE(test_stream_override_get) { migraphx::gpu::context ctx{}; @@ -166,18 +173,7 @@ TEST_CASE(test_external_stream_serialized_on_caller_stream) EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); auto host_result = migraphx::gpu::from_gpu(gout); - - auto* result_data = reinterpret_cast(host_result.data()); - bool all_correct = true; - for(unsigned int i = 0; i < n; ++i) - { - if(result_data[i] != 3.0f) - { - all_correct = false; - break; - } - } - EXPECT(all_correct); + verify_data(host_result, out_shape, 3.0f); } TEST_CASE(test_multiple_async_evals_same_stream) @@ -214,18 +210,7 @@ TEST_CASE(test_multiple_async_evals_same_stream) EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); auto host_result = migraphx::gpu::from_gpu(gout); - - auto* result_data = reinterpret_cast(host_result.data()); - bool all_correct = true; - for(unsigned int i = 0; i < n; ++i) - { - if(result_data[i] != 3.0f) - { - all_correct = false; - break; - } - } - EXPECT(all_correct); + verify_data(host_result, out_shape, 3.0f); } TEST_CASE(test_external_stream_cleared_after_eval) @@ -319,18 +304,7 @@ TEST_CASE(test_fallback_event_path_produces_correct_results) EXPECT(hipDeviceSynchronize() == hipSuccess); auto host_result = migraphx::gpu::from_gpu(gout); - - auto* result_data = reinterpret_cast(host_result.data()); - bool all_correct = true; - for(unsigned int i = 0; i < n; ++i) - { - if(result_data[i] != 12.0f) - { - all_correct = false; - break; - } - } - EXPECT(all_correct); + verify_data(host_result, out_shape, 12.0f); } TEST_CASE(test_non_async_eval_uses_internal_stream) @@ -369,18 +343,7 @@ TEST_CASE(test_non_async_eval_uses_internal_stream) p.finish(); auto host_result = migraphx::gpu::from_gpu(gout); - - auto* result_data = reinterpret_cast(host_result.data()); - bool all_correct = true; - for(unsigned int i = 0; i < n; ++i) - { - if(result_data[i] != 10.0f) - { - all_correct = false; - break; - } - } - EXPECT(all_correct); + verify_data(host_result, out_shape, 10.0f); } TEST_CASE(test_mixed_async_and_sync_evals) @@ -419,12 +382,8 @@ TEST_CASE(test_mixed_async_and_sync_evals) EXPECT(not gpu_ctx->get_stream().has_external_stream()); EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); - auto host_result = migraphx::gpu::from_gpu(gout); - auto* result_data = reinterpret_cast(host_result.data()); - for(unsigned int i = 0; i < n; ++i) - { - EXPECT(result_data[i] == 3.0f); - } + auto host_result = migraphx::gpu::from_gpu(gout); + verify_data(host_result, out_shape, 3.0f); // Sync eval with internal stream auto gout2 = migraphx::gpu::to_gpu(out); @@ -432,12 +391,8 @@ TEST_CASE(test_mixed_async_and_sync_evals) EXPECT(not gpu_ctx->get_stream().has_external_stream()); p.finish(); - auto host_result2 = migraphx::gpu::from_gpu(gout2); - auto* result_data2 = reinterpret_cast(host_result2.data()); - for(unsigned int i = 0; i < n; ++i) - { - EXPECT(result_data2[i] == 3.0f); - } + auto host_result2 = migraphx::gpu::from_gpu(gout2); + verify_data(host_result2, out_shape, 3.0f); // Async eval again to confirm no stale state auto gout3 = migraphx::gpu::to_gpu(out); @@ -445,12 +400,8 @@ TEST_CASE(test_mixed_async_and_sync_evals) EXPECT(not gpu_ctx->get_stream().has_external_stream()); EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); - auto host_result3 = migraphx::gpu::from_gpu(gout3); - auto* result_data3 = reinterpret_cast(host_result3.data()); - for(unsigned int i = 0; i < n; ++i) - { - EXPECT(result_data3[i] == 3.0f); - } + auto host_result3 = migraphx::gpu::from_gpu(gout3); + verify_data(host_result3, out_shape, 3.0f); } int main(int argc, const char* argv[]) { test::run(argc, argv); } From 563c78b05d5961857fa9ee9b1826860fd27566ed Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Sat, 11 Apr 2026 00:01:51 -0500 Subject: [PATCH 05/11] Update context to not rebind on the same stream --- .../gpu/include/migraphx/gpu/context.hpp | 23 +++++-------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/src/targets/gpu/include/migraphx/gpu/context.hpp b/src/targets/gpu/include/migraphx/gpu/context.hpp index ad2e649f9b7..21a203c6fc3 100644 --- a/src/targets/gpu/include/migraphx/gpu/context.hpp +++ b/src/targets/gpu/include/migraphx/gpu/context.hpp @@ -148,6 +148,8 @@ struct hip_device void set_external_stream(hipStream_t ext_stream) { + if(external_stream == ext_stream) + return; external_stream = ext_stream; #if MIGRAPHX_USE_MIOPEN if(mihandle != nullptr) @@ -161,18 +163,9 @@ struct hip_device void clear_external_stream() { - if(external_stream == nullptr) - return; - external_stream = nullptr; - auto *internal = get(); -#if MIGRAPHX_USE_MIOPEN - if(mihandle != nullptr) - miopenSetStream(mihandle.get(), internal); -#endif -#if MIGRAPHX_USE_ROCBLAS - if(rbhandle != nullptr) - rocblas_set_stream(rbhandle.get(), internal); -#endif + // No-op: keep external stream bound to avoid repeated + // miopenSetStream/rocblas_set_stream rebinding on the next call. + // A different stream passed to set_external_stream will replace it. } bool has_external_stream() const { return external_stream != nullptr; } @@ -394,11 +387,7 @@ struct context void finish_on(any_ptr queue) { - if(get_stream().has_external_stream()) - { - get_stream().clear_external_stream(); - } - else + if(not get_stream().has_external_stream()) { get_stream().record(finish_event.get()); auto status = hipStreamWaitEvent(queue.get(), finish_event.get(), 0); From 25ea2466f883d977cb2158f09296f6f18cad8d23 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 16 Apr 2026 16:26:31 -0500 Subject: [PATCH 06/11] Remove noop clearStream call --- src/targets/gpu/include/migraphx/gpu/context.hpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/targets/gpu/include/migraphx/gpu/context.hpp b/src/targets/gpu/include/migraphx/gpu/context.hpp index 21a203c6fc3..7ec65d4d697 100644 --- a/src/targets/gpu/include/migraphx/gpu/context.hpp +++ b/src/targets/gpu/include/migraphx/gpu/context.hpp @@ -161,13 +161,6 @@ struct hip_device #endif } - void clear_external_stream() - { - // No-op: keep external stream bound to avoid repeated - // miopenSetStream/rocblas_set_stream rebinding on the next call. - // A different stream passed to set_external_stream will replace it. - } - bool has_external_stream() const { return external_stream != nullptr; } void wait() const From 054cc15a8e9a6b4c03032e01600477171ef26e01 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 16 Apr 2026 16:46:20 -0500 Subject: [PATCH 07/11] remove clear_stream from tests --- test/gpu/external_stream.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/gpu/external_stream.cpp b/test/gpu/external_stream.cpp index 7c12b786a69..e1cc9380c75 100644 --- a/test/gpu/external_stream.cpp +++ b/test/gpu/external_stream.cpp @@ -69,7 +69,6 @@ TEST_CASE(test_stream_override_get) EXPECT(stream.get() != internal); EXPECT(stream.has_external_stream()); - stream.clear_external_stream(); EXPECT(stream.get() == internal); EXPECT(not stream.has_external_stream()); @@ -86,7 +85,6 @@ TEST_CASE(test_stream_override_get_queue) ctx.get_stream().set_external_stream(ext.get()); EXPECT(ctx.get_queue().get() == ext.get()); - ctx.get_stream().clear_external_stream(); EXPECT(ctx.get_queue().get() == original_queue); } From a164acb416b4f03b65564441ffc8315d49d9da86 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 16 Apr 2026 22:54:16 -0500 Subject: [PATCH 08/11] Update context and tests --- .../gpu/include/migraphx/gpu/context.hpp | 20 +++++++++++++------ test/gpu/external_stream.cpp | 9 +++++---- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/src/targets/gpu/include/migraphx/gpu/context.hpp b/src/targets/gpu/include/migraphx/gpu/context.hpp index 7ec65d4d697..2e9a9190cdd 100644 --- a/src/targets/gpu/include/migraphx/gpu/context.hpp +++ b/src/targets/gpu/include/migraphx/gpu/context.hpp @@ -364,7 +364,11 @@ struct context void wait_for(any_ptr queue) { - auto *ext = queue.get(); + if(get_stream().has_external_stream()) + return; + if(queue.unsafe_get() == nullptr) + return; + auto* ext = queue.get(); if(ext == nullptr) { auto status = hipEventRecord(begin_event.get(), ext); @@ -380,13 +384,17 @@ struct context void finish_on(any_ptr queue) { - if(not get_stream().has_external_stream()) + if(get_stream().has_external_stream()) { - get_stream().record(finish_event.get()); - auto status = hipStreamWaitEvent(queue.get(), finish_event.get(), 0); - if(status != hipSuccess) - MIGRAPHX_THROW("Failed to wait on event: " + hip_error(status)); + get_stream().set_external_stream(nullptr); + return; } + if(queue.unsafe_get() == nullptr) + return; + get_stream().record(finish_event.get()); + auto status = hipStreamWaitEvent(queue.get(), finish_event.get(), 0); + if(status != hipSuccess) + MIGRAPHX_THROW("Failed to wait on event: " + hip_error(status)); } any_ptr get_queue() { return get_stream().get(); } diff --git a/test/gpu/external_stream.cpp b/test/gpu/external_stream.cpp index e1cc9380c75..4971e341c95 100644 --- a/test/gpu/external_stream.cpp +++ b/test/gpu/external_stream.cpp @@ -69,6 +69,7 @@ TEST_CASE(test_stream_override_get) EXPECT(stream.get() != internal); EXPECT(stream.has_external_stream()); + stream.set_external_stream(nullptr); EXPECT(stream.get() == internal); EXPECT(not stream.has_external_stream()); @@ -85,6 +86,8 @@ TEST_CASE(test_stream_override_get_queue) ctx.get_stream().set_external_stream(ext.get()); EXPECT(ctx.get_queue().get() == ext.get()); + ctx.get_stream().set_external_stream(nullptr); + EXPECT(ctx.get_queue().get() == original_queue); } @@ -254,8 +257,7 @@ TEST_CASE(test_wait_for_null_stream_uses_event_fallback) { migraphx::gpu::context ctx{}; - hipStream_t null_stream = nullptr; - migraphx::any_ptr queue(null_stream); + migraphx::any_ptr queue{}; hipStream_t internal_before = ctx.get_queue().get(); @@ -295,8 +297,7 @@ TEST_CASE(test_fallback_event_path_produces_correct_results) auto out = migraphx::fill_argument(out_shape, 0); auto gout = migraphx::gpu::to_gpu(out); - hipStream_t null_stream = nullptr; - auto results = p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {null_stream, true}); + auto results = p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {migraphx::any_ptr{}, true}); EXPECT(not results.empty()); From 54ad81bd51766331d1ef85c1d3daff6a43cff202 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 16 Apr 2026 23:05:14 -0500 Subject: [PATCH 09/11] fix format --- test/gpu/external_stream.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/gpu/external_stream.cpp b/test/gpu/external_stream.cpp index 4971e341c95..8c9f1087dc6 100644 --- a/test/gpu/external_stream.cpp +++ b/test/gpu/external_stream.cpp @@ -297,7 +297,8 @@ TEST_CASE(test_fallback_event_path_produces_correct_results) auto out = migraphx::fill_argument(out_shape, 0); auto gout = migraphx::gpu::to_gpu(out); - auto results = p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {migraphx::any_ptr{}, true}); + auto results = + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {migraphx::any_ptr{}, true}); EXPECT(not results.empty()); From bcc94e0d079fa632817e228c3f07647daad5786e Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Wed, 29 Apr 2026 17:36:08 -0500 Subject: [PATCH 10/11] Change interface to use set_queue for async calls instead of wait_for, finish_on --- src/include/migraphx/context.hpp | 61 ++++--------------- src/program.cpp | 4 +- .../gpu/include/migraphx/gpu/context.hpp | 34 +---------- test/gpu/external_stream.cpp | 52 ++-------------- 4 files changed, 22 insertions(+), 129 deletions(-) diff --git a/src/include/migraphx/context.hpp b/src/include/migraphx/context.hpp index 29c939c07a5..8755561319f 100644 --- a/src/include/migraphx/context.hpp +++ b/src/include/migraphx/context.hpp @@ -68,12 +68,7 @@ any_ptr get_queue_context(T&) } template -void wait_for_context(T&, any_ptr) -{ -} - -template -void finish_on_context(T&, any_ptr) +void use_queue_context(T&, any_ptr) { } @@ -89,9 +84,7 @@ struct MIGRAPHX_EXPORT context // (optional) any_ptr get_queue(); // (optional) - void wait_for(any_ptr queue); - // (optional) - void finish_on(any_ptr queue); + void use_queue(any_ptr queue); // void finish() const; }; @@ -143,30 +136,17 @@ struct context } template - static auto private_detail_te_default_wait_for(char, T&& private_detail_te_self, any_ptr queue) - -> decltype(private_detail_te_self.wait_for(queue)) - { - private_detail_te_self.wait_for(queue); - } - - template - static void private_detail_te_default_wait_for(float, T&& private_detail_te_self, any_ptr queue) - { - wait_for_context(private_detail_te_self, queue); - } - - template - static auto private_detail_te_default_finish_on(char, T&& private_detail_te_self, any_ptr queue) - -> decltype(private_detail_te_self.finish_on(queue)) + static auto private_detail_te_default_use_queue(char, T&& private_detail_te_self, any_ptr queue) + -> decltype(private_detail_te_self.use_queue(queue)) { - private_detail_te_self.finish_on(queue); + private_detail_te_self.use_queue(queue); } template static void - private_detail_te_default_finish_on(float, T&& private_detail_te_self, any_ptr queue) + private_detail_te_default_use_queue(float, T&& private_detail_te_self, any_ptr queue) { - finish_on_context(private_detail_te_self, queue); + use_queue_context(private_detail_te_self, queue); } template @@ -192,9 +172,7 @@ struct context std::declval()), private_detail_te_default_get_queue(char(0), std::declval()), - private_detail_te_default_wait_for( - char(0), std::declval(), std::declval()), - private_detail_te_default_finish_on( + private_detail_te_default_use_queue( char(0), std::declval(), std::declval()), std::declval().finish(), void()); @@ -289,16 +267,10 @@ struct context return (*this).private_detail_te_get_handle().get_queue(); } - void wait_for(any_ptr queue) + void use_queue(any_ptr queue) { assert((*this).private_detail_te_handle_mem_var); - (*this).private_detail_te_get_handle().wait_for(queue); - } - - void finish_on(any_ptr queue) - { - assert((*this).private_detail_te_handle_mem_var); - (*this).private_detail_te_get_handle().finish_on(queue); + (*this).private_detail_te_get_handle().use_queue(queue); } void finish() const @@ -323,8 +295,7 @@ struct context virtual value to_value() const = 0; virtual void from_value(const value& v) = 0; virtual any_ptr get_queue() = 0; - virtual void wait_for(any_ptr queue) = 0; - virtual void finish_on(any_ptr queue) = 0; + virtual void use_queue(any_ptr queue) = 0; virtual void finish() const = 0; }; @@ -374,16 +345,10 @@ struct context return private_detail_te_default_get_queue(char(0), private_detail_te_value); } - void wait_for(any_ptr queue) override - { - - private_detail_te_default_wait_for(char(0), private_detail_te_value, queue); - } - - void finish_on(any_ptr queue) override + void use_queue(any_ptr queue) override { - private_detail_te_default_finish_on(char(0), private_detail_te_value, queue); + private_detail_te_default_use_queue(char(0), private_detail_te_value, queue); } void finish() const override { private_detail_te_value.finish(); } diff --git a/src/program.cpp b/src/program.cpp index c24be628f36..73e45c4320e 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -597,7 +597,7 @@ std::vector program::eval(const parameter_map& params, if(exec_env.async) { assert(contexts.size() == 1); - contexts.front().wait_for(exec_env.queue); + contexts.front().use_queue(exec_env.queue); } if(trace_level > 0) @@ -662,7 +662,7 @@ std::vector program::eval(const parameter_map& params, if(exec_env.async) { assert(contexts.size() == 1); - contexts.front().finish_on(exec_env.queue); + contexts.front().use_queue(any_ptr{}); } return ret; diff --git a/src/targets/gpu/include/migraphx/gpu/context.hpp b/src/targets/gpu/include/migraphx/gpu/context.hpp index 2e9a9190cdd..a1e43353727 100644 --- a/src/targets/gpu/include/migraphx/gpu/context.hpp +++ b/src/targets/gpu/include/migraphx/gpu/context.hpp @@ -278,8 +278,6 @@ struct context }; context(std::size_t device_id = 0, std::size_t n = value_of(MIGRAPHX_NSTREAMS{}, 1)) : current_device(std::make_shared(device_id, n)), - begin_event(create_event()), - finish_event(create_event()), pc(std::make_shared()) { } @@ -362,39 +360,14 @@ struct context this->current_device = std::make_shared(device, n_streams); } - void wait_for(any_ptr queue) + void use_queue(any_ptr queue) { - if(get_stream().has_external_stream()) - return; if(queue.unsafe_get() == nullptr) - return; - auto* ext = queue.get(); - if(ext == nullptr) - { - auto status = hipEventRecord(begin_event.get(), ext); - if(status != hipSuccess) - MIGRAPHX_THROW("Failed to record: " + hip_error(status)); - get_stream().wait(begin_event.get()); - } - else - { - get_stream().set_external_stream(ext); - } - } - - void finish_on(any_ptr queue) - { - if(get_stream().has_external_stream()) { get_stream().set_external_stream(nullptr); return; } - if(queue.unsafe_get() == nullptr) - return; - get_stream().record(finish_event.get()); - auto status = hipStreamWaitEvent(queue.get(), finish_event.get(), 0); - if(status != hipSuccess) - MIGRAPHX_THROW("Failed to wait on event: " + hip_error(status)); + get_stream().set_external_stream(queue.get()); } any_ptr get_queue() { return get_stream().get(); } @@ -434,9 +407,6 @@ struct context // for event perf timing shared start_event = nullptr; shared stop_event = nullptr; - // for stream synchronization - shared begin_event = nullptr; - shared finish_event = nullptr; std::shared_ptr pc = nullptr; }; diff --git a/test/gpu/external_stream.cpp b/test/gpu/external_stream.cpp index 8c9f1087dc6..c2a7f4dec5e 100644 --- a/test/gpu/external_stream.cpp +++ b/test/gpu/external_stream.cpp @@ -91,7 +91,7 @@ TEST_CASE(test_stream_override_get_queue) EXPECT(ctx.get_queue().get() == original_queue); } -TEST_CASE(test_context_wait_for_sets_external_stream) +TEST_CASE(test_context_use_queue_sets_external_stream) { migraphx::gpu::context ctx{}; auto ext = create_external_stream(); @@ -99,11 +99,11 @@ TEST_CASE(test_context_wait_for_sets_external_stream) migraphx::any_ptr queue(ext.get()); hipStream_t before = ctx.get_queue().get(); - ctx.wait_for(queue); + ctx.use_queue(queue); EXPECT(ctx.get_queue().get() == ext.get()); EXPECT(ctx.get_queue().get() != before); - ctx.finish_on(queue); + ctx.use_queue(migraphx::any_ptr{}); EXPECT(ctx.get_queue().get() == before); } @@ -253,60 +253,18 @@ TEST_CASE(test_external_stream_cleared_after_eval) EXPECT(not gpu_ctx->get_stream().has_external_stream()); } -TEST_CASE(test_wait_for_null_stream_uses_event_fallback) +TEST_CASE(test_use_queue_null_clears_external_stream) { migraphx::gpu::context ctx{}; - migraphx::any_ptr queue{}; - hipStream_t internal_before = ctx.get_queue().get(); - ctx.wait_for(queue); - - EXPECT(not ctx.get_stream().has_external_stream()); - EXPECT(ctx.get_queue().get() == internal_before); - - ctx.finish_on(queue); + ctx.use_queue(migraphx::any_ptr{}); EXPECT(not ctx.get_stream().has_external_stream()); EXPECT(ctx.get_queue().get() == internal_before); } -TEST_CASE(test_fallback_event_path_produces_correct_results) -{ - const unsigned int n = 128; - - migraphx::program p; - auto* mm = p.get_main_module(); - - auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); - auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); - mm->add_instruction(migraphx::make_op("add"), x, y); - - p.compile(migraphx::make_target("gpu")); - - std::vector xdata(n, 5.0f); - std::vector ydata(n, 7.0f); - auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; - auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; - - auto gx = migraphx::gpu::to_gpu(xarg); - auto gy = migraphx::gpu::to_gpu(yarg); - - migraphx::shape out_shape{migraphx::shape::float_type, {n}}; - auto out = migraphx::fill_argument(out_shape, 0); - auto gout = migraphx::gpu::to_gpu(out); - - auto results = - p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {migraphx::any_ptr{}, true}); - - EXPECT(not results.empty()); - - EXPECT(hipDeviceSynchronize() == hipSuccess); - auto host_result = migraphx::gpu::from_gpu(gout); - verify_data(host_result, out_shape, 12.0f); -} - TEST_CASE(test_non_async_eval_uses_internal_stream) { const unsigned int n = 128; From 844bf9cd02f4f99bd65e4368d26d62a679e8831a Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 30 Apr 2026 09:29:13 -0500 Subject: [PATCH 11/11] Fix format --- src/targets/gpu/include/migraphx/gpu/context.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/include/migraphx/gpu/context.hpp b/src/targets/gpu/include/migraphx/gpu/context.hpp index a1e43353727..0c2ff2cdda5 100644 --- a/src/targets/gpu/include/migraphx/gpu/context.hpp +++ b/src/targets/gpu/include/migraphx/gpu/context.hpp @@ -406,7 +406,7 @@ struct context bool measure_perf = false; // for event perf timing shared start_event = nullptr; - shared stop_event = nullptr; + shared stop_event = nullptr; std::shared_ptr pc = nullptr; };