Skip to content
61 changes: 13 additions & 48 deletions src/include/migraphx/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,7 @@ any_ptr get_queue_context(T&)
}

template <class T>
void wait_for_context(T&, any_ptr)
{
}

template <class T>
void finish_on_context(T&, any_ptr)
void use_queue_context(T&, any_ptr)
{
}

Expand All @@ -89,9 +84,7 @@ struct MIGRAPHX_EXPORT context
// (optional)
any_ptr get_queue();
// (optional)
void wait_for(any_ptr queue);
// (optional)
void finish_on(any_ptr queue);
void use_queue(any_ptr queue);
//
void finish() const;
};
Expand Down Expand Up @@ -143,30 +136,17 @@ struct context
}

template <class T>
static auto private_detail_te_default_wait_for(char, T&& private_detail_te_self, any_ptr queue)
-> decltype(private_detail_te_self.wait_for(queue))
{
private_detail_te_self.wait_for(queue);
}

template <class T>
static void private_detail_te_default_wait_for(float, T&& private_detail_te_self, any_ptr queue)
{
wait_for_context(private_detail_te_self, queue);
}

template <class T>
static auto private_detail_te_default_finish_on(char, T&& private_detail_te_self, any_ptr queue)
-> decltype(private_detail_te_self.finish_on(queue))
static auto private_detail_te_default_use_queue(char, T&& private_detail_te_self, any_ptr queue)
-> decltype(private_detail_te_self.use_queue(queue))
{
private_detail_te_self.finish_on(queue);
private_detail_te_self.use_queue(queue);
}

template <class T>
static void
private_detail_te_default_finish_on(float, T&& private_detail_te_self, any_ptr queue)
private_detail_te_default_use_queue(float, T&& private_detail_te_self, any_ptr queue)
{
finish_on_context(private_detail_te_self, queue);
use_queue_context(private_detail_te_self, queue);
}

template <class PrivateDetailTypeErasedT>
Expand All @@ -192,9 +172,7 @@ struct context
std::declval<const value&>()),
private_detail_te_default_get_queue(char(0),
std::declval<PrivateDetailTypeErasedT>()),
private_detail_te_default_wait_for(
char(0), std::declval<PrivateDetailTypeErasedT>(), std::declval<any_ptr>()),
private_detail_te_default_finish_on(
private_detail_te_default_use_queue(
char(0), std::declval<PrivateDetailTypeErasedT>(), std::declval<any_ptr>()),
std::declval<PrivateDetailTypeErasedT>().finish(),
void());
Expand Down Expand Up @@ -289,16 +267,10 @@ struct context
return (*this).private_detail_te_get_handle().get_queue();
}

void wait_for(any_ptr queue)
void use_queue(any_ptr queue)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().wait_for(queue);
}

void finish_on(any_ptr queue)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().finish_on(queue);
(*this).private_detail_te_get_handle().use_queue(queue);
}

void finish() const
Expand All @@ -323,8 +295,7 @@ struct context
virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0;
virtual any_ptr get_queue() = 0;
virtual void wait_for(any_ptr queue) = 0;
virtual void finish_on(any_ptr queue) = 0;
virtual void use_queue(any_ptr queue) = 0;
virtual void finish() const = 0;
};

Expand Down Expand Up @@ -374,16 +345,10 @@ struct context
return private_detail_te_default_get_queue(char(0), private_detail_te_value);
}

void wait_for(any_ptr queue) override
{

private_detail_te_default_wait_for(char(0), private_detail_te_value, queue);
}

void finish_on(any_ptr queue) override
void use_queue(any_ptr queue) override
{

private_detail_te_default_finish_on(char(0), private_detail_te_value, queue);
private_detail_te_default_use_queue(char(0), private_detail_te_value, queue);
}

void finish() const override { private_detail_te_value.finish(); }
Expand Down
4 changes: 2 additions & 2 deletions src/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ std::vector<argument> program::eval(const parameter_map& params,
if(exec_env.async)
{
assert(contexts.size() == 1);
contexts.front().wait_for(exec_env.queue);
contexts.front().use_queue(exec_env.queue);
}

if(trace_level > 0)
Expand Down Expand Up @@ -662,7 +662,7 @@ std::vector<argument> program::eval(const parameter_map& params,
if(exec_env.async)
{
assert(contexts.size() == 1);
contexts.front().finish_on(exec_env.queue);
contexts.front().use_queue(any_ptr{});
}

return ret;
Expand Down
55 changes: 35 additions & 20 deletions src/targets/gpu/include/migraphx/gpu/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ struct hip_device

hipStream_t get()
{
if(external_stream != nullptr)
return external_stream;
if(not enabled(MIGRAPHX_ENABLE_NULL_STREAM{}))
{
setup();
Expand Down Expand Up @@ -144,8 +146,33 @@ struct hip_device
}
#endif

void set_external_stream(hipStream_t ext_stream)
{
if(external_stream == ext_stream)
return;
external_stream = ext_stream;
#if MIGRAPHX_USE_MIOPEN
if(mihandle != nullptr)
miopenSetStream(mihandle.get(), ext_stream);
#endif
#if MIGRAPHX_USE_ROCBLAS
if(rbhandle != nullptr)
rocblas_set_stream(rbhandle.get(), ext_stream);
#endif
}
Comment thread
TedThemistokleous marked this conversation as resolved.

bool has_external_stream() const { return external_stream != nullptr; }

void wait() const
{
if(external_stream != nullptr)
{
setup();
auto status = hipStreamSynchronize(external_stream);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to wait: " + hip_error(status));
return;
}
if(s == nullptr)
return;
setup();
Expand Down Expand Up @@ -173,6 +200,7 @@ struct hip_device
private:
std::size_t id = 0;
shared<hip_stream_ptr> s = nullptr;
hipStream_t external_stream = nullptr;
#if MIGRAPHX_USE_MIOPEN
shared<miopen_handle> mihandle = nullptr;
#endif
Expand Down Expand Up @@ -250,8 +278,6 @@ struct context
};
context(std::size_t device_id = 0, std::size_t n = value_of(MIGRAPHX_NSTREAMS{}, 1))
: current_device(std::make_shared<hip_device>(device_id, n)),
begin_event(create_event()),
finish_event(create_event()),
pc(std::make_shared<auto_save_problem_cache>())
{
}
Expand Down Expand Up @@ -334,22 +360,14 @@ struct context
this->current_device = std::make_shared<hip_device>(device, n_streams);
}

void wait_for(any_ptr queue)
void use_queue(any_ptr queue)
{
auto status = hipEventRecord(begin_event.get(), queue.get<hipStream_t>());
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to record: " + hip_error(status));

get_stream().wait(begin_event.get());
}

void finish_on(any_ptr queue)
{
get_stream().record(finish_event.get());

auto status = hipStreamWaitEvent(queue.get<hipStream_t>(), finish_event.get(), 0);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to wait on event: " + hip_error(status));
if(queue.unsafe_get() == nullptr)
{
get_stream().set_external_stream(nullptr);
return;
}
get_stream().set_external_stream(queue.get<hipStream_t>());
}

any_ptr get_queue() { return get_stream().get(); }
Expand Down Expand Up @@ -389,9 +407,6 @@ struct context
// for event perf timing
shared<hip_event_ptr> start_event = nullptr;
shared<hip_event_ptr> stop_event = nullptr;
// for stream synchronization
shared<hip_event_ptr> begin_event = nullptr;
shared<hip_event_ptr> finish_event = nullptr;
std::shared_ptr<auto_save_problem_cache> pc = nullptr;
};

Expand Down
Loading
Loading