Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 54 additions & 10 deletions src/targets/gpu/include/migraphx/gpu/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ struct hip_device

hipStream_t get()
{
if(external_stream != nullptr)
return external_stream;
if(not enabled(MIGRAPHX_ENABLE_NULL_STREAM{}))
{
setup();
Expand Down Expand Up @@ -144,8 +146,40 @@ struct hip_device
}
#endif

void set_external_stream(hipStream_t ext_stream)
{
if(external_stream == ext_stream)
return;
external_stream = ext_stream;
#if MIGRAPHX_USE_MIOPEN
if(mihandle != nullptr)
miopenSetStream(mihandle.get(), ext_stream);
#endif
#if MIGRAPHX_USE_ROCBLAS
if(rbhandle != nullptr)
rocblas_set_stream(rbhandle.get(), ext_stream);
#endif
}
Comment thread
TedThemistokleous marked this conversation as resolved.

void clear_external_stream()
{
// No-op: keep external stream bound to avoid repeated
Comment thread
TedThemistokleous marked this conversation as resolved.
Outdated
// miopenSetStream/rocblas_set_stream rebinding on the next call.
// A different stream passed to set_external_stream will replace it.
Comment thread
TedThemistokleous marked this conversation as resolved.
Outdated
}

bool has_external_stream() const { return external_stream != nullptr; }

void wait() const
{
if(external_stream != nullptr)
{
setup();
auto status = hipStreamSynchronize(external_stream);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to wait: " + hip_error(status));
return;
}
if(s == nullptr)
return;
setup();
Expand Down Expand Up @@ -173,6 +207,7 @@ struct hip_device
private:
std::size_t id = 0;
shared<hip_stream_ptr> s = nullptr;
hipStream_t external_stream = nullptr;
#if MIGRAPHX_USE_MIOPEN
shared<miopen_handle> mihandle = nullptr;
#endif
Expand Down Expand Up @@ -336,20 +371,29 @@ struct context

void wait_for(any_ptr queue)
{
auto status = hipEventRecord(begin_event.get(), queue.get<hipStream_t>());
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to record: " + hip_error(status));

get_stream().wait(begin_event.get());
auto *ext = queue.get<hipStream_t>();
if(ext == nullptr)
{
auto status = hipEventRecord(begin_event.get(), ext);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to record: " + hip_error(status));
get_stream().wait(begin_event.get());
}
else
{
Comment thread
CharlieL7 marked this conversation as resolved.
Outdated
Comment thread
bdevorem marked this conversation as resolved.
get_stream().set_external_stream(ext);
}
}

void finish_on(any_ptr queue)
{
get_stream().record(finish_event.get());

auto status = hipStreamWaitEvent(queue.get<hipStream_t>(), finish_event.get(), 0);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to wait on event: " + hip_error(status));
if(not get_stream().has_external_stream())
{
get_stream().record(finish_event.get());
auto status = hipStreamWaitEvent(queue.get<hipStream_t>(), finish_event.get(), 0);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to wait on event: " + hip_error(status));
}
Comment thread
CharlieL7 marked this conversation as resolved.
Outdated
}

any_ptr get_queue() { return get_stream().get(); }
Expand Down
Loading
Loading