diff --git a/src/driver/verify.cpp b/src/driver/verify.cpp index ff201f0a4db..d79f7de8435 100644 --- a/src/driver/verify.cpp +++ b/src/driver/verify.cpp @@ -137,6 +137,9 @@ static std::vector run_target(program p, auto arg = inputs.count(x.first) == 0 ? generate_argument(x.second) : inputs.at(x.first); m[x.first] = options.offload_copy ? arg : t.copy_to(arg); } + + p.eval(m); // run once for hip graph capture + auto gpu_out = p.eval(m); std::vector output(gpu_out.size()); std::cout << p << std::endl; diff --git a/src/include/migraphx/context.hpp b/src/include/migraphx/context.hpp index 29c939c07a5..a4a3fd4baaa 100644 --- a/src/include/migraphx/context.hpp +++ b/src/include/migraphx/context.hpp @@ -32,6 +32,7 @@ #include #include #include +#include #include namespace migraphx { @@ -77,6 +78,17 @@ void finish_on_context(T&, any_ptr) { } +template +void nop_context(T&, Ts&&...) +{ +} + +template +std::function()> get_capture_context(T&) +{ + return nullptr; +} + #ifdef TYPE_ERASED_DECLARATION // Type-erased interface for: @@ -92,6 +104,12 @@ struct MIGRAPHX_EXPORT context void wait_for(any_ptr queue); // (optional) void finish_on(any_ptr queue); + // (optional) + void start_capture(); + // (optional) + void end_capture(const std::vector& args); + // (optional) + std::function()> get_capture(); // void finish() const; }; @@ -169,6 +187,50 @@ struct context finish_on_context(private_detail_te_self, queue); } + template + static auto private_detail_te_default_start_capture(char, T&& private_detail_te_self) + -> decltype(private_detail_te_self.start_capture()) + { + private_detail_te_self.start_capture(); + } + + template + static void private_detail_te_default_start_capture(float, T&& private_detail_te_self) + { + nop_context(private_detail_te_self); + } + + template + static auto private_detail_te_default_end_capture(char, + T&& private_detail_te_self, + const std::vector& args) + -> decltype(private_detail_te_self.end_capture(args)) + { + private_detail_te_self.end_capture(args); + } + + template + static void private_detail_te_default_end_capture(float, + T&& private_detail_te_self, + const std::vector& args) + { + nop_context(private_detail_te_self, args); + } + + template + static auto private_detail_te_default_get_capture(char, T&& private_detail_te_self) + -> decltype(private_detail_te_self.get_capture()) + { + return private_detail_te_self.get_capture(); + } + + template + static std::function()> + private_detail_te_default_get_capture(float, T&& private_detail_te_self) + { + return get_capture_context(private_detail_te_self); + } + template struct private_te_unwrap_reference { @@ -196,6 +258,14 @@ struct context char(0), std::declval(), std::declval()), private_detail_te_default_finish_on( char(0), std::declval(), std::declval()), + private_detail_te_default_start_capture(char(0), + std::declval()), + private_detail_te_default_end_capture( + char(0), + std::declval(), + std::declval&>()), + private_detail_te_default_get_capture(char(0), + std::declval()), std::declval().finish(), void()); @@ -301,6 +371,24 @@ struct context (*this).private_detail_te_get_handle().finish_on(queue); } + void start_capture() + { + assert((*this).private_detail_te_handle_mem_var); + (*this).private_detail_te_get_handle().start_capture(); + } + + void end_capture(const std::vector& args) + { + assert((*this).private_detail_te_handle_mem_var); + (*this).private_detail_te_get_handle().end_capture(args); + } + + std::function()> get_capture() + { + assert((*this).private_detail_te_handle_mem_var); + return (*this).private_detail_te_get_handle().get_capture(); + } + void finish() const { assert((*this).private_detail_te_handle_mem_var); @@ -320,12 +408,15 @@ struct context virtual std::shared_ptr clone() const = 0; virtual const std::type_info& type() const = 0; - virtual value to_value() const = 0; - virtual void from_value(const value& v) = 0; - virtual any_ptr get_queue() = 0; - virtual void wait_for(any_ptr queue) = 0; - virtual void finish_on(any_ptr queue) = 0; - virtual void finish() const = 0; + virtual value to_value() const = 0; + virtual void from_value(const value& v) = 0; + virtual any_ptr get_queue() = 0; + virtual void wait_for(any_ptr queue) = 0; + virtual void finish_on(any_ptr queue) = 0; + virtual void start_capture() = 0; + virtual void end_capture(const std::vector& args) = 0; + virtual std::function()> get_capture() = 0; + virtual void finish() const = 0; }; template @@ -386,6 +477,24 @@ struct context private_detail_te_default_finish_on(char(0), private_detail_te_value, queue); } + void start_capture() override + { + + private_detail_te_default_start_capture(char(0), private_detail_te_value); + } + + void end_capture(const std::vector& args) override + { + + private_detail_te_default_end_capture(char(0), private_detail_te_value, args); + } + + std::function()> get_capture() override + { + + return private_detail_te_default_get_capture(char(0), private_detail_te_value); + } + void finish() const override { private_detail_te_value.finish(); } PrivateDetailTypeErasedT private_detail_te_value; diff --git a/src/include/migraphx/program.hpp b/src/include/migraphx/program.hpp index 8bc7310c2d2..3f985962a14 100644 --- a/src/include/migraphx/program.hpp +++ b/src/include/migraphx/program.hpp @@ -46,6 +46,7 @@ inline namespace MIGRAPHX_INLINE_NS { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_COMPILE) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL) +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIP_GRAPH) struct program_impl; diff --git a/src/layout_convolution.cpp b/src/layout_convolution.cpp index 3bb0824f0d8..3f5e38e505a 100644 --- a/src/layout_convolution.cpp +++ b/src/layout_convolution.cpp @@ -119,9 +119,7 @@ void remove_layout(module& m) { if(ins->name() != "layout") continue; - auto perm = ins->get_operator().to_value()["permutation"].to_vector(); - auto iperm = find_permutation(ins->inputs().front()->get_shape()); - if(perm != iperm) + if(find_permutation(ins->get_shape()) != find_permutation(ins->inputs().front()->get_shape())) continue; m.replace_instruction(ins, ins->inputs().front()); } diff --git a/src/program.cpp b/src/program.cpp index 4ca08842e74..aa27a144328 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -590,6 +590,15 @@ std::vector program::eval(const parameter_map& params, { auto& contexts = this->impl->contexts; + if(contexts.size() == 1 and enabled(MIGRAPHX_ENABLE_HIP_GRAPH{})) + { + auto& ctx = contexts.front(); + auto run = ctx.get_capture(); + if(run != nullptr) + return run(); + ctx.start_capture(); + } + auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{}); std::vector ret; @@ -664,6 +673,12 @@ std::vector program::eval(const parameter_map& params, contexts.front().finish_on(exec_env.queue); } + if(contexts.size() == 1 and enabled(MIGRAPHX_ENABLE_HIP_GRAPH{})) + { + auto& ctx = contexts.front(); + ctx.end_capture(ret); + } + return ret; } diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 06ff9a25f1e..06963da5b9d 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -106,6 +106,23 @@ static bool concat_const_foldable(Iterator start, Iterator last, std::size_t iax }); } +MIGRAPHX_PRED_MATCHER(conv_1x1, instruction_ref ins) +{ + if(ins->name() != "convolution") + return false; + auto v = ins->get_operator().to_value(); + if(v.at("group").to() != 1) + return false; + if(not all_of(v.at("stride"), [](const value& x) { return x.to() == 1; })) + return false; + if(not all_of(v.at("padding"), [](const value& x) { return x.to() == 0; })) + return false; + if(not all_of(v.at("dilation"), [](const value& x) { return x.to() == 1; })) + return false; + auto w = ins->inputs().at(1)->get_shape(); + return std::all_of(w.lens().begin() + 2, w.lens().end(), [](std::size_t i) { return i == 1; }); +} + // conv(x, w) * a => conv(x, a * w) struct find_mul_conv { @@ -1093,6 +1110,61 @@ struct find_concat_conv } }; +// (x * w1) * w2 => x * (w1 * w2) +struct find_conv_conv_1x1 +{ + auto matcher() const + { + return conv_1x1( + match::arg(0)(match::used_once(), match::name("convolution").bind("input"))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto input = r.instructions["input"]; + auto x_ins = input->inputs().front(); + auto wnxn = input->inputs()[1]; + auto w1x1 = ins->inputs()[1]; + + auto out_channels = w1x1->get_shape().lens()[0]; + auto mid_channels = w1x1->get_shape().lens()[1]; + auto in_channels_per_group = wnxn->get_shape().lens()[1]; + auto groups = x_ins->get_shape().lens()[1] / in_channels_per_group; + auto w_size = std::accumulate(wnxn->get_shape().lens().begin() + 2, + wnxn->get_shape().lens().end(), + std::size_t{1}, + std::multiplies<>{}); + + auto mw_dims = wnxn->get_shape().lens(); + mw_dims[1] *= groups; + + auto w1x1_reshaped = m.insert_instruction( + ins, + make_op("reshape", {{"dims", {out_channels, groups, mid_channels / groups}}}), + w1x1); + auto w1x1_grouped = m.insert_instruction( + ins, make_op("transpose", {{"permutation", {1, 0, 2}}}), w1x1_reshaped); + + auto wnxn_reshaped = m.insert_instruction( + ins, + make_op("reshape", + {{"dims", {groups, mid_channels / groups, in_channels_per_group * w_size}}}), + wnxn); + + auto mw = m.insert_instruction(ins, make_op("dot"), w1x1_grouped, wnxn_reshaped); + auto mw_transposed = + m.insert_instruction(ins, make_op("transpose", {{"permutation", {1, 0, 2}}}), mw); + auto mw_reshaped = + m.insert_instruction(ins, make_op("reshape", {{"dims", mw_dims}}), mw_transposed); + + auto op = input->get_operator(); + op.from_value({{"group", 1}}); + auto conv = m.insert_instruction(ins, op, x_ins, mw_reshaped); + m.replace_instruction(ins, conv); + } +}; + static void move_instructions_back(module& m, instruction_ref pos, std::vector inss) { @@ -2395,6 +2467,7 @@ void simplify_algebra::apply(module& m) const find_zero_ops{}, find_dot_add{}, find_conv_add{}, + find_conv_conv_1x1{}, find_div_const{}, find_sub_const{}, find_rsqrt{}, diff --git a/src/targets/gpu/compile_hip_code_object.cpp b/src/targets/gpu/compile_hip_code_object.cpp index 99da5306d52..91dbc7bb556 100644 --- a/src/targets/gpu/compile_hip_code_object.cpp +++ b/src/targets/gpu/compile_hip_code_object.cpp @@ -29,6 +29,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -151,7 +152,7 @@ compute_global_for(const context& ctx, std::size_t n, std::size_t over) ctx.get_current_device().get_max_workitems_per_cu(); return [n, over, max_global](std::size_t local) { std::size_t num_elements = n; - if(not hip_accept_non_uniform_wg()) + if(enabled(MIGRAPHX_ENABLE_HIP_GRAPH{}) or not hip_accept_non_uniform_wg()) { num_elements = (1 + (n - 1) / local) * local; } diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index e35b1274328..40c1af1203a 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -372,14 +372,15 @@ auto is_mlir_conv(mlir_mode mode) if(mode == mlir_mode::all) return true; // No winograd for group convolution - if(group > 1) - return true; - auto w = ins->inputs().at(1)->get_shape(); - if(w.lens().size() != 4) - return true; - if(w.lens()[2] != w.lens()[3]) - return true; - return (w.lens()[3] % 3) != 0; + return group == 1; + // if(group > 1) + // return true; + // auto w = ins->inputs().at(1)->get_shape(); + // if(w.lens().size() != 4) + // return true; + // if(w.lens()[2] != w.lens()[3]) + // return true; + // return (w.lens()[3] % 3) != 0; }); } diff --git a/src/targets/gpu/include/migraphx/gpu/context.hpp b/src/targets/gpu/include/migraphx/gpu/context.hpp index e29414d41f3..7e076eb9dbe 100644 --- a/src/targets/gpu/include/migraphx/gpu/context.hpp +++ b/src/targets/gpu/include/migraphx/gpu/context.hpp @@ -51,6 +51,8 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NULL_STREAM) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_NSTREAMS) using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy); +using hip_graph_ptr = MIGRAPHX_MANAGE_PTR(hipGraph_t, hipGraphDestroy); +using hip_graph_exec_ptr = MIGRAPHX_MANAGE_PTR(hipGraphExec_t, hipGraphExecDestroy); struct hip_device { @@ -380,6 +382,38 @@ struct context pc->auto_save = true; } + void start_capture() + { + // hipStreamCaptureModeThreadLocal + // hipStreamCaptureModeGlobal + // hipStreamCaptureModeRelaxed + auto status = hipStreamBeginCapture(get_stream().get(), hipStreamCaptureModeGlobal); + if(status != hipSuccess) + MIGRAPHX_THROW("Failed: hipStreamBeginCapture: " + hip_error(status)); + } + void end_capture(const std::vector& args) + { + hipGraph_t raw_graph = nullptr; + auto status = hipStreamEndCapture(get_stream().get(), &raw_graph); + auto graph = share(hip_graph_ptr{raw_graph}); + if(status != hipSuccess) + MIGRAPHX_THROW("Failed: hipStreamEndCapture: " + hip_error(status)); + + // auto log = make_shared_array(1024); + hipGraphExec_t raw_graph_exec = nullptr; + status = hipGraphInstantiate(&raw_graph_exec, graph.get(), nullptr, nullptr, 0); + auto graph_exec = share(hip_graph_exec_ptr{raw_graph_exec}); + if(status != hipSuccess) + MIGRAPHX_THROW("Failed: hipGraphInstantiate: " + hip_error(status)); + saved_graph = [this, graph, graph_exec, args] { + auto status2 = hipGraphLaunch(graph_exec.get(), get_stream().get()); + if(status2 != hipSuccess) + MIGRAPHX_THROW("Failed: hipGraphLaunch: " + hip_error(status2)); + return args; + }; + } + std::function()> get_capture() const { return saved_graph; } + private: // TODO: Make this a vector to support multiple devices std::shared_ptr current_device; @@ -393,6 +427,8 @@ struct context shared begin_event = nullptr; shared finish_event = nullptr; std::shared_ptr pc = nullptr; + + std::function()> saved_graph = nullptr; }; inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); } diff --git a/tools/include/context.hpp b/tools/include/context.hpp index 50baa9e3a21..ac1e4972d22 100644 --- a/tools/include/context.hpp +++ b/tools/include/context.hpp @@ -32,6 +32,7 @@ #include #include #include +#include #include namespace migraphx { @@ -75,14 +76,35 @@ void wait_for_context(T&, any_ptr) template void finish_on_context(T&, any_ptr){} +template +void nop_context(T&, Ts&&...) +{ +} + +template +std::function()> get_capture_context(T&) +{ + return nullptr; +} + <% - interface('context', - virtual('to_value', returns = 'value', const = True, default = 'to_value_context'), - virtual('from_value', v = 'const value&', default = 'from_value_context'), - virtual('get_queue', returns = 'any_ptr', default = 'get_queue_context'), - virtual('wait_for', queue = 'any_ptr', returns = 'void', default = 'wait_for_context'), - virtual('finish_on', queue = 'any_ptr', returns = 'void', default = 'finish_on_context'), - virtual('finish', returns = 'void', const = True)) %> + interface( + 'context', + virtual('to_value', returns = 'value', const = True, default = 'to_value_context'), + virtual('from_value', v = 'const value&', default = 'from_value_context'), + virtual('get_queue', returns = 'any_ptr', default = 'get_queue_context'), + virtual('wait_for', queue = 'any_ptr', returns = 'void', default = 'wait_for_context'), + virtual('finish_on', queue = 'any_ptr', returns = 'void', default = 'finish_on_context'), + virtual('start_capture', returns = 'void', default = 'nop_context'), + virtual('end_capture', + args = 'const std::vector&', + returns = 'void', + default = 'nop_context'), + virtual('get_capture', + returns = 'std::function()>', + default = 'get_capture_context'), + virtual('finish', returns = 'void', const = True)) +%> inline void migraphx_to_value(value& v, const context& ctx) {