diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc index c9cd6e21b4eba..8302ed06c29f2 100644 --- a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/providers/shared_library/provider_api.h" #include "core/providers/migraphx/gpu_data_transfer.h" #include "core/providers/migraphx/migraphx_call.h" @@ -9,6 +11,23 @@ namespace onnxruntime { +namespace { + +struct StagingReturnInfo { + PinnedStagingPool* pool; + void* buffer; + size_t capacity; +}; + +void StagingReturnCallback(void* raw) { + std::unique_ptr info(static_cast(raw)); + info->pool->Release(info->buffer, info->capacity); +} + +} // namespace + +GPUDataTransfer::~GPUDataTransfer() = default; + bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { OrtDevice::DeviceType src_type = src_device.Type(); OrtDevice::DeviceType dst_type = dst_device.Type(); @@ -38,28 +57,26 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const const bool src_is_gpu_default = src_device.Type() == OrtDevice::GPU && src_device.MemType() == OrtDevice::MemType::DEFAULT; - // for the sync version of memcpy, launch to hip default stream + // Use the EP's compute stream (non-blocking) instead of the default (null) + // stream to avoid the implicit cross-stream serialisation that the default + // stream imposes on all other streams. if (dst_is_gpu_default) { if (src_is_gpu_default) { - // Copy only if the two addresses are different. if (dst_data != src_data) { - HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToDevice)); - // Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here. - HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, + hipMemcpyDeviceToDevice, stream_)); + HIP_RETURN_IF_ERROR(hipStreamSynchronize(stream_)); } } else { - // copy from other CPU memory to GPU, this is blocking - HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); - if (src_device.MemType() != OrtDevice::MemType::HOST_ACCESSIBLE) { - // Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here. - HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); - } + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, + hipMemcpyHostToDevice, stream_)); + HIP_RETURN_IF_ERROR(hipStreamSynchronize(stream_)); } } else if (src_is_gpu_default) { - // copying from GPU to CPU memory, this is blocking - HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost)); + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, + hipMemcpyDeviceToHost, stream_)); + HIP_RETURN_IF_ERROR(hipStreamSynchronize(stream_)); } else { - // copying between cpu memory ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } @@ -80,26 +97,59 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, const bool src_is_gpu_default = src_device.Type() == OrtDevice::GPU && src_device.MemType() == OrtDevice::MemType::DEFAULT; + auto hip_stream = static_cast(stream.GetHandle()); + if (dst_is_gpu_default) { if (src_is_gpu_default) { - // copying between GPU, this is non-blocking - HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, - static_cast(stream.GetHandle()))); + // D2D — always non-blocking + HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, hip_stream)); + } else if (src_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE || bytes < kStagingThreshold) { + // Pinned source or small transfer — hipMemcpyAsync is already truly async for pinned memory; + // for tiny pageable transfers the staging overhead isn't worth it. + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, hip_stream)); } else { - // If source are not pinned, the memory copy will be performed synchronously. - // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. - HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, - static_cast(stream.GetHandle()))); + // Pageable source above threshold — stage through a pinned buffer so the + // H2D DMA is truly async and the host thread returns immediately. + void* pinned = staging_pool_.Acquire(bytes); + if (pinned) { + std::memcpy(pinned, src_data, bytes); + auto err = hipMemcpyAsync(dst_data, pinned, bytes, hipMemcpyHostToDevice, hip_stream); + if (err != hipSuccess) { + staging_pool_.Release(pinned, bytes); + HIP_RETURN_IF_ERROR(err); + } + auto cb = std::make_unique(StagingReturnInfo{&staging_pool_, pinned, bytes}); + HIP_RETURN_IF_ERROR(hipLaunchHostFunc(hip_stream, StagingReturnCallback, cb.release())); + } else { + // hipHostMalloc failed — fall back to the (synchronous) direct path + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, hip_stream)); + } } } else if (src_is_gpu_default) { - // If dest are not pinned, the memory copy will be performed synchronously. - // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. - HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, - static_cast(stream.GetHandle()))); + if (dst_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE || bytes < kStagingThreshold) { + // Pinned dest or small transfer — hipMemcpyAsync is already efficient. + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, hip_stream)); + } else { + // Pageable dest above threshold — stage through pinned so the GPU→host + // DMA runs as one large transfer instead of the driver's internal chunking. + void* pinned = staging_pool_.Acquire(bytes); + if (pinned) { + auto err = hipMemcpyAsync(pinned, src_data, bytes, hipMemcpyDeviceToHost, hip_stream); + if (err != hipSuccess) { + staging_pool_.Release(pinned, bytes); + HIP_RETURN_IF_ERROR(err); + } + HIP_RETURN_IF_ERROR(hipStreamSynchronize(hip_stream)); + std::memcpy(dst_data, pinned, bytes); + staging_pool_.Release(pinned, bytes); + } else { + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, hip_stream)); + } + } } else { if (src_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE) { // sync the stream first to make sure the data arrived - HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream.GetHandle()))); + HIP_RETURN_IF_ERROR(hipStreamSynchronize(hip_stream)); } ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.h b/onnxruntime/core/providers/migraphx/gpu_data_transfer.h index a4eb8efd2afea..3a4d5f2d948d3 100644 --- a/onnxruntime/core/providers/migraphx/gpu_data_transfer.h +++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.h @@ -3,19 +3,96 @@ #pragma once +#include +#include +#include + #include "core/providers/migraphx/migraphx_inc.h" #include "core/framework/data_transfer.h" namespace onnxruntime { +// Thread-safe pool of hipHostMalloc'd staging buffers used to avoid the +// silent synchronous fallback that hipMemcpyAsync performs when handed +// pageable (non-pinned) host memory. Buffers are grown on demand and +// recycled between copies via hipLaunchHostFunc callbacks. +class PinnedStagingPool { + struct Buffer { + void* ptr; + size_t capacity; + }; + + public: + PinnedStagingPool() = default; + + ~PinnedStagingPool() { + (void)hipDeviceSynchronize(); + for (auto& b : pool_) { + (void)hipHostFree(b.ptr); + } + } + + PinnedStagingPool(const PinnedStagingPool&) = delete; + PinnedStagingPool& operator=(const PinnedStagingPool&) = delete; + + // Returns a pinned buffer with at least `bytes` capacity, or nullptr on + // allocation failure. Prefers the smallest adequate buffer already in + // the pool to minimise waste. + void* Acquire(size_t bytes) { + std::lock_guard lock(mu_); + auto best = pool_.end(); + for (auto it = pool_.begin(); it != pool_.end(); ++it) { + if (it->capacity >= bytes && + (best == pool_.end() || it->capacity < best->capacity)) { + best = it; + } + } + if (best != pool_.end()) { + void* p = best->ptr; + pool_.erase(best); + return p; + } + void* p = nullptr; + if (hipHostMalloc(&p, bytes) != hipSuccess) return nullptr; + return p; + } + + void Release(void* ptr, size_t capacity) { + std::lock_guard lock(mu_); + if (pool_.size() >= kMaxPoolSize) { + auto smallest = std::min_element( + pool_.begin(), pool_.end(), + [](const Buffer& a, const Buffer& b) { return a.capacity < b.capacity; }); + if (smallest != pool_.end() && smallest->capacity < capacity) { + (void)hipHostFree(smallest->ptr); + pool_.erase(smallest); + } else { + (void)hipHostFree(ptr); + return; + } + } + pool_.push_back({ptr, capacity}); + } + + private: + static constexpr size_t kMaxPoolSize = 8; + std::mutex mu_; + std::vector pool_; +}; + class GPUDataTransfer : public IDataTransfer { public: - GPUDataTransfer() = default; - ~GPUDataTransfer() = default; + explicit GPUDataTransfer(hipStream_t stream = nullptr) : stream_(stream) {} + ~GPUDataTransfer(); bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; common::Status CopyTensor(const Tensor& src, Tensor& dst) const override; common::Status CopyTensorAsync(const Tensor& src, Tensor& dst, Stream& stream) const override; + + private: + static constexpr size_t kStagingThreshold = 64 * 1024; // 64 KiB + hipStream_t stream_; + mutable PinnedStagingPool staging_pool_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 5be2a1a2336aa..5ced07d0a032c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -173,6 +173,9 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv external_stream_ = true; stream_ = static_cast(info.user_compute_stream); LOGS_DEFAULT(INFO) << "[MIGraphX EP] Using external user compute stream: " << stream_; + } else { + HIP_CALL_THROW(hipStreamCreateWithFlags(&stream_, hipStreamNonBlocking)); + LOGS_DEFAULT(INFO) << "[MIGraphX EP] Created non-blocking compute stream: " << stream_; } // Overwrite initialized values with values from environment variables. @@ -300,7 +303,7 @@ std::vector MIGraphXExecutionProvider::CreatePreferredAllocators() } std::unique_ptr MIGraphXExecutionProvider::GetDataTransfer() const { - return std::make_unique(); + return std::make_unique(stream_); } static bool IsTypeSupported(const NodeArg* node_arg) { @@ -1393,15 +1396,24 @@ static void pad_input_tensor(const void* src_data, void* dst_data, original_batch * bytes_per_batch, hipMemcpyDeviceToDevice, stream)); - // Pad with last batch element replicated + // Pad by replicating the last batch element using exponential doubling. + // Seed one copy, then double the filled region each iteration so the number + // of hipMemcpyAsync calls is O(log N) instead of O(N). if (original_batch > 0 && padded_batch > original_batch) { const char* last_batch = static_cast(src_data) + (original_batch - 1) * bytes_per_batch; char* pad_start = static_cast(dst_data) + original_batch * bytes_per_batch; - - for (std::size_t i = original_batch; i < padded_batch; ++i) { - HIP_CALL_THROW(hipMemcpyAsync(pad_start, last_batch, bytes_per_batch, + std::size_t slots_to_fill = padded_batch - original_batch; + + HIP_CALL_THROW(hipMemcpyAsync(pad_start, last_batch, bytes_per_batch, + hipMemcpyDeviceToDevice, stream)); + std::size_t filled = 1; + while (filled < slots_to_fill) { + std::size_t chunk = std::min(filled, slots_to_fill - filled); + HIP_CALL_THROW(hipMemcpyAsync(pad_start + filled * bytes_per_batch, + pad_start, + chunk * bytes_per_batch, hipMemcpyDeviceToDevice, stream)); - pad_start += bytes_per_batch; + filled += chunk; } } } @@ -1945,8 +1957,7 @@ static migraphx::program load_or_compile_model( // If original_batch_size is provided and < padded batch size, slices the output to remove padding static void run_migraphx_program( std::mutex* mgx_mu_ptr, - const OrtApi* api, - OrtKernelContext* context, + hipStream_t rocm_stream, Ort::KernelContext& ctx, migraphx::program& prog, migraphx::program_parameters& m, @@ -1954,13 +1965,10 @@ static void run_migraphx_program( std::size_t original_batch_size = 0, std::size_t padded_batch_size = 0) { - void* rocm_stream; - Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &rocm_stream)); - std::optional prog_outputs; { // Scoped lock for thread safety std::lock_guard lock(*mgx_mu_ptr); - prog_outputs = prog.run_async(m, static_cast(rocm_stream)); + prog_outputs = prog.run_async(m, rocm_stream); } bool needs_slicing = (original_batch_size > 0 && padded_batch_size > 0 && @@ -1978,55 +1986,38 @@ static void run_migraphx_program( } } - // First, handle pre-allocated outputs (need slicing but were already bound) - // NOTE: This is a defensive path - pre-allocated outputs should NOT exist when slicing is needed. + // Defensive path for pre-allocated outputs that need slicing. + // Callers should use temp output buffers when slicing is needed so this path + // is not reached in normal operation; it exists only as a safety net. if (needs_slicing && !prog_output_indices_set.empty()) { + // Sync once to ensure MIGraphX has finished writing all pre-allocated outputs + // before any buffer may be reallocated by GetOutput below. + HIP_CALL_THROW(hipStreamSynchronize(rocm_stream)); + for (std::size_t i = 0; i < output_num; ++i) { if (prog_output_indices_set.count(i) > 0) { - // This output was pre-allocated with padded shape - need to copy sliced data auto gpu_res = (*prog_outputs)[i]; migraphx::shape res_shape = gpu_res.get_shape(); auto res_lens = res_shape.lengths(); - - // Create sliced shape for ORT output + std::vector ort_shape{res_lens.begin(), res_lens.end()}; if (!ort_shape.empty() && static_cast(ort_shape[0]) != original_batch_size) { ort_shape[0] = static_cast(original_batch_size); - - // Calculate bytes to copy (sliced portion only) + std::size_t bytes_per_batch = res_shape.bytes() / padded_batch_size; std::size_t bytes_to_copy = bytes_per_batch * original_batch_size; - - // Allocate temp buffer for sliced data on GPU - void* temp_sliced_buffer = nullptr; - auto hip_status = hipMalloc(&temp_sliced_buffer, bytes_to_copy); - if (hip_status != hipSuccess) { - ORT_THROW("hipMalloc failed for sliced output buffer"); - } - - // Copy sliced data from MIGraphX output to temp buffer - HIP_CALL_THROW(hipMemcpyWithStream(temp_sliced_buffer, - gpu_res.data(), - bytes_to_copy, - hipMemcpyDeviceToDevice, - static_cast(rocm_stream))); - - // Synchronize to ensure copy is complete before allocating ORT output - HIP_CALL_THROW(hipStreamSynchronize(static_cast(rocm_stream))); - - // Now allocate the ORT output tensor with the SLICED shape + + const void* src_data = gpu_res.data(); auto output_tensor = ctx.GetOutput(i, ort_shape.data(), ort_shape.size()); void* output_data = output_tensor.GetTensorMutableRawData(); - - // Copy from temp buffer to ORT output - HIP_CALL_THROW(hipMemcpyWithStream(output_data, - temp_sliced_buffer, - bytes_to_copy, - hipMemcpyDeviceToDevice, - static_cast(rocm_stream))); - - // Free temporary buffer - (void)hipFree(temp_sliced_buffer); + + if (output_data != src_data) { + HIP_CALL_THROW(hipMemcpyWithStream(output_data, + src_data, + bytes_to_copy, + hipMemcpyDeviceToDevice, + rocm_stream)); + } } } } @@ -2039,16 +2030,14 @@ static void run_migraphx_program( migraphx::shape res_shape = gpu_res.get_shape(); auto res_lens = res_shape.lengths(); - // Adjust output shape if slicing is needed std::vector ort_shape{res_lens.begin(), res_lens.end()}; if (needs_slicing && !ort_shape.empty()) { - ort_shape[0] = original_batch_size; // Slice batch dimension + ort_shape[0] = original_batch_size; } auto output_tensor = ctx.GetOutput(i, ort_shape.data(), ort_shape.size()); void* output_data = output_tensor.GetTensorMutableRawData(); - // Calculate bytes to copy (slice if needed) std::size_t bytes_to_copy = res_shape.bytes(); if (needs_slicing && res_lens.size() > 0) { bytes_to_copy = (res_shape.bytes() / padded_batch_size) * original_batch_size; @@ -2058,7 +2047,7 @@ static void run_migraphx_program( gpu_res.data(), bytes_to_copy, hipMemcpyDeviceToDevice, - static_cast(rocm_stream))); + rocm_stream)); } } @@ -2323,8 +2312,7 @@ static std::vector build_input_shapes_in_cached_order( // Returns true if executed successfully, false if shapes don't match static bool execute_ultra_fast_path( MIGraphXFuncState* mgx_state, - const OrtApi* api, - OrtKernelContext* context, + hipStream_t rocm_stream, Ort::KernelContext& ctx) { if (!mgx_state->caches_valid || mgx_state->last_input_shapes_raw.empty()) { @@ -2430,9 +2418,6 @@ static bool execute_ultra_fast_path( // Allocate and pad inputs if needed for dynamic batching bool using_padded_inputs = false; if (padded_batch_size > original_batch_size) { - void* rocm_stream_ptr; - Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &rocm_stream_ptr)); - auto rocm_stream = static_cast(rocm_stream_ptr); using_padded_inputs = allocate_and_pad_inputs(mgx_state, ctx, original_batch_size, padded_batch_size, rocm_stream); } @@ -2462,7 +2447,7 @@ static bool execute_ultra_fast_path( } // Run directly - minimal overhead path - run_migraphx_program(mgx_state->mgx_mu_ptr, api, context, ctx, prog, m, + run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, m, mgx_state->cached_prog_output_indices, original_batch_size, padded_batch_size); @@ -2474,8 +2459,7 @@ static bool execute_ultra_fast_path( // Note: all_input_shapes is only consumed (moved) if the function returns true static bool execute_fast_path( MIGraphXFuncState* mgx_state, - const OrtApi* api, - OrtKernelContext* context, + hipStream_t rocm_stream, Ort::KernelContext& ctx, const std::string& current_hash, std::vector& all_input_shapes) @@ -2621,9 +2605,6 @@ static bool execute_fast_path( // Allocate and pad inputs if needed for dynamic batching bool using_padded_inputs = false; if (padded_batch_size > original_batch_size) { - void* rocm_stream_ptr; - Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &rocm_stream_ptr)); - auto rocm_stream = static_cast(rocm_stream_ptr); using_padded_inputs = allocate_and_pad_inputs(mgx_state, ctx, original_batch_size, padded_batch_size, rocm_stream); } @@ -2663,7 +2644,7 @@ static bool execute_fast_path( } } - run_migraphx_program(mgx_state->mgx_mu_ptr, api, context, ctx, prog, + run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, mgx_state->cached_prog_params.value(), mgx_state->cached_prog_output_indices, original_batch_size, padded_batch_size); @@ -2910,8 +2891,7 @@ static void compile_dynamic_batch_models( // Standard path: Shape checking, potential recompilation, and execution static void execute_standard_path( MIGraphXFuncState* mgx_state, - const OrtApi* api, - OrtKernelContext* context, + hipStream_t rocm_stream, Ort::KernelContext& ctx, const std::string& current_hash, std::vector&& all_input_shapes, @@ -3013,16 +2993,15 @@ static void execute_standard_path( } // Allocate and pad inputs for dynamic batching - void* rocm_stream_ptr; - Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &rocm_stream_ptr)); - auto rocm_stream = static_cast(rocm_stream_ptr); bool using_padded_inputs = allocate_and_pad_inputs(mgx_state, ctx, original_batch_size, padded_batch_size, rocm_stream); - // Bind inputs and outputs with temporary output buffers (for slicing) - std::vector temp_output_buffers; + // Get or reuse cached temp output buffers (avoids hipMalloc/hipFree per run) + auto temp_output_buffer_ptrs = get_or_allocate_temp_output_buffers( + mgx_state, param_shapes, output_shapes, map_input_name_index, padded_batch_size); + auto [m, prog_output_indices] = handle_program_input_outputs( - param_shapes, output_shapes, map_input_name_index, ctx, true, &temp_output_buffers); + param_shapes, output_shapes, map_input_name_index, ctx, true, &temp_output_buffer_ptrs); mgx_state->cached_prog_params = m; mgx_state->cached_prog_output_indices = prog_output_indices; @@ -3039,19 +3018,11 @@ static void execute_standard_path( } } - // Run with slicing enabled - run_migraphx_program(mgx_state->mgx_mu_ptr, api, context, ctx, prog, m, + run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, m, prog_output_indices, original_batch_size, padded_batch_size); - // Free temporary output buffers - for (void* buf : temp_output_buffers) { - if (buf != nullptr) { - (void)hipFree(buf); - } - } - - // NOTE: Padded buffers are kept for reuse - they will be freed when batch size changes - // or when the state is destroyed + // Temp output buffers are cached on mgx_state for reuse across runs. + // They are freed when the batch size changes or when the state is destroyed. return; } else { @@ -3075,8 +3046,8 @@ static void execute_standard_path( mgx_state->last_input_shape_hash = current_hash; mgx_state->caches_valid = true; - run_migraphx_program(mgx_state->mgx_mu_ptr, api, context, ctx, prog, m, prog_output_indices, - 0, 0); // Pass 0,0 for batch sizes to indicate no slicing needed + run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, m, prog_output_indices, + 0, 0); return; } @@ -3126,8 +3097,10 @@ static void execute_standard_path( mgx_state->last_input_shape_hash = current_hash; mgx_state->caches_valid = true; - run_migraphx_program(mgx_state->mgx_mu_ptr, api, context, ctx, prog, m, prog_output_indices, - original_batch_size, padded_batch_size); + // The program at this point was compiled (or already matched) for the exact + // runtime input shapes, so its outputs already have the original batch + // dimension. Pass 0,0 to avoid incorrect slicing arithmetic. + run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, m, prog_output_indices); } // Build MIGraphX ONNX options with default shapes for symbolic dimensions @@ -4077,20 +4050,31 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& }; compute_info.release_state_func = [](FunctionState state) { - if (state) - delete static_cast(state); + if (state) { + auto* s = static_cast(state); + for (auto& buf : s->padded_input_buffers) { + if (buf.data) (void)hipFree(buf.data); + } + for (auto& buf : s->temp_output_buffers) { + if (buf.data) (void)hipFree(buf.data); + } + delete s; + } }; - compute_info.compute_func = [this, mxr_filename_prefix](FunctionState state, const OrtApi* api, OrtKernelContext* context) { + compute_info.compute_func = [this, mxr_filename_prefix](FunctionState state, const OrtApi* /*api*/, OrtKernelContext* context) { Ort::KernelContext ctx(context); MIGraphXFuncState* mgx_state = reinterpret_cast(state); const auto& map_input_name_index = mgx_state->input_name_indexes; + // stream_ is always valid: either the user's external stream or an + // EP-owned hipStreamNonBlocking created in the constructor. + // ═══════════════════════════════════════════════════════════════════════ // ULTRA-FAST PATH: Shapes unchanged from last run // ═══════════════════════════════════════════════════════════════════════ - if (execute_ultra_fast_path(mgx_state, api, context, ctx)) { + if (execute_ultra_fast_path(mgx_state, stream_, ctx)) { return Status::OK(); } @@ -4108,14 +4092,14 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // ═══════════════════════════════════════════════════════════════════════ // FAST PATH: Check cached programs for this shape hash // ═══════════════════════════════════════════════════════════════════════ - if (execute_fast_path(mgx_state, api, context, ctx, current_hash, all_input_shapes)) { + if (execute_fast_path(mgx_state, stream_, ctx, current_hash, all_input_shapes)) { return Status::OK(); } // ═══════════════════════════════════════════════════════════════════════ // STANDARD PATH: Shape checking and potential recompilation // ═══════════════════════════════════════════════════════════════════════ - execute_standard_path(mgx_state, api, context, ctx, current_hash, std::move(all_input_shapes), + execute_standard_path(mgx_state, stream_, ctx, current_hash, std::move(all_input_shapes), model_cache_path_, model_path_, mxr_filename_prefix); return Status::OK(); @@ -4129,7 +4113,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& void MIGraphXExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const { auto allocator = allocators[GetOrtDeviceByMemType(OrtMemTypeCPU)]; - RegisterMIGraphXStreamHandles(stream_handle_registry, OrtDevice::GPU, allocator, true, stream_, external_stream_); + RegisterMIGraphXStreamHandles(stream_handle_registry, OrtDevice::GPU, allocator, true, stream_, /*use_existing_stream=*/true); } OrtDevice MIGraphXExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 75c3aa3429a02..24aeee586263d 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -162,7 +162,11 @@ struct MIGraphXFuncState { class MIGraphXExecutionProvider : public IExecutionProvider { public: explicit MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info); - ~MIGraphXExecutionProvider() override = default; + ~MIGraphXExecutionProvider() override { + if (!external_stream_ && stream_) { + (void)hipStreamDestroy(stream_); + } + } Status Sync() const override;