diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index 37bdbd514..16e8858a3 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -430,7 +431,7 @@ struct EndpointConfig { int maxWrPerSend = DefaultMaxWrPerSend, Mode mode = Mode::Default) : deviceIndex(deviceIndex), port(port), - gidIndex(gidIndex), + gidIndex(env()->ibGidIndex >= 0 ? env()->ibGidIndex : gidIndex), maxCqSize(maxCqSize), maxCqPollNum(maxCqPollNum), maxSendWr(maxSendWr), diff --git a/include/mscclpp/env.hpp b/include/mscclpp/env.hpp index 39f73e8d8..696255d23 100644 --- a/include/mscclpp/env.hpp +++ b/include/mscclpp/env.hpp @@ -110,6 +110,11 @@ class Env { /// Default is false. const bool forceDisableNvls; + /// Env name: `MSCCLPP_IB_GID_INDEX`. The GID index to use for IB transport. + /// When set to a non-negative value, overrides the `gidIndex` parameter passed to `EndpointConfig::Ib`. + /// Default is -1 (unset, uses the constructor argument which defaults to `EndpointConfig::Ib::DefaultGidIndex`). + const int ibGidIndex; + private: Env(); diff --git a/python/csrc/env_py.cpp b/python/csrc/env_py.cpp index ce89fd3da..c1d465ae9 100644 --- a/python/csrc/env_py.cpp +++ b/python/csrc/env_py.cpp @@ -23,7 +23,12 @@ void register_env(nb::module_& m) { .def_ro("ibv_mode", &Env::ibvMode) .def_ro("cache_dir", &Env::cacheDir) .def_ro("npkit_dump_dir", &Env::npkitDumpDir) - .def_ro("cuda_ipc_use_default_stream", &Env::cudaIpcUseDefaultStream); + .def_ro("cuda_ipc_use_default_stream", &Env::cudaIpcUseDefaultStream) + .def_ro("nccl_shared_lib_path", &Env::ncclSharedLibPath) + .def_ro("force_nccl_fallback_operation", &Env::forceNcclFallbackOperation) + .def_ro("nccl_symmetric_memory", &Env::ncclSymmetricMemory) + .def_ro("force_disable_nvls", &Env::forceDisableNvls) + .def_ro("ib_gid_index", &Env::ibGidIndex); m.def("env", &env); } diff --git a/src/core/env.cpp b/src/core/env.cpp index 484b40af1..4e25bfeaf 100644 --- a/src/core/env.cpp +++ b/src/core/env.cpp @@ -65,7 +65,8 @@ Env::Env() ncclSharedLibPath(readEnv("MSCCLPP_NCCL_LIB_PATH", "")), forceNcclFallbackOperation(readEnv("MSCCLPP_FORCE_NCCL_FALLBACK_OPERATION", "")), ncclSymmetricMemory(readEnv("MSCCLPP_NCCL_SYMMETRIC_MEMORY", false)), - forceDisableNvls(readEnv("MSCCLPP_FORCE_DISABLE_NVLS", false)) {} + forceDisableNvls(readEnv("MSCCLPP_FORCE_DISABLE_NVLS", false)), + ibGidIndex(readEnv("MSCCLPP_IB_GID_INDEX", -1)) {} std::shared_ptr env() { static std::shared_ptr globalEnv = std::shared_ptr(new Env()); @@ -93,6 +94,7 @@ std::shared_ptr env() { logEnv("MSCCLPP_FORCE_NCCL_FALLBACK_OPERATION", globalEnv->forceNcclFallbackOperation); logEnv("MSCCLPP_NCCL_SYMMETRIC_MEMORY", globalEnv->ncclSymmetricMemory); logEnv("MSCCLPP_FORCE_DISABLE_NVLS", globalEnv->forceDisableNvls); + logEnv("MSCCLPP_IB_GID_INDEX", globalEnv->ibGidIndex); } return globalEnv; } diff --git a/src/core/executor/executor.cc b/src/core/executor/executor.cc index bf2caf97f..9229f9ac8 100644 --- a/src/core/executor/executor.cc +++ b/src/core/executor/executor.cc @@ -109,7 +109,7 @@ namespace mscclpp { struct ExecutionContext { std::shared_ptr proxyService; - std::unordered_map connections; + std::vector connections; std::vector> nvlsConnections; MemoryId localMemoryIdBegin = MemoryId(0); @@ -121,8 +121,6 @@ struct ExecutionContext { // local registered memories to keep resources alive std::vector localRegisteredMemories; - std::vector> memorySemaphores; - std::vector proxySemaphores; std::vector memoryChannels; std::vector portChannels; std::vector nvlsChannels; @@ -266,15 +264,28 @@ struct Executor::Impl { } }; - std::vector connectedPeers = plan.impl_->getConnectedPeers(); - std::vector> connectionFutures; - for (int peer : connectedPeers) { - Transport transport = - !useIB(rank, peer, this->nranksPerNode) ? Transport::CudaIpc : IBs[rank % this->nranksPerNode]; - connectionFutures.push_back(this->comm->connect(transport, peer)); + std::unordered_map peerTags; + Transport ibTransport = IBs[rank % this->nranksPerNode]; + std::vector> connFutures; + for (ChannelType channelType : {ChannelType::MEMORY, ChannelType::PORT}) { + std::vector channelInfos = plan.impl_->getChannelInfos(channelType); + for (const auto& info : channelInfos) { + for (int peer : info.connectedPeers) { + Transport transport = useIB(rank, peer, this->nranksPerNode) ? ibTransport : Transport::CudaIpc; + connFutures.push_back(this->comm->connect(transport, peer, peerTags[peer]++)); + } + } + channelInfos = plan.impl_->getUnpairedChannelInfos(nranks, channelType); + for (const auto& info : channelInfos) { + for (int peer : info.connectedPeers) { + Transport transport = useIB(rank, peer, this->nranksPerNode) ? ibTransport : Transport::CudaIpc; + connFutures.push_back(this->comm->connect(transport, peer, peerTags[peer]++)); + } + } } - for (size_t i = 0; i < connectionFutures.size(); i++) { - context.connections[connectedPeers[i]] = connectionFutures[i].get(); + + for (auto& future : connFutures) { + context.connections.push_back(future.get()); } std::vector nvlsInfos = plan.impl_->nvlsInfos.at(rank); @@ -328,10 +339,11 @@ struct Executor::Impl { std::vector> futureProxySemaphores; std::vector> memorySemaphores; std::vector proxySemaphores; + int connIdx = 0; auto processChannelInfos = [&](std::vector& channelInfos) { for (ChannelInfo& info : channelInfos) { - for (int peer : info.connectedPeers) { - auto connection = context.connections.at(peer); + for (size_t i = 0; i < info.connectedPeers.size(); i++) { + auto& connection = context.connections[connIdx++]; if (info.channelType == ChannelType::MEMORY) { futureMemorySemaphores.push_back(this->comm->buildSemaphore( connection, this->comm->remoteRankOf(connection), this->comm->tagOf(connection))); @@ -360,18 +372,15 @@ struct Executor::Impl { proxySemaphores.push_back(context.proxyService->addSemaphore(sem.get())); } - context.memorySemaphores = std::move(memorySemaphores); - context.proxySemaphores = std::move(proxySemaphores); - for (ChannelType channelType : channelTypes) { std::vector channelInfos = plan.impl_->getChannelInfos(channelType); int index = 0; for (ChannelInfo& info : channelInfos) { for (size_t i = 0; i < info.connectedPeers.size(); i++) { if (channelType == ChannelType::MEMORY) { - context.memoryChannels.emplace_back(context.memorySemaphores[index++]); + context.memoryChannels.emplace_back(memorySemaphores[index++]); } else if (channelType == ChannelType::PORT) { - context.portChannels.emplace_back(context.proxyService->basePortChannel(context.proxySemaphores[index++])); + context.portChannels.emplace_back(context.proxyService->basePortChannel(proxySemaphores[index++])); } } }