Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/mscclpp/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <bitset>
#include <future>
#include <memory>
#include <mscclpp/env.hpp>
#include <mscclpp/errors.hpp>
#include <mscclpp/gpu_data_types.hpp>
#include <mscclpp/version.hpp>
Expand Down Expand Up @@ -430,7 +431,7 @@ struct EndpointConfig {
int maxWrPerSend = DefaultMaxWrPerSend, Mode mode = Mode::Default)
: deviceIndex(deviceIndex),
port(port),
gidIndex(gidIndex),
gidIndex(env()->ibGidIndex > 0 ? env()->ibGidIndex : gidIndex),
maxCqSize(maxCqSize),
maxCqPollNum(maxCqPollNum),
maxSendWr(maxSendWr),
Expand Down
4 changes: 4 additions & 0 deletions include/mscclpp/env.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ class Env {
/// Default is false.
const bool forceDisableNvls;

/// Env name: `MSCCLPP_IB_GID_INDEX`. The GID index to use for IB transport.
/// Default is 0 (`EndpointConfig::Ib::DefaultGidIndex`).
const int ibGidIndex;

private:
Env();

Expand Down
7 changes: 6 additions & 1 deletion python/csrc/env_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ void register_env(nb::module_& m) {
.def_ro("ibv_mode", &Env::ibvMode)
.def_ro("cache_dir", &Env::cacheDir)
.def_ro("npkit_dump_dir", &Env::npkitDumpDir)
.def_ro("cuda_ipc_use_default_stream", &Env::cudaIpcUseDefaultStream);
.def_ro("cuda_ipc_use_default_stream", &Env::cudaIpcUseDefaultStream)
.def_ro("nccl_shared_lib_path", &Env::ncclSharedLibPath)
.def_ro("force_nccl_fallback_operation", &Env::forceNcclFallbackOperation)
.def_ro("nccl_symmetric_memory", &Env::ncclSymmetricMemory)
.def_ro("force_disable_nvls", &Env::forceDisableNvls)
.def_ro("ib_gid_index", &Env::ibGidIndex);

m.def("env", &env);
}
4 changes: 3 additions & 1 deletion src/core/env.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ Env::Env()
ncclSharedLibPath(readEnv<std::string>("MSCCLPP_NCCL_LIB_PATH", "")),
forceNcclFallbackOperation(readEnv<std::string>("MSCCLPP_FORCE_NCCL_FALLBACK_OPERATION", "")),
ncclSymmetricMemory(readEnv<bool>("MSCCLPP_NCCL_SYMMETRIC_MEMORY", false)),
forceDisableNvls(readEnv<bool>("MSCCLPP_FORCE_DISABLE_NVLS", false)) {}
forceDisableNvls(readEnv<bool>("MSCCLPP_FORCE_DISABLE_NVLS", false)),
ibGidIndex(readEnv<int>("MSCCLPP_IB_GID_INDEX", 0)) {}

std::shared_ptr<Env> env() {
static std::shared_ptr<Env> globalEnv = std::shared_ptr<Env>(new Env());
Expand Down Expand Up @@ -93,6 +94,7 @@ std::shared_ptr<Env> env() {
logEnv("MSCCLPP_FORCE_NCCL_FALLBACK_OPERATION", globalEnv->forceNcclFallbackOperation);
logEnv("MSCCLPP_NCCL_SYMMETRIC_MEMORY", globalEnv->ncclSymmetricMemory);
logEnv("MSCCLPP_FORCE_DISABLE_NVLS", globalEnv->forceDisableNvls);
logEnv("MSCCLPP_IB_GID_INDEX", globalEnv->ibGidIndex);
}
return globalEnv;
}
Expand Down
45 changes: 27 additions & 18 deletions src/core/executor/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ namespace mscclpp {

struct ExecutionContext {
std::shared_ptr<ProxyService> proxyService;
std::unordered_map<int, Connection> connections;
std::vector<Connection> connections;
std::vector<std::shared_ptr<NvlsConnection>> nvlsConnections;
MemoryId localMemoryIdBegin = MemoryId(0);

Expand All @@ -121,8 +121,6 @@ struct ExecutionContext {
// local registered memories to keep resources alive
std::vector<mscclpp::RegisteredMemory> localRegisteredMemories;

std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memorySemaphores;
std::vector<mscclpp::SemaphoreId> proxySemaphores;
std::vector<mscclpp::BaseMemoryChannel> memoryChannels;
std::vector<mscclpp::BasePortChannel> portChannels;
std::vector<mscclpp::SwitchChannel> nvlsChannels;
Expand Down Expand Up @@ -266,15 +264,28 @@ struct Executor::Impl {
}
};

std::vector<int> connectedPeers = plan.impl_->getConnectedPeers();
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures;
for (int peer : connectedPeers) {
Transport transport =
!useIB(rank, peer, this->nranksPerNode) ? Transport::CudaIpc : IBs[rank % this->nranksPerNode];
connectionFutures.push_back(this->comm->connect(transport, peer));
std::unordered_map<int, int> peerTags;
Transport ibTransport = IBs[rank % this->nranksPerNode];
std::vector<std::shared_future<Connection>> connFutures;
for (ChannelType channelType : {ChannelType::MEMORY, ChannelType::PORT}) {
std::vector<ChannelInfo> channelInfos = plan.impl_->getChannelInfos(channelType);
for (const auto& info : channelInfos) {
for (int peer : info.connectedPeers) {
Transport transport = useIB(rank, peer, this->nranksPerNode) ? ibTransport : Transport::CudaIpc;
connFutures.push_back(this->comm->connect(transport, peer, peerTags[peer]++));
}
}
channelInfos = plan.impl_->getUnpairedChannelInfos(nranks, channelType);
for (const auto& info : channelInfos) {
for (int peer : info.connectedPeers) {
Transport transport = useIB(rank, peer, this->nranksPerNode) ? ibTransport : Transport::CudaIpc;
connFutures.push_back(this->comm->connect(transport, peer, peerTags[peer]++));
}
}
}
for (size_t i = 0; i < connectionFutures.size(); i++) {
context.connections[connectedPeers[i]] = connectionFutures[i].get();

for (auto& future : connFutures) {
context.connections.push_back(future.get());
}

std::vector<NvlsInfo> nvlsInfos = plan.impl_->nvlsInfos.at(rank);
Expand Down Expand Up @@ -328,10 +339,11 @@ struct Executor::Impl {
std::vector<std::shared_future<Semaphore>> futureProxySemaphores;
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> memorySemaphores;
std::vector<mscclpp::SemaphoreId> proxySemaphores;
int connIdx = 0;
auto processChannelInfos = [&](std::vector<ChannelInfo>& channelInfos) {
for (ChannelInfo& info : channelInfos) {
for (int peer : info.connectedPeers) {
auto connection = context.connections.at(peer);
for (size_t i = 0; i < info.connectedPeers.size(); i++) {
auto& connection = context.connections[connIdx++];
if (info.channelType == ChannelType::MEMORY) {
futureMemorySemaphores.push_back(this->comm->buildSemaphore(
connection, this->comm->remoteRankOf(connection), this->comm->tagOf(connection)));
Expand Down Expand Up @@ -360,18 +372,15 @@ struct Executor::Impl {
proxySemaphores.push_back(context.proxyService->addSemaphore(sem.get()));
}

context.memorySemaphores = std::move(memorySemaphores);
context.proxySemaphores = std::move(proxySemaphores);

for (ChannelType channelType : channelTypes) {
std::vector<ChannelInfo> channelInfos = plan.impl_->getChannelInfos(channelType);
int index = 0;
for (ChannelInfo& info : channelInfos) {
for (size_t i = 0; i < info.connectedPeers.size(); i++) {
if (channelType == ChannelType::MEMORY) {
context.memoryChannels.emplace_back(context.memorySemaphores[index++]);
context.memoryChannels.emplace_back(memorySemaphores[index++]);
} else if (channelType == ChannelType::PORT) {
context.portChannels.emplace_back(context.proxyService->basePortChannel(context.proxySemaphores[index++]));
context.portChannels.emplace_back(context.proxyService->basePortChannel(proxySemaphores[index++]));
}
}
}
Expand Down
Loading