Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Full documentation for MIGraphX is available at
* Added a fuse_horizontal pass which batches independent cross embedding gather instructions (#4599).
* Added GPU JIT `Resize` kernel (#4553).
* Added environment variable `MIGRAPHX_SKIP_BENCHMARKING` which when enabled, skips tuning of MIGraphX and rocMLIR kernels (#4628).
* Added cross-compilation support via `MIGRAPHX_GPU_ARCH` environment variable, enabling compilation for a target GPU architecture without a physical device present (#4795).
* Added Cubic resize jit kernel (#4652).
* Added JIT compiler for `fill` operation (#4666).
* Added JIT compiler for `multinomial` operation (#4721).
Expand Down
22 changes: 22 additions & 0 deletions docs/reference/MIGraphX-dev-env-vars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -742,4 +742,26 @@ Advanced settings

| Default: Benchmarking is not skipped.

* - | ``MIGRAPHX_GPU_ARCH``
| Enables cross-compilation mode by specifying a target GPU architecture without requiring a physical GPU.
| When set, kernel benchmarking and finalization are skipped. MIOpen, hipBLASLt, and CK operations are currently not supported in this mode.

- | Takes a valid GPU architecture string (e.g. ``gfx942``, ``gfx1100``).

| Default: Not set. A physical GPU is used.
Comment on lines +746 to +751

* - | ``MIGRAPHX_GPU_NUM_CU``
| Sets the number of compute units for cross-compilation mode. Only used when ``MIGRAPHX_GPU_ARCH`` is set.

- | Takes a positive integer.

| Default: ``120``

* - | ``MIGRAPHX_GPU_NUM_CHIPLETS``
| Sets the number of chiplets (XCCs) for cross-compilation mode. Only used when ``MIGRAPHX_GPU_ARCH`` is set.

- | Takes a positive integer.

| Default: ``1``


3 changes: 2 additions & 1 deletion src/include/migraphx/program.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -46,6 +46,7 @@ inline namespace MIGRAPHX_INLINE_NS {

MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_COMPILE)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_ARCH)

struct program_impl;

Expand Down
7 changes: 5 additions & 2 deletions src/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ void program::compile(const std::vector<target>& targets, std::vector<compile_op
}
}
}
this->finalize();
if(string_value_of(MIGRAPHX_GPU_ARCH{}).empty())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think we should use env variable here. Instead the target or context should tell us if we can run finalize.

this->finalize();
Comment on lines +311 to +312
}

void program::compile(const target& t, compile_options options)
Expand All @@ -326,6 +327,7 @@ void program::compile(const target& t, compile_options options)
auto&& passes = t.get_passes(this->impl->contexts.front(), options);
run_passes(*this, passes, options.trace);
auto mods = this->get_modules();
bool cross_compiling = not string_value_of(MIGRAPHX_GPU_ARCH{}).empty();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think we should use env variable here. Instead the target or context should tell us if we can run finalize.

// Validate and finalize
for(const auto& mod : reverse(mods))
{
Expand All @@ -342,7 +344,8 @@ void program::compile(const target& t, compile_options options)
MIGRAPHX_THROW("Dangling reference in module " + mod->name() + " from instruction " +
std::to_string(index));
}
mod->finalize(this->impl->contexts);
if(not cross_compiling)
mod->finalize(this->impl->contexts);
Comment on lines 330 to +348
}
}

Expand Down
1 change: 1 addition & 0 deletions src/targets/gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ add_library(migraphx_gpu
compile_miopen.cpp
compile_pointwise.cpp
compiler.cpp
cross_compile_device.cpp
device_name.cpp
eliminate_data_type_for_gpu.cpp
fixed_pad.cpp
Expand Down
3 changes: 2 additions & 1 deletion src/targets/gpu/compile_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,8 @@ struct compile_plan
if(solutions.empty())
MIGRAPHX_THROW("No solutions provided for " + preop.name() + " with " +
problem_string() + "\n\n" + print_modules());
if(enabled(MIGRAPHX_SKIP_BENCHMARKING{}) or solutions.size() == 1)
if(enabled(MIGRAPHX_SKIP_BENCHMARKING{}) or ctx->is_cross_compile() or
solutions.size() == 1)
{
ctx->get_problem_cache().insert(preop.name(), problem, solutions.front());
results.resize(1);
Expand Down
47 changes: 47 additions & 0 deletions src/targets/gpu/cross_compile_device.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/cross_compile_device.hpp>
#include <algorithm>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

hipDeviceProp_t make_cross_compile_device_props(const std::string& arch_name, std::size_t cu_count)
{
hipDeviceProp_t props{};
auto n = std::min(arch_name.size(), sizeof(props.gcnArchName) - 1);
std::copy_n(arch_name.begin(), n, props.gcnArchName);
props.gcnArchName[n] = '\0';
// these are placeholders
props.warpSize = 64;
props.maxThreadsPerMultiProcessor = 2048;
props.maxThreadsPerBlock = 1024;
props.multiProcessorCount = cu_count;
return props;
}

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
14 changes: 7 additions & 7 deletions src/targets/gpu/device_name.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ bool gfx_has_mx_intrinsics(const context& ctx)
}

#if MIGRAPHX_USE_HIPBLASLT
static bool hipblaslt_supported_impl(const std::string& gfx_name)
{
return (gfx_name == "gfx90a" or (starts_with(gfx_name, "gfx94") and gfx_name >= "gfx942") or
(starts_with(gfx_name, "gfx95") and gfx_name >= "gfx950") or
starts_with(gfx_name, "gfx110") or starts_with(gfx_name, "gfx120"));
}

static bool gfx_default_rocblas_impl(const std::string& gfx_name)
{
return ((string_value_of(MIGRAPHX_SET_GEMM_PROVIDER{}) == "hipblaslt")
Expand All @@ -137,13 +144,6 @@ bool gfx_default_rocblas(const context& ctx)
}
#endif

static bool hipblaslt_supported_impl(const std::string& gfx_name)
{
return (gfx_name == "gfx90a" or (starts_with(gfx_name, "gfx94") and gfx_name >= "gfx942") or
(starts_with(gfx_name, "gfx95") and gfx_name >= "gfx950") or
starts_with(gfx_name, "gfx110") or starts_with(gfx_name, "gfx120"));
}

bool hipblaslt_supported()
{
#if !MIGRAPHX_USE_HIPBLASLT
Expand Down
31 changes: 18 additions & 13 deletions src/targets/gpu/eliminate_data_type_for_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <migraphx/gpu/eliminate_data_type_for_gpu.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/functional.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -68,43 +69,47 @@
return eliminate_data_type{unsupported_types, shape::float_type, device_functions};
}

static eliminate_data_type for_fp8fnuz()
template <class F>
static bool query_device(const context* ctx, F f)
{
if(ctx != nullptr)
return f(*ctx);
return f();
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this. Everything should just use the function that takes the context. The other overloads should be removed, in a later PR.


static eliminate_data_type for_fp8fnuz(const context* ctx)
{
std::set<std::string> unsupported_ops = {};

// disable dot & quant_dot if no hipblaslt
if(not hipblaslt_supported())
if(not query_device(ctx, MIGRAPHX_LIFT(hipblaslt_supported)))

Check warning on line 84 in src/targets/gpu/eliminate_data_type_for_gpu.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: Use 'and' instead of && [UseNamedLogicOperator]
{
unsupported_ops.insert("dot");
unsupported_ops.insert("quant_dot");
}

// MIOpen doesn't have support for fp8 pooling yet.
insert_miopen_pooling(unsupported_ops);

if(not gpu::gfx_has_fp8fnuz_intrinsics())
if(not query_device(ctx, MIGRAPHX_LIFT(gfx_has_fp8fnuz_intrinsics)))

Check warning on line 92 in src/targets/gpu/eliminate_data_type_for_gpu.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: Use 'and' instead of && [UseNamedLogicOperator]
{
insert_gemm_conv(unsupported_ops);
}
return eliminate_data_type{
{shape::fp8e4m3fnuz_type, shape::fp8e5m2fnuz_type}, shape::float_type, unsupported_ops};
}

static eliminate_data_type for_fp8ocp()
static eliminate_data_type for_fp8ocp(const context* ctx)
{
std::set<std::string> unsupported_ops = {};

// disable dot & quant_dot if no hipblaslt
if(not hipblaslt_supported())
if(not query_device(ctx, MIGRAPHX_LIFT(hipblaslt_supported)))

Check warning on line 104 in src/targets/gpu/eliminate_data_type_for_gpu.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: Use 'and' instead of && [UseNamedLogicOperator]
{
unsupported_ops.insert("dot");
unsupported_ops.insert("quant_dot");
}

// MIOpen doesn't have support for fp8 pooling yet.
insert_miopen_pooling(unsupported_ops);

if(not gpu::gfx_has_fp8ocp_intrinsics())
if(not query_device(ctx, MIGRAPHX_LIFT(gfx_has_fp8ocp_intrinsics)))

Check warning on line 112 in src/targets/gpu/eliminate_data_type_for_gpu.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: Use 'and' instead of && [UseNamedLogicOperator]
{
insert_gemm_conv(unsupported_ops);
}
Expand Down Expand Up @@ -133,7 +138,7 @@
{
std::set<shape::type_t> unsupported_floats;
// No BF-16 Support on Navi21
if(not gpu::gfx_has_bf16_intrinsics())
if(not query_device(ctx, MIGRAPHX_LIFT(gfx_has_bf16_intrinsics)))

Check warning on line 141 in src/targets/gpu/eliminate_data_type_for_gpu.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: Use 'and' instead of && [UseNamedLogicOperator]
{
unsupported_floats.insert(shape::bf16_type);
}
Expand All @@ -158,8 +163,8 @@

mpm.run_pass(for_device_functions());

mpm.run_pass(for_fp8fnuz());
mpm.run_pass(for_fp8ocp());
mpm.run_pass(for_fp8fnuz(ctx));
mpm.run_pass(for_fp8ocp(ctx));

mpm.run_pass(for_gemm_conv());
}
Expand Down
53 changes: 48 additions & 5 deletions src/targets/gpu/include/migraphx/gpu/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/problem_cache.hpp>
#include <migraphx/gpu/hsa_chiplet.hpp>
#include <migraphx/gpu/cross_compile_device.hpp>
#include <unordered_map>
#include <memory>

Expand Down Expand Up @@ -69,6 +70,14 @@ struct hip_device
add_stream();
}

hip_device(const std::string& arch_name, std::size_t cu_count, std::size_t chiplets)
: cross_compile_mode(true),
chiplet_count_override(chiplets),
device_props(make_cross_compile_device_props(arch_name, cu_count))
{
add_stream();
}

struct stream
{
using hip_stream_ptr = MIGRAPHX_MANAGE_PTR(hipStream_t, hipStreamDestroy);
Expand Down Expand Up @@ -211,7 +220,14 @@ struct hip_device

std::size_t get_cu_count() const { return device_props.multiProcessorCount; }

std::size_t get_chiplet_count() const { return get_hsa_chiplet_count(device_id); }
std::size_t get_chiplet_count() const
{
if(cross_compile_mode)
return chiplet_count_override;
return get_hsa_chiplet_count(device_id);
}

bool is_cross_compile() const { return cross_compile_mode; }

std::size_t get_max_workitems_per_cu() const
{
Expand All @@ -223,8 +239,10 @@ struct hip_device
std::size_t get_wavefront_size() const { return device_props.warpSize; }

private:
std::size_t device_id = 0;
std::size_t current_stream = 0;
std::size_t device_id = 0;
std::size_t current_stream = 0;
bool cross_compile_mode = false;
std::size_t chiplet_count_override = 1;
std::vector<stream> streams;
hipDeviceProp_t device_props;

Expand Down Expand Up @@ -256,6 +274,12 @@ struct context
{
}

context(const std::string& arch_name, std::size_t cu_count, std::size_t chiplets)
: current_device(std::make_shared<hip_device>(arch_name, cu_count, chiplets)),
pc(std::make_shared<auto_save_problem_cache>())
{
}

hip_device& get_current_device()
{
assert(current_device != nullptr);
Expand All @@ -268,6 +292,11 @@ struct context
return *current_device;
}

bool is_cross_compile() const
{
return current_device != nullptr and current_device->is_cross_compile();
}

bool get_exhaustive_tune_flag() const { return exhaustive_tune; }

void set_exhaustive_tune_flag(bool t) { exhaustive_tune = t; }
Expand All @@ -292,7 +321,12 @@ struct context
hipEvent_t get_event(std::size_t i) const { return events.at(i).get(); }

std::vector<argument> literals{};
void finish() const { get_stream().wait(); }
void finish() const
{
if(is_cross_compile())
MIGRAPHX_THROW("Cannot execute in cross-compilation mode");
get_stream().wait();
}

static hip_event_ptr create_event()
{
Expand Down Expand Up @@ -336,6 +370,8 @@ struct context

void wait_for(any_ptr queue)
{
if(is_cross_compile())
MIGRAPHX_THROW("Cannot execute in cross-compilation mode");
auto status = hipEventRecord(begin_event.get(), queue.get<hipStream_t>());
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to record: " + hip_error(status));
Expand All @@ -345,14 +381,21 @@ struct context

void finish_on(any_ptr queue)
{
if(is_cross_compile())
MIGRAPHX_THROW("Cannot execute in cross-compilation mode");
get_stream().record(finish_event.get());

auto status = hipStreamWaitEvent(queue.get<hipStream_t>(), finish_event.get(), 0);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to wait on event: " + hip_error(status));
}

any_ptr get_queue() { return get_stream().get(); }
any_ptr get_queue()
{
if(is_cross_compile())
MIGRAPHX_THROW("Cannot execute in cross-compilation mode");
return get_stream().get();
}

std::pair<hipEvent_t, hipEvent_t> get_perf_events() const
{
Expand Down
Loading
Loading