Skip to content

[fix](kt): Fix barrier call to use cpu_group instead of device_group#27

Open
SCDESPERTATE wants to merge 1 commit intokvcache-ai:mainfrom
SCDESPERTATE:fix/layerwise_prefill_barrier_usage
Open

[fix](kt): Fix barrier call to use cpu_group instead of device_group#27
SCDESPERTATE wants to merge 1 commit intokvcache-ai:mainfrom
SCDESPERTATE:fix/layerwise_prefill_barrier_usage

Conversation

@SCDESPERTATE
Copy link
Copy Markdown

Summary

The comment in the GroupCoordinator.barrier method in parallel_state.py explicitly states: "don't use device_group here." However, the current implementation frequently uses device_group in dist.barrier calls, which introduces unnecessary latency during the expert transfer process in layerwise-prefill. This latency scales with both the number of layers and the number of experts per layer, degrading performance.

Profling

On a system with Quadro RTX 5000x1 and PyTorch 2.7.1, using device_group in dist.barrier adds around 75 microseconds of latency per expert transfer compare with using cpu_group. For a model like Qwen3-30B-A3B (128 experts per layer, 48 layers), this accumulates to nearly 460.8ms of additional overhead.

Modifications

Replace the explicit dist.barrier(group=get_tp_group().device_group) call with the already wrapped GroupCoordinator.barrier method which use cpu_group inside.

@gemini-code-assist
Copy link
Copy Markdown

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a performance bottleneck in the expert transfer process by correcting the synchronization mechanism. By ensuring that barrier calls utilize the CPU group instead of the device group, it eliminates unnecessary latency that previously scaled with model complexity, leading to a substantial improvement in overall efficiency for distributed models.

Highlights

  • Performance Improvement: Replaced dist.barrier calls that incorrectly used device_group with get_tp_group().barrier(), which internally uses cpu_group. This change significantly reduces latency during expert transfer in layerwise-prefill, especially for models with many layers and experts.
  • Refactored Broadcast: Updated the object broadcasting mechanism from dist.broadcast_object_list to the more streamlined get_tp_group().broadcast_object for shared memory unique IDs.
  • Latency Reduction: Profiling showed that using device_group in dist.barrier added approximately 75 microseconds of latency per expert transfer, accumulating to nearly 460.8ms for a Qwen3-30B-A3B model.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/moe/kt_ep_wrapper.py
    • Replaced direct dist.broadcast_object_list with get_tp_group().broadcast_object for shared memory unique ID broadcasting.
    • Changed dist.barrier(group=get_tp_group().device_group) to get_tp_group().barrier() in multiple locations to use the CPU group for synchronization.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces two main changes to the kt_ep_wrapper.py file. First, it refactors a dist.broadcast_object_list call to use the get_tp_group().broadcast_object helper, which improves code readability and abstraction. Second, and more importantly, it replaces multiple calls to dist.barrier using the device_group with get_tp_group().barrier(), which correctly uses the cpu_group. As detailed in the pull request description and confirmed by the implementation in parallel_state.py, this change avoids the high latency associated with NCCL barriers for synchronization, leading to a significant performance improvement. The changes are correct and well-aligned with the goal of reducing latency in the expert transfer process.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant