Skip to content
102 changes: 76 additions & 26 deletions onnxruntime/core/providers/migraphx/gpu_data_transfer.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <cstring>

#include "core/providers/shared_library/provider_api.h"
#include "core/providers/migraphx/gpu_data_transfer.h"
#include "core/providers/migraphx/migraphx_call.h"
Expand All @@ -9,6 +11,23 @@

namespace onnxruntime {

namespace {

struct StagingReturnInfo {
PinnedStagingPool* pool;
void* buffer;
size_t capacity;
};

void StagingReturnCallback(void* raw) {
std::unique_ptr<StagingReturnInfo> info(static_cast<StagingReturnInfo*>(raw));
info->pool->Release(info->buffer, info->capacity);
}

} // namespace

GPUDataTransfer::~GPUDataTransfer() = default;

bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const {
OrtDevice::DeviceType src_type = src_device.Type();
OrtDevice::DeviceType dst_type = dst_device.Type();
Expand Down Expand Up @@ -38,28 +57,26 @@
const bool src_is_gpu_default = src_device.Type() == OrtDevice::GPU &&
src_device.MemType() == OrtDevice::MemType::DEFAULT;

// for the sync version of memcpy, launch to hip default stream
// Use the EP's compute stream (non-blocking) instead of the default (null)
// stream to avoid the implicit cross-stream serialisation that the default

Check notice on line 61 in onnxruntime/core/providers/migraphx/gpu_data_transfer.cc

View workflow job for this annotation

GitHub Actions / misspell

[misspell] onnxruntime/core/providers/migraphx/gpu_data_transfer.cc#L61

"serialisation" is a misspelling of "serialization"
Raw output
./onnxruntime/core/providers/migraphx/gpu_data_transfer.cc:61:47: "serialisation" is a misspelling of "serialization"
// stream imposes on all other streams.
if (dst_is_gpu_default) {
if (src_is_gpu_default) {
// Copy only if the two addresses are different.
if (dst_data != src_data) {
HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToDevice));
// Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here.
HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr));
HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes,
hipMemcpyDeviceToDevice, stream_));
HIP_RETURN_IF_ERROR(hipStreamSynchronize(stream_));
}
} else {
// copy from other CPU memory to GPU, this is blocking
HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice));
if (src_device.MemType() != OrtDevice::MemType::HOST_ACCESSIBLE) {
// Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here.
HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr));
}
HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes,
hipMemcpyHostToDevice, stream_));
HIP_RETURN_IF_ERROR(hipStreamSynchronize(stream_));
}
} else if (src_is_gpu_default) {
// copying from GPU to CPU memory, this is blocking
HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost));
HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes,
hipMemcpyDeviceToHost, stream_));
HIP_RETURN_IF_ERROR(hipStreamSynchronize(stream_));
} else {
// copying between cpu memory
ORT_ENFORCE(dst_data != src_data);
memcpy(dst_data, src_data, bytes);
}
Expand All @@ -80,26 +97,59 @@
const bool src_is_gpu_default = src_device.Type() == OrtDevice::GPU &&
src_device.MemType() == OrtDevice::MemType::DEFAULT;

auto hip_stream = static_cast<hipStream_t>(stream.GetHandle());

if (dst_is_gpu_default) {
if (src_is_gpu_default) {
// copying between GPU, this is non-blocking
HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice,
static_cast<hipStream_t>(stream.GetHandle())));
// D2D — always non-blocking
HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, hip_stream));
} else if (src_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE || bytes < kStagingThreshold) {
// Pinned source or small transfer — hipMemcpyAsync is already truly async for pinned memory;
// for tiny pageable transfers the staging overhead isn't worth it.
HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, hip_stream));
} else {
// If source are not pinned, the memory copy will be performed synchronously.
// For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously.
HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice,
static_cast<hipStream_t>(stream.GetHandle())));
// Pageable source above threshold — stage through a pinned buffer so the
// H2D DMA is truly async and the host thread returns immediately.
void* pinned = staging_pool_.Acquire(bytes);
if (pinned) {
std::memcpy(pinned, src_data, bytes);
auto err = hipMemcpyAsync(dst_data, pinned, bytes, hipMemcpyHostToDevice, hip_stream);
if (err != hipSuccess) {
staging_pool_.Release(pinned, bytes);
HIP_RETURN_IF_ERROR(err);
}
auto cb = std::make_unique<StagingReturnInfo>(StagingReturnInfo{&staging_pool_, pinned, bytes});

Check notice on line 121 in onnxruntime/core/providers/migraphx/gpu_data_transfer.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/gpu_data_transfer.cc#L121

Add #include <memory> for make_unique<> [build/include_what_you_use] [4]
Raw output
onnxruntime/core/providers/migraphx/gpu_data_transfer.cc:121:  Add #include <memory> for make_unique<>  [build/include_what_you_use] [4]
HIP_RETURN_IF_ERROR(hipLaunchHostFunc(hip_stream, StagingReturnCallback, cb.release()));
} else {
// hipHostMalloc failed — fall back to the (synchronous) direct path
HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, hip_stream));
}
}
} else if (src_is_gpu_default) {
// If dest are not pinned, the memory copy will be performed synchronously.
// For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously.
HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost,
static_cast<hipStream_t>(stream.GetHandle())));
if (dst_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE || bytes < kStagingThreshold) {
// Pinned dest or small transfer — hipMemcpyAsync is already efficient.
HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, hip_stream));
} else {
// Pageable dest above threshold — stage through pinned so the GPU→host
// DMA runs as one large transfer instead of the driver's internal chunking.
void* pinned = staging_pool_.Acquire(bytes);
if (pinned) {
auto err = hipMemcpyAsync(pinned, src_data, bytes, hipMemcpyDeviceToHost, hip_stream);
if (err != hipSuccess) {
staging_pool_.Release(pinned, bytes);
HIP_RETURN_IF_ERROR(err);
}
HIP_RETURN_IF_ERROR(hipStreamSynchronize(hip_stream));
std::memcpy(dst_data, pinned, bytes);
staging_pool_.Release(pinned, bytes);
} else {
HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, hip_stream));
}
}
} else {
if (src_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE) {
// sync the stream first to make sure the data arrived
HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast<hipStream_t>(stream.GetHandle())));
HIP_RETURN_IF_ERROR(hipStreamSynchronize(hip_stream));
}
ORT_ENFORCE(dst_data != src_data);
memcpy(dst_data, src_data, bytes);
Expand Down
81 changes: 79 additions & 2 deletions onnxruntime/core/providers/migraphx/gpu_data_transfer.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,96 @@

#pragma once

#include <algorithm>
#include <mutex>
#include <vector>

#include "core/providers/migraphx/migraphx_inc.h"
#include "core/framework/data_transfer.h"

namespace onnxruntime {

// Thread-safe pool of hipHostMalloc'd staging buffers used to avoid the
// silent synchronous fallback that hipMemcpyAsync performs when handed
// pageable (non-pinned) host memory. Buffers are grown on demand and
// recycled between copies via hipLaunchHostFunc callbacks.
class PinnedStagingPool {
struct Buffer {
void* ptr;
size_t capacity;
};

public:
PinnedStagingPool() = default;

~PinnedStagingPool() {
(void)hipDeviceSynchronize();
for (auto& b : pool_) {
(void)hipHostFree(b.ptr);
}
}

PinnedStagingPool(const PinnedStagingPool&) = delete;
PinnedStagingPool& operator=(const PinnedStagingPool&) = delete;

// Returns a pinned buffer with at least `bytes` capacity, or nullptr on
// allocation failure. Prefers the smallest adequate buffer already in
// the pool to minimise waste.

Check notice on line 40 in onnxruntime/core/providers/migraphx/gpu_data_transfer.h

View workflow job for this annotation

GitHub Actions / misspell

[misspell] onnxruntime/core/providers/migraphx/gpu_data_transfer.h#L40

"minimise" is a misspelling of "minimize"
Raw output
./onnxruntime/core/providers/migraphx/gpu_data_transfer.h:40:17: "minimise" is a misspelling of "minimize"
void* Acquire(size_t bytes) {
std::lock_guard<std::mutex> lock(mu_);
auto best = pool_.end();
for (auto it = pool_.begin(); it != pool_.end(); ++it) {
if (it->capacity >= bytes &&
(best == pool_.end() || it->capacity < best->capacity)) {
best = it;
}
}
if (best != pool_.end()) {
void* p = best->ptr;
pool_.erase(best);
return p;
}
void* p = nullptr;
if (hipHostMalloc(&p, bytes) != hipSuccess) return nullptr;
return p;
}

void Release(void* ptr, size_t capacity) {
std::lock_guard<std::mutex> lock(mu_);
if (pool_.size() >= kMaxPoolSize) {
auto smallest = std::min_element(
pool_.begin(), pool_.end(),
[](const Buffer& a, const Buffer& b) { return a.capacity < b.capacity; });
if (smallest != pool_.end() && smallest->capacity < capacity) {
(void)hipHostFree(smallest->ptr);
pool_.erase(smallest);
} else {
(void)hipHostFree(ptr);
return;
}
}
pool_.push_back({ptr, capacity});
}

private:
static constexpr size_t kMaxPoolSize = 8;
std::mutex mu_;
std::vector<Buffer> pool_;
};

class GPUDataTransfer : public IDataTransfer {
public:
GPUDataTransfer() = default;
~GPUDataTransfer() = default;
explicit GPUDataTransfer(hipStream_t stream = nullptr) : stream_(stream) {}
~GPUDataTransfer();

bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override;
common::Status CopyTensor(const Tensor& src, Tensor& dst) const override;
common::Status CopyTensorAsync(const Tensor& src, Tensor& dst, Stream& stream) const override;

private:
static constexpr size_t kStagingThreshold = 64 * 1024; // 64 KiB
hipStream_t stream_;
mutable PinnedStagingPool staging_pool_;
};

} // namespace onnxruntime
Loading
Loading