diff --git a/src/include/migraphx/context.hpp b/src/include/migraphx/context.hpp index 29c939c07a5..8755561319f 100644 --- a/src/include/migraphx/context.hpp +++ b/src/include/migraphx/context.hpp @@ -68,12 +68,7 @@ any_ptr get_queue_context(T&) } template -void wait_for_context(T&, any_ptr) -{ -} - -template -void finish_on_context(T&, any_ptr) +void use_queue_context(T&, any_ptr) { } @@ -89,9 +84,7 @@ struct MIGRAPHX_EXPORT context // (optional) any_ptr get_queue(); // (optional) - void wait_for(any_ptr queue); - // (optional) - void finish_on(any_ptr queue); + void use_queue(any_ptr queue); // void finish() const; }; @@ -143,30 +136,17 @@ struct context } template - static auto private_detail_te_default_wait_for(char, T&& private_detail_te_self, any_ptr queue) - -> decltype(private_detail_te_self.wait_for(queue)) - { - private_detail_te_self.wait_for(queue); - } - - template - static void private_detail_te_default_wait_for(float, T&& private_detail_te_self, any_ptr queue) - { - wait_for_context(private_detail_te_self, queue); - } - - template - static auto private_detail_te_default_finish_on(char, T&& private_detail_te_self, any_ptr queue) - -> decltype(private_detail_te_self.finish_on(queue)) + static auto private_detail_te_default_use_queue(char, T&& private_detail_te_self, any_ptr queue) + -> decltype(private_detail_te_self.use_queue(queue)) { - private_detail_te_self.finish_on(queue); + private_detail_te_self.use_queue(queue); } template static void - private_detail_te_default_finish_on(float, T&& private_detail_te_self, any_ptr queue) + private_detail_te_default_use_queue(float, T&& private_detail_te_self, any_ptr queue) { - finish_on_context(private_detail_te_self, queue); + use_queue_context(private_detail_te_self, queue); } template @@ -192,9 +172,7 @@ struct context std::declval()), private_detail_te_default_get_queue(char(0), std::declval()), - private_detail_te_default_wait_for( - char(0), std::declval(), std::declval()), - private_detail_te_default_finish_on( + private_detail_te_default_use_queue( char(0), std::declval(), std::declval()), std::declval().finish(), void()); @@ -289,16 +267,10 @@ struct context return (*this).private_detail_te_get_handle().get_queue(); } - void wait_for(any_ptr queue) + void use_queue(any_ptr queue) { assert((*this).private_detail_te_handle_mem_var); - (*this).private_detail_te_get_handle().wait_for(queue); - } - - void finish_on(any_ptr queue) - { - assert((*this).private_detail_te_handle_mem_var); - (*this).private_detail_te_get_handle().finish_on(queue); + (*this).private_detail_te_get_handle().use_queue(queue); } void finish() const @@ -323,8 +295,7 @@ struct context virtual value to_value() const = 0; virtual void from_value(const value& v) = 0; virtual any_ptr get_queue() = 0; - virtual void wait_for(any_ptr queue) = 0; - virtual void finish_on(any_ptr queue) = 0; + virtual void use_queue(any_ptr queue) = 0; virtual void finish() const = 0; }; @@ -374,16 +345,10 @@ struct context return private_detail_te_default_get_queue(char(0), private_detail_te_value); } - void wait_for(any_ptr queue) override - { - - private_detail_te_default_wait_for(char(0), private_detail_te_value, queue); - } - - void finish_on(any_ptr queue) override + void use_queue(any_ptr queue) override { - private_detail_te_default_finish_on(char(0), private_detail_te_value, queue); + private_detail_te_default_use_queue(char(0), private_detail_te_value, queue); } void finish() const override { private_detail_te_value.finish(); } diff --git a/src/program.cpp b/src/program.cpp index c24be628f36..73e45c4320e 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -597,7 +597,7 @@ std::vector program::eval(const parameter_map& params, if(exec_env.async) { assert(contexts.size() == 1); - contexts.front().wait_for(exec_env.queue); + contexts.front().use_queue(exec_env.queue); } if(trace_level > 0) @@ -662,7 +662,7 @@ std::vector program::eval(const parameter_map& params, if(exec_env.async) { assert(contexts.size() == 1); - contexts.front().finish_on(exec_env.queue); + contexts.front().use_queue(any_ptr{}); } return ret; diff --git a/src/targets/gpu/include/migraphx/gpu/context.hpp b/src/targets/gpu/include/migraphx/gpu/context.hpp index e29414d41f3..0c2ff2cdda5 100644 --- a/src/targets/gpu/include/migraphx/gpu/context.hpp +++ b/src/targets/gpu/include/migraphx/gpu/context.hpp @@ -90,6 +90,8 @@ struct hip_device hipStream_t get() { + if(external_stream != nullptr) + return external_stream; if(not enabled(MIGRAPHX_ENABLE_NULL_STREAM{})) { setup(); @@ -144,8 +146,33 @@ struct hip_device } #endif + void set_external_stream(hipStream_t ext_stream) + { + if(external_stream == ext_stream) + return; + external_stream = ext_stream; +#if MIGRAPHX_USE_MIOPEN + if(mihandle != nullptr) + miopenSetStream(mihandle.get(), ext_stream); +#endif +#if MIGRAPHX_USE_ROCBLAS + if(rbhandle != nullptr) + rocblas_set_stream(rbhandle.get(), ext_stream); +#endif + } + + bool has_external_stream() const { return external_stream != nullptr; } + void wait() const { + if(external_stream != nullptr) + { + setup(); + auto status = hipStreamSynchronize(external_stream); + if(status != hipSuccess) + MIGRAPHX_THROW("Failed to wait: " + hip_error(status)); + return; + } if(s == nullptr) return; setup(); @@ -173,6 +200,7 @@ struct hip_device private: std::size_t id = 0; shared s = nullptr; + hipStream_t external_stream = nullptr; #if MIGRAPHX_USE_MIOPEN shared mihandle = nullptr; #endif @@ -250,8 +278,6 @@ struct context }; context(std::size_t device_id = 0, std::size_t n = value_of(MIGRAPHX_NSTREAMS{}, 1)) : current_device(std::make_shared(device_id, n)), - begin_event(create_event()), - finish_event(create_event()), pc(std::make_shared()) { } @@ -334,22 +360,14 @@ struct context this->current_device = std::make_shared(device, n_streams); } - void wait_for(any_ptr queue) + void use_queue(any_ptr queue) { - auto status = hipEventRecord(begin_event.get(), queue.get()); - if(status != hipSuccess) - MIGRAPHX_THROW("Failed to record: " + hip_error(status)); - - get_stream().wait(begin_event.get()); - } - - void finish_on(any_ptr queue) - { - get_stream().record(finish_event.get()); - - auto status = hipStreamWaitEvent(queue.get(), finish_event.get(), 0); - if(status != hipSuccess) - MIGRAPHX_THROW("Failed to wait on event: " + hip_error(status)); + if(queue.unsafe_get() == nullptr) + { + get_stream().set_external_stream(nullptr); + return; + } + get_stream().set_external_stream(queue.get()); } any_ptr get_queue() { return get_stream().get(); } @@ -388,10 +406,7 @@ struct context bool measure_perf = false; // for event perf timing shared start_event = nullptr; - shared stop_event = nullptr; - // for stream synchronization - shared begin_event = nullptr; - shared finish_event = nullptr; + shared stop_event = nullptr; std::shared_ptr pc = nullptr; }; diff --git a/test/gpu/external_stream.cpp b/test/gpu/external_stream.cpp new file mode 100644 index 00000000000..c2a7f4dec5e --- /dev/null +++ b/test/gpu/external_stream.cpp @@ -0,0 +1,365 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "test.hpp" + +using hip_stream_ptr = MIGRAPHX_MANAGE_PTR(hipStream_t, hipStreamDestroy); + +static hip_stream_ptr create_external_stream() +{ + hipStream_t stream; + auto status = hipStreamCreate(&stream); + if(status != hipSuccess) + MIGRAPHX_THROW("Failed to create stream"); + return hip_stream_ptr{stream}; +} + +static void verify_data(const migraphx::argument& result, const migraphx::shape& s, float expected) +{ + std::vector expected_data(s.elements(), expected); + auto expected_arg = migraphx::argument{s, expected_data.data()}; + EXPECT(result == expected_arg); +} + +TEST_CASE(test_stream_override_get) +{ + migraphx::gpu::context ctx{}; + auto& stream = ctx.get_stream(); + + hipStream_t internal = stream.get(); + EXPECT(internal != nullptr); + + auto ext = create_external_stream(); + stream.set_external_stream(ext.get()); + + EXPECT(stream.get() == ext.get()); + EXPECT(stream.get() != internal); + EXPECT(stream.has_external_stream()); + + stream.set_external_stream(nullptr); + + EXPECT(stream.get() == internal); + EXPECT(not stream.has_external_stream()); +} + +TEST_CASE(test_stream_override_get_queue) +{ + migraphx::gpu::context ctx{}; + auto ext = create_external_stream(); + + hipStream_t original_queue = ctx.get_queue().get(); + EXPECT(original_queue != nullptr); + + ctx.get_stream().set_external_stream(ext.get()); + EXPECT(ctx.get_queue().get() == ext.get()); + + ctx.get_stream().set_external_stream(nullptr); + + EXPECT(ctx.get_queue().get() == original_queue); +} + +TEST_CASE(test_context_use_queue_sets_external_stream) +{ + migraphx::gpu::context ctx{}; + auto ext = create_external_stream(); + + migraphx::any_ptr queue(ext.get()); + + hipStream_t before = ctx.get_queue().get(); + ctx.use_queue(queue); + EXPECT(ctx.get_queue().get() == ext.get()); + EXPECT(ctx.get_queue().get() != before); + + ctx.use_queue(migraphx::any_ptr{}); + EXPECT(ctx.get_queue().get() == before); +} + +TEST_CASE(test_external_stream_eval_uses_caller_stream) +{ + const unsigned int m = 64; + const unsigned int k = 128; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {m, k}}); + auto y = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {k, m}})); + mm->add_instruction(migraphx::make_op("dot"), x, y); + + p.compile(migraphx::make_target("gpu")); + + migraphx::shape input_shape{migraphx::shape::float_type, {m, k}}; + migraphx::shape output_shape{migraphx::shape::float_type, {m, m}}; + auto input = migraphx::fill_argument(input_shape, 1); + auto ginput = migraphx::gpu::to_gpu(input); + + auto output = migraphx::fill_argument(output_shape, 0); + auto goutput = migraphx::gpu::to_gpu(output); + + auto ext = create_external_stream(); + + auto results = p.eval({{"x", ginput}, {"main:#output_0", goutput}}, {ext.get(), true}); + + EXPECT(not results.empty()); + + EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); + auto host_output = migraphx::gpu::from_gpu(goutput); + EXPECT(host_output != output); +} + +TEST_CASE(test_external_stream_serialized_on_caller_stream) +{ + const unsigned int n = 256; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 1.0f); + std::vector ydata(n, 2.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + auto ext = create_external_stream(); + + auto results = p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {ext.get(), true}); + + EXPECT(not results.empty()); + + EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); + auto host_result = migraphx::gpu::from_gpu(gout); + verify_data(host_result, out_shape, 3.0f); +} + +TEST_CASE(test_multiple_async_evals_same_stream) +{ + const unsigned int n = 128; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 1.0f); + std::vector ydata(n, 2.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + auto ext = create_external_stream(); + + for(int iter = 0; iter < 5; ++iter) + { + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {ext.get(), true}); + } + + EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); + auto host_result = migraphx::gpu::from_gpu(gout); + verify_data(host_result, out_shape, 3.0f); +} + +TEST_CASE(test_external_stream_cleared_after_eval) +{ + const unsigned int n = 64; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 1.0f); + std::vector ydata(n, 2.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + auto ext = create_external_stream(); + + migraphx::context& ctx_ref = p.get_context(); + auto* gpu_ctx = ctx_ref.any_cast(); + EXPECT(gpu_ctx != nullptr); + + hipStream_t internal_stream = gpu_ctx->get_queue().get(); + + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {ext.get(), true}); + + EXPECT(gpu_ctx->get_queue().get() == internal_stream); + EXPECT(not gpu_ctx->get_stream().has_external_stream()); +} + +TEST_CASE(test_use_queue_null_clears_external_stream) +{ + migraphx::gpu::context ctx{}; + + hipStream_t internal_before = ctx.get_queue().get(); + + ctx.use_queue(migraphx::any_ptr{}); + + EXPECT(not ctx.get_stream().has_external_stream()); + EXPECT(ctx.get_queue().get() == internal_before); +} + +TEST_CASE(test_non_async_eval_uses_internal_stream) +{ + const unsigned int n = 128; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 4.0f); + std::vector ydata(n, 6.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + migraphx::context& ctx_ref = p.get_context(); + auto* gpu_ctx = ctx_ref.any_cast(); + EXPECT(gpu_ctx != nullptr); + + auto results = p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}); + + EXPECT(not results.empty()); + EXPECT(not gpu_ctx->get_stream().has_external_stream()); + + p.finish(); + auto host_result = migraphx::gpu::from_gpu(gout); + verify_data(host_result, out_shape, 10.0f); +} + +TEST_CASE(test_mixed_async_and_sync_evals) +{ + const unsigned int n = 128; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 1.0f); + std::vector ydata(n, 2.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + migraphx::context& ctx_ref = p.get_context(); + auto* gpu_ctx = ctx_ref.any_cast(); + EXPECT(gpu_ctx != nullptr); + + auto ext = create_external_stream(); + + // Async eval with external stream + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {ext.get(), true}); + EXPECT(not gpu_ctx->get_stream().has_external_stream()); + EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); + + auto host_result = migraphx::gpu::from_gpu(gout); + verify_data(host_result, out_shape, 3.0f); + + // Sync eval with internal stream + auto gout2 = migraphx::gpu::to_gpu(out); + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout2}}); + EXPECT(not gpu_ctx->get_stream().has_external_stream()); + p.finish(); + + auto host_result2 = migraphx::gpu::from_gpu(gout2); + verify_data(host_result2, out_shape, 3.0f); + + // Async eval again to confirm no stale state + auto gout3 = migraphx::gpu::to_gpu(out); + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout3}}, {ext.get(), true}); + EXPECT(not gpu_ctx->get_stream().has_external_stream()); + EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); + + auto host_result3 = migraphx::gpu::from_gpu(gout3); + verify_data(host_result3, out_shape, 3.0f); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/tools/include/context.hpp b/tools/include/context.hpp index 50baa9e3a21..4e02708c4de 100644 --- a/tools/include/context.hpp +++ b/tools/include/context.hpp @@ -68,21 +68,19 @@ any_ptr get_queue_context(T&) } template -void wait_for_context(T&, any_ptr) +void use_queue_context(T&, any_ptr) { } -template -void finish_on_context(T&, any_ptr){} - <% - interface('context', - virtual('to_value', returns = 'value', const = True, default = 'to_value_context'), - virtual('from_value', v = 'const value&', default = 'from_value_context'), - virtual('get_queue', returns = 'any_ptr', default = 'get_queue_context'), - virtual('wait_for', queue = 'any_ptr', returns = 'void', default = 'wait_for_context'), - virtual('finish_on', queue = 'any_ptr', returns = 'void', default = 'finish_on_context'), - virtual('finish', returns = 'void', const = True)) %> + interface( + 'context', + virtual('to_value', returns = 'value', const = True, default = 'to_value_context'), + virtual('from_value', v = 'const value&', default = 'from_value_context'), + virtual('get_queue', returns = 'any_ptr', default = 'get_queue_context'), + virtual('use_queue', queue = 'any_ptr', returns = 'void', default = 'use_queue_context'), + virtual('finish', returns = 'void', const = True)) +%> inline void migraphx_to_value(value& v, const context& ctx) {