diff --git a/include/mscclpp/proxy_channel_device.hpp b/include/mscclpp/proxy_channel_device.hpp index 2b91dcec4..200b0caec 100644 --- a/include/mscclpp/proxy_channel_device.hpp +++ b/include/mscclpp/proxy_channel_device.hpp @@ -114,6 +114,11 @@ struct ProxyChannelDeviceHandle { /// Push a @ref TriggerFlag to the FIFO. MSCCLPP_DEVICE_INLINE void signal() { fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, semaphoreId_).value); } + + MSCCLPP_DEVICE_INLINE void signal(const uint64_t count) { + for (uint64_t i = 0; i < count; ++i) + fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, semaphoreId_).value); + } /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. /// @param dst The destination memory region. @@ -165,8 +170,9 @@ struct ProxyChannelDeviceHandle { } /// Check if the proxy channel has been signaled. + /// @param max_poll The max number of signals to poll. /// @return true if the proxy channel has been signaled. - MSCCLPP_DEVICE_INLINE bool poll() { return semaphore_.poll(); } + MSCCLPP_DEVICE_INLINE uint64_t poll(const int64_t max_poll = 1) { return semaphore_.poll(max_poll); } /// Wait for the proxy channel to be signaled. /// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative. @@ -195,7 +201,7 @@ struct SimpleProxyChannelDeviceHandle { MSCCLPP_DEVICE_INLINE void put(uint64_t offset, uint64_t size) { put(offset, offset, size); } /// Push a @ref TriggerFlag to the FIFO. - MSCCLPP_DEVICE_INLINE void signal() { proxyChan_.signal(); } + MSCCLPP_DEVICE_INLINE void signal(const uint64_t count = 1) { proxyChan_.signal(count); } /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. /// @param dstOffset The offset into the destination memory region. @@ -229,8 +235,9 @@ struct SimpleProxyChannelDeviceHandle { MSCCLPP_DEVICE_INLINE void flush() { proxyChan_.flush(); } /// Check if the proxy channel has been signaled. + /// @param max_poll The max number of signals to poll. /// @return true if the proxy channel has been signaled. - MSCCLPP_DEVICE_INLINE bool poll() { return proxyChan_.poll(); } + MSCCLPP_DEVICE_INLINE uint64_t poll(const int64_t max_poll = 1) { return proxyChan_.poll(max_poll); } /// Wait for the proxy channel to be signaled. /// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative. diff --git a/include/mscclpp/semaphore_device.hpp b/include/mscclpp/semaphore_device.hpp index cd455078a..909abe6f8 100644 --- a/include/mscclpp/semaphore_device.hpp +++ b/include/mscclpp/semaphore_device.hpp @@ -17,11 +17,18 @@ namespace mscclpp { struct Host2DeviceSemaphoreDeviceHandle { #if defined(MSCCLPP_DEVICE_COMPILE) /// Poll if the host has signaled. - /// @return true if the host has signaled. - MSCCLPP_DEVICE_INLINE bool poll() { - bool signaled = (atomicLoad(inboundSemaphoreId, memoryOrderAcquire) > (*expectedInboundSemaphoreId)); - if (signaled) (*expectedInboundSemaphoreId) += 1; - return signaled; + /// @param max_poll The max number of signals to poll. + /// @return number of signals up to max_poll that the remote device has signaled. + MSCCLPP_DEVICE_INLINE uint64_t poll(const int64_t max_poll = 1) { + if (max_poll <= 0) return 0; + uint64_t count = (atomicLoad(inboundSemaphoreId, memoryOrderAcquire) - (*expectedInboundSemaphoreId)); + if (count <= 0) { + return 0; + } else { + if (max_poll < count) count = max_poll; + *expectedInboundSemaphoreId += count; + return count; + } } /// Wait for the host to signal. @@ -40,11 +47,18 @@ struct Host2DeviceSemaphoreDeviceHandle { struct SmDevice2DeviceSemaphoreDeviceHandle { #if defined(MSCCLPP_DEVICE_COMPILE) /// Poll if the remote device has signaled. - /// @return true if the remote device has signaled. - MSCCLPP_DEVICE_INLINE bool poll() { - bool signaled = (atomicLoad(inboundSemaphoreId, memoryOrderAcquire) > (*expectedInboundSemaphoreId)); - if (signaled) (*expectedInboundSemaphoreId) += 1; - return signaled; + /// @param max_poll The max number of signals to poll. + /// @return number of signals up to max_poll that the remote device has signaled. + MSCCLPP_DEVICE_INLINE uint64_t poll(const int64_t max_poll = 1) { + if (max_poll <= 0) return 0; + uint64_t count = (atomicLoad(inboundSemaphoreId, memoryOrderAcquire) - (*expectedInboundSemaphoreId)); + if (count <= 0) { + return 0; + } else { + if (max_poll < count) count = max_poll; + *expectedInboundSemaphoreId += count; + return count; + } } /// Wait for the remote device to signal. @@ -59,10 +73,10 @@ struct SmDevice2DeviceSemaphoreDeviceHandle { /// This function guarantees that all the memory operation before this function is completed before the remote /// semaphore is signaled. /// - MSCCLPP_DEVICE_INLINE void signal() { + MSCCLPP_DEVICE_INLINE void signal(const uint64_t count = 1) { // This fence ensures that preceding writes are visible on the peer GPU before the incremented // `outboundSemaphoreId` is visible. - semaphoreIncrement(); + semaphoreIncrement(count); atomicStore(remoteInboundSemaphoreId, semaphoreGetLocal(), memoryOrderSeqCst); } @@ -90,7 +104,7 @@ struct SmDevice2DeviceSemaphoreDeviceHandle { } /// Increase the counter of the local semaphore. - MSCCLPP_DEVICE_INLINE void semaphoreIncrement() { *outboundSemaphoreId += 1; } + MSCCLPP_DEVICE_INLINE void semaphoreIncrement(const uint64_t count = 1) { *outboundSemaphoreId += count; } /// Get the value of the local semaphore. MSCCLPP_DEVICE_INLINE uint64_t semaphoreGetLocal() const { return *outboundSemaphoreId; } diff --git a/include/mscclpp/sm_channel_device.hpp b/include/mscclpp/sm_channel_device.hpp index e49a431b7..99987fbfb 100644 --- a/include/mscclpp/sm_channel_device.hpp +++ b/include/mscclpp/sm_channel_device.hpp @@ -243,7 +243,7 @@ struct SmChannelDeviceHandle { /// This function guarantees that all the memory operation before this function is completed before the remote /// semaphore is signaled. /// - MSCCLPP_DEVICE_INLINE void signal() { semaphore_.signal(); } + MSCCLPP_DEVICE_INLINE void signal(uint64_t count = 1) { semaphore_.signal(count); } /// Signal the remote semaphore. /// @@ -267,8 +267,9 @@ struct SmChannelDeviceHandle { MSCCLPP_DEVICE_INLINE uint64_t semaphoreGetLocal() const { return semaphore_.semaphoreGetLocal(); } /// Check if the remote semaphore has signaled. + /// @param max_poll The max number of signals to poll. /// @return true if the remote semaphore has signaled. - MSCCLPP_DEVICE_INLINE bool poll() { return semaphore_.poll(); } + MSCCLPP_DEVICE_INLINE uint64_t poll(const int64_t max_poll = 1) { return semaphore_.poll(max_poll); } /// Wait for the remote semaphore to send a signal. /// @param maxSpinCount The maximum number of spins before asserting. Never assert if negative.