Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions slime_multi_node_rl/Dockerfile.anyscale
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
FROM anyscale/ray:2.54.0-py312-cu129

ARG PATCH_VERSION=latest
ARG MEGATRON_COMMIT=3714d81d418c9f1bca4594fc35f9e8289f652862
ARG SGLANG_COMMIT=dce8b0606c06d3a191a24c7b8cbe8e238ab316c9
ARG SLIME_REF=main

# Anyscale base image runs as non-root; switch to root for system installs.
USER root
WORKDIR /root

RUN apt-get update && \
apt-get install -y --no-install-recommends git rsync dnsutils && \
rm -rf /var/lib/apt/lists/*

# Keep pip tooling current and pin numpy to 1.x for Megatron compatibility.
RUN python -m pip install --upgrade pip setuptools wheel && \
python -m pip install "numpy<2" huggingface_hub

# Downgrade PyTorch from 2.9.1 (Anyscale base) to 2.7.1 for compatibility
# with pre-built flash-attn and transformer_engine wheels.
RUN python -m pip install torch==2.7.1 torchvision torchaudio \
--index-url https://download.pytorch.org/whl/cu128

# Pre-built flash-attn wheel for torch 2.7 + cu12 (source compilation
# exceeds Anyscale's ~60 min build timeout).
RUN python -m pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1%2Bcu12torch2.7cxx11abiTRUE-cp312-cp312-linux_x86_64.whl

# Apex: install Python-only (no CUDA extensions) to stay within Anyscale's
# ~60 min build timeout. Megatron falls back to PyTorch-native kernels.
RUN git clone --filter=blob:none https://github.com/NVIDIA/apex.git /tmp/apex && \
cd /tmp/apex && \
git checkout 10417aceddd7d5d05d7cbf7b0fc2daad1105f8b4 && \
python -m pip install --disable-pip-version-check --no-cache-dir \
--no-build-isolation . && \
rm -rf /tmp/apex

# Install SGLang from source and apply slime patch.
RUN git clone https://github.com/sgl-project/sglang.git /root/sglang && \
cd /root/sglang && \
git checkout ${SGLANG_COMMIT} && \
python -m pip install -e "python[all]"

# Install Megatron from source and apply slime patch.
RUN git clone --recursive https://github.com/NVIDIA/Megatron-LM.git /root/Megatron-LM && \
cd /root/Megatron-LM && \
git checkout ${MEGATRON_COMMIT} && \
python -m pip install -e .

# Pull slime source for patches and dependency manifests.
RUN git clone https://github.com/THUDM/slime.git /tmp/slime && \
cd /tmp/slime && \
git checkout ${SLIME_REF}

RUN cd /root/sglang && \
cp /tmp/slime/docker/patch/${PATCH_VERSION}/sglang.patch ./sglang.patch && \
git update-index --refresh && \
git apply sglang.patch --3way && \
if grep -R -n '^<<<<<<< ' .; then \
echo "SGLang patch failed to apply cleanly. Please resolve conflicts." && \
exit 1; \
fi && \
rm sglang.patch

RUN cd /root/Megatron-LM && \
cp /tmp/slime/docker/patch/${PATCH_VERSION}/megatron.patch ./megatron.patch && \
git update-index --refresh && \
git apply megatron.patch --3way && \
if grep -R -n '^<<<<<<< ' .; then \
echo "Megatron patch failed to apply cleanly. Please resolve conflicts." && \
exit 1; \
fi && \
rm megatron.patch

RUN python -m pip install git+https://github.com/ISEEKYAN/mbridge.git@89eb10887887bc74853f89a4de258c0702932a1c --no-deps && \
python -m pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@dc6876905830430b5054325fa4211ff302169c6b --no-cache-dir --force-reinstall && \
python -m pip install git+https://github.com/fzyzcjy/Megatron-Bridge.git@dev_rl --no-build-isolation && \
python -m pip install "nvidia-modelopt[torch]>=0.37.0" --no-build-isolation

RUN python -m pip install -r /tmp/slime/requirements.txt && \
python -m pip install -e /tmp/slime --no-deps && \
cd /tmp/slime/slime/backends/megatron_utils/kernels/int4_qat && \
python -m pip install . --no-build-isolation

# Re-pin PyTorch 2.7.1 and reinstall flash-attn + TE at the end.
# Earlier installs (SGLang, modelopt, etc.) may have upgraded torch,
# breaking pre-built binary wheels. Re-pinning here ensures all
# native extensions (flash-attn, TE) match the same PyTorch ABI.
RUN python -c "import torch; print(f'Before re-pin: PyTorch {torch.__version__}')"
RUN python -m pip install torch==2.7.1 torchvision torchaudio \
--index-url https://download.pytorch.org/whl/cu128
RUN python -m pip install --force-reinstall --no-deps \
https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1%2Bcu12torch2.7cxx11abiTRUE-cp312-cp312-linux_x86_64.whl
RUN python -m pip install --no-build-isolation "transformer_engine[pytorch]==2.10.0"
# Reinstall sgl_kernel from PyPI (needed for package metadata), then patch
# load_utils.py to return a lazy stub instead of crashing when no compatible
# architecture ops are found. PyPI sgl_kernel only ships SM100 ops and is
# compiled against torch 2.9.1, so it can't load on A10G (SM86) + torch 2.7.1.
# SGLang engines on A10G don't use sgl_kernel FP8 ops; the RolloutManager
# (CPU coordinator) just needs the import to succeed without crashing.
RUN python -m pip install --force-reinstall sgl_kernel
# Comprehensive sgl_kernel compat patch:
# Part 1: Patch load_utils.py to return a lazy stub (handles Python imports)
# Part 2: Create a .pth startup hook that monkeypatches torch.library.register_fake
# to silently skip when sgl_kernel C++ operators are missing (handles torch ops)
RUN echo 'aW1wb3J0IHBhdGhsaWIsIHN5cywgaW1wb3J0bGliLnV0aWwsIHJlCgojID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PQojIFBhcnQgMTogUGF0Y2ggc2dsX2tlcm5lbC9sb2FkX3V0aWxzLnB5IHRvIHJldHVybiBhIHN0dWIgaW5zdGVhZCBvZgojICAgICAgICAgIGNyYXNoaW5nIHdoZW4gbm8gY29tcGF0aWJsZSBHUFUgb3BzIGFyZSBmb3VuZC4KIyA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT0Kc3BlYyA9IGltcG9ydGxpYi51dGlsLmZpbmRfc3BlYygnc2dsX2tlcm5lbCcpCmlmIG5vdCBzcGVjIG9yIG5vdCBzcGVjLnN1Ym1vZHVsZV9zZWFyY2hfbG9jYXRpb25zOgogICAgcHJpbnQoInNnbF9rZXJuZWwgcGFja2FnZSBub3QgZm91bmQsIHNraXBwaW5nIHBhdGNoIikKICAgIHN5cy5leGl0KDApCgpwa2dfZGlyID0gcGF0aGxpYi5QYXRoKHNwZWMuc3VibW9kdWxlX3NlYXJjaF9sb2NhdGlvbnNbMF0pCnAgPSBwa2dfZGlyIC8gJ2xvYWRfdXRpbHMucHknCmlmIG5vdCBwLmV4aXN0cygpOgogICAgcHJpbnQoZiJsb2FkX3V0aWxzLnB5IG5vdCBmb3VuZCBhdCB7cH0sIHNraXBwaW5nIHBhdGNoIikKICAgIHN5cy5leGl0KDApCgpzcmMgPSBwLnJlYWRfdGV4dCgpCgptYXRjaCA9IHJlLnNlYXJjaChyJ14oICopcmFpc2UgSW1wb3J0RXJyb3JcKGVycm9yX21zZ1wpJywgc3JjLCByZS5NVUxUSUxJTkUpCmlmIG5vdCBtYXRjaDoKICAgIHByaW50KGYiV0FSTklORzogJ3JhaXNlIEltcG9ydEVycm9yKGVycm9yX21zZyknIG5vdCBmb3VuZCBpbiB7cH0iKQogICAgc3lzLmV4aXQoMSkKCmluZGVudCA9IG1hdGNoLmdyb3VwKDEpCnRhcmdldCA9ICdyYWlzZSBJbXBvcnRFcnJvcihlcnJvcl9tc2cpJwoKcmVwbCA9ICdcbicuam9pbihbCiAgICAnaW1wb3J0IHdhcm5pbmdzJywKICAgIGYne2luZGVudH13YXJuaW5ncy53YXJuKCJzZ2xfa2VybmVsOiBubyBjb21wYXRpYmxlIG9wcyBmb3IgdGhpcyBHUFU7ICInLAogICAgZid7aW5kZW50fSAgICAgICAgICAgICAgIkNVREEga2VybmVscyB3aWxsIGJlIHVuYXZhaWxhYmxlLiIsIFJ1bnRpbWVXYXJuaW5nKScsCiAgICBmJ3tpbmRlbnR9Y2xhc3MgX0xhenlTdHViOicsCiAgICBmJ3tpbmRlbnR9ICAgIGRlZiBfX2dldGF0dHJfXyhzZWxmLCBuYW1lKTonLAogICAgZid7aW5kZW50fSAgICAgICAgZGVmIF91bmF2YWlsYWJsZSgqYSwgKiprdyk6JywKICAgIGYne2luZGVudH0gICAgICAgICAgICByYWlzZSBSdW50aW1lRXJyb3IoZiJzZ2xfa2VybmVsLnt7bmFtZX19IHVuYXZhaWxhYmxlIG9uIHRoaXMgR1BVIGFyY2giKScsCiAgICBmJ3tpbmRlbnR9ICAgICAgICByZXR1cm4gX3VuYXZhaWxhYmxlJywKICAgIGYne2luZGVudH1yZXR1cm4gX0xhenlTdHViKCknLApdKQoKc3JjID0gc3JjLnJlcGxhY2UodGFyZ2V0LCByZXBsKQpwLndyaXRlX3RleHQoc3JjKQpwcmludChmIlBhdGNoZWQge3B9IChpbmRlbnQ9e2xlbihpbmRlbnQpfSBzcGFjZXMpIikKCiMgQ2xlYXIgYnl0ZWNvZGUgY2FjaGUgZm9yIGxvYWRfdXRpbHMKcHljYWNoZSA9IHBrZ19kaXIgLyAnX19weWNhY2hlX18nCmlmIHB5Y2FjaGUuZXhpc3RzKCk6CiAgICBmb3IgcHljIGluIHB5Y2FjaGUuZ2xvYignbG9hZF91dGlscyonKToKICAgICAgICBweWMudW5saW5rKCkKCiMgPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09CiMgUGFydCAyOiBDcmVhdGUgYSBzaXRlLXBhY2thZ2VzIG1vZHVsZSB0aGF0IG1vbmtleXBhdGNoZXMKIyAgICAgICAgICB0b3JjaC5saWJyYXJ5LnJlZ2lzdGVyX2Zha2UgdG8gc2lsZW50bHkgc2tpcCB3aGVuIHRoZQojICAgICAgICAgIHVuZGVybHlpbmcgQysrIG9wZXJhdG9yIGRvZXNuJ3QgZXhpc3QuCiMgPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09CnNpdGVfZGlyID0gcGtnX2Rpci5wYXJlbnQgICMgc2l0ZS1wYWNrYWdlcyBkaXJlY3RvcnkKCmNvbXBhdF9tb2R1bGUgPSBzaXRlX2RpciAvICdfc2dsX2tlcm5lbF9jb21wYXQucHknCmNvbXBhdF9tb2R1bGUud3JpdGVfdGV4dCgnJydcCiIiIk1vbmtleXBhdGNoIHRvcmNoLmxpYnJhcnkucmVnaXN0ZXJfZmFrZSB0byBoYW5kbGUgbWlzc2luZyBzZ2xfa2VybmVsIG9wcy4iIiIKaW1wb3J0IHRvcmNoCgpfb3JpZ19yZWdpc3Rlcl9mYWtlID0gdG9yY2gubGlicmFyeS5yZWdpc3Rlcl9mYWtlCgpkZWYgX3Jlc2lsaWVudF9yZWdpc3Rlcl9mYWtlKG9wX25hbWUsIGZuPU5vbmUsIC8sICoqa3dhcmdzKToKICAgICIiIldyYXBwZXIgdGhhdCByZXR1cm5zIGEgbm8tb3AgZGVjb3JhdG9yIGlmIHRoZSBvcGVyYXRvciBkb2Vzbid0IGV4aXN0LiIiIgogICAgdHJ5OgogICAgICAgIHJldHVybiBfb3JpZ19yZWdpc3Rlcl9mYWtlKG9wX25hbWUsIGZuLCAqKmt3YXJncykKICAgIGV4Y2VwdCBSdW50aW1lRXJyb3IgYXMgZToKICAgICAgICBpZiAiZG9lcyBub3QgZXhpc3QiIGluIHN0cihlKSBhbmQgInNnbF9rZXJuZWwiIGluIHN0cihvcF9uYW1lKToKICAgICAgICAgICAgIyBzZ2xfa2VybmVsIG9wcyBub3QgbG9hZGVkIChhcmNoaXRlY3R1cmUgbWlzbWF0Y2gpOyBza2lwIHJlZ2lzdHJhdGlvbgogICAgICAgICAgICBpZiBmbiBpcyBub3QgTm9uZToKICAgICAgICAgICAgICAgIHJldHVybiBmbgogICAgICAgICAgICByZXR1cm4gbGFtYmRhIGY6IGYKICAgICAgICByYWlzZQoKdG9yY2gubGlicmFyeS5yZWdpc3Rlcl9mYWtlID0gX3Jlc2lsaWVudF9yZWdpc3Rlcl9mYWtlCicnJykKcHJpbnQoZiJDcmVhdGVkIHtjb21wYXRfbW9kdWxlfSIpCgojIENyZWF0ZSAucHRoIGZpbGUgc28gdGhlIGNvbXBhdCBtb2R1bGUgcnVucyBvbiBldmVyeSBQeXRob24gc3RhcnR1cApwdGhfZmlsZSA9IHNpdGVfZGlyIC8gJ19zZ2xfa2VybmVsX2NvbXBhdC5wdGgnCnB0aF9maWxlLndyaXRlX3RleHQoJ2ltcG9ydCBfc2dsX2tlcm5lbF9jb21wYXRcbicpCnByaW50KGYiQ3JlYXRlZCB7cHRoX2ZpbGV9IikK' | base64 -d | python

# Patch Triton's driver.py so that importing transformer_engine does not crash
# on CPU-only nodes/processes (RolloutManager, head-node driver, etc.).
# Triton 3.3.1 (bundled with torch 2.7.1) raises RuntimeError when no GPU
# driver is found; we replace the error with a graceful no-op fallback.
RUN echo 'aW1wb3J0IHBhdGhsaWIsIHN5cywgaW1wb3J0bGliLnV0aWwKCiMgRmluZCB0cml0b24gZHJpdmVyLnB5IFdJVEhPVVQgaW1wb3J0aW5nIGl0IChpbXBvcnQgdHJpZ2dlcnMgdGhlIGNyYXNoKQpzcGVjID0gaW1wb3J0bGliLnV0aWwuZmluZF9zcGVjKCd0cml0b24ucnVudGltZS5kcml2ZXInKQppZiBub3Qgc3BlYyBvciBub3Qgc3BlYy5vcmlnaW46CiAgICBwcmludCgidHJpdG9uLnJ1bnRpbWUuZHJpdmVyIG5vdCBmb3VuZCwgc2tpcHBpbmcgcGF0Y2giKQogICAgc3lzLmV4aXQoMCkKCnAgPSBwYXRobGliLlBhdGgoc3BlYy5vcmlnaW4pCnNyYyA9IHAucmVhZF90ZXh0KCkKdGFyZ2V0ID0gJ3JhaXNlIFJ1bnRpbWVFcnJvcihmIntsZW4oYWN0aXZlcyl9IGFjdGl2ZSBkcml2ZXJzICh7YWN0aXZlc30pLiBUaGVyZSBzaG91bGQgb25seSBiZSBvbmUuIiknCmlmIHRhcmdldCBub3QgaW4gc3JjOgogICAgcHJpbnQoZiJXQVJOSU5HOiBwYXRjaCB0YXJnZXQgbm90IGZvdW5kIGluIHtwfSIpCiAgICBwcmludCgiRmlsZSBjb250ZW50czoiKQogICAgcHJpbnQoc3JjKQogICAgc3lzLmV4aXQoMSkKCnJlcGwgPSAoCiAgICAnaWYgbGVuKGFjdGl2ZXMpID09IDA6XG4nCiAgICAnICAgICAgICAgICAgY2xhc3MgX051bGxEcml2ZXI6XG4nCiAgICAnICAgICAgICAgICAgICAgIGRlZiBnZXRfYmVuY2htYXJrZXIoc2VsZik6IHJldHVybiBsYW1iZGEgKmEsICoqa3c6IDAuMFxuJwogICAgJyAgICAgICAgICAgICAgICBkZWYgZ2V0X2N1cnJlbnRfdGFyZ2V0KHNlbGYpOiByZXR1cm4gTm9uZVxuJwogICAgJyAgICAgICAgICAgICAgICBkZWYgaXNfYWN0aXZlKHNlbGYpOiByZXR1cm4gRmFsc2VcbicKICAgICcgICAgICAgICAgICByZXR1cm4gX051bGxEcml2ZXIoKVxuJwogICAgJyAgICAgICAgJyArIHRhcmdldAopCnAud3JpdGVfdGV4dChzcmMucmVwbGFjZSh0YXJnZXQsIHJlcGwpKQoKIyBDbGVhciBieXRlY29kZSBjYWNoZQpweWNhY2hlID0gcC5wYXJlbnQgLyAnX19weWNhY2hlX18nCmlmIHB5Y2FjaGUuZXhpc3RzKCk6CiAgICBmb3IgcHljIGluIHB5Y2FjaGUuZ2xvYignZHJpdmVyKicpOgogICAgICAgIHB5Yy51bmxpbmsoKQogICAgICAgIHByaW50KGYiUmVtb3ZlZCBjYWNoZWQge3B5Y30iKQoKcHJpbnQoZiJQYXRjaGVkIHtwfSIpCg==' | base64 -d | python

# Verify torch + flash-attn ABI compatibility.
# (TE verification runs at runtime on GPU nodes where Triton patch takes effect.)
RUN python -c "\
import torch; print(f'PyTorch: {torch.__version__}, CUDA: {torch.version.cuda}'); \
assert torch.__version__.startswith('2.7'), f'Expected 2.7.x, got {torch.__version__}'; \
from flash_attn import flash_attn_func; print('flash-attn OK')"

WORKDIR /tmp/slime
93 changes: 93 additions & 0 deletions slime_multi_node_rl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Multi-Node RL Training with Slime

[Slime](https://github.com/THUDM/slime) is an RL training framework that uses Megatron-LM for distributed training with disaggregated rollout via SGLang. This example runs GRPO training of Qwen3-1.7B on **2 workers x 4x A10G** (8 GPUs, 24 GB VRAM each) using Anyscale.

## Cluster Layout

```
Head node (m5.2xlarge): driver only, no GPUs
Worker 0 (4 GPUs): [GPU 0-1: Training TP=2, Stage 0] [GPU 2-3: Rollout]
Worker 1 (4 GPUs): [GPU 0-1: Training TP=2, Stage 1] [GPU 2-3: Rollout]
```

- **Training**: 4 GPUs — TP=2 x PP=2 x DP=1 (Megatron backend, PP spans workers)
- **Rollout**: 3 GPUs — disaggregated SGLang inference, 1 GPU per engine (1 GPU reserved for driver)

## Files

| File | Description |
|------|-------------|
| `job.yaml` | Anyscale job config (`m5.2xlarge` head + 2x `g5.12xlarge` workers) |
| `Dockerfile.anyscale` | Docker image with Slime, Megatron-LM, SGLang, and A10G compatibility patches |
| `anyscale-smoke-2node-a10g.sh` | Anyscale entrypoint (downloads model/data, converts weights, runs training) |
| `patch_all_nodes.py` | Runtime patches for sgl_kernel compatibility on A10G (SM86) |
| `run-qwen3-4B-smoke-2node-a10g.sh` | Bare-metal variant for Qwen3-4B (manual Ray cluster setup) |

## Install the Anyscale CLI

```bash
pip install -U anyscale
anyscale login
```

## Quick Start

Clone the example from GitHub.

```bash
git clone https://github.com/anyscale/examples.git
cd examples/slime_multi_node_rl
```

Submit the job.

```bash
anyscale job submit -f job.yaml
```

The entrypoint automatically:
1. Downloads `Qwen/Qwen3-1.7B` and `zhuzilin/dapo-math-17k` to `/mnt/cluster_storage`
2. Converts HF weights to Megatron torch_dist format (on a GPU worker)
3. Patches all nodes for A10G compatibility (sgl_kernel SM86 workaround)
4. Runs GRPO training with `deepscaler` reward model

## Understanding the Example

- The [Dockerfile.anyscale](Dockerfile.anyscale) builds on `anyscale/ray:2.54.0-py312-cu129` and installs Slime with all dependencies. It downgrades PyTorch from 2.9.1 to 2.7.1 for compatibility with pre-built flash-attn and transformer_engine wheels, and includes patches for sgl_kernel (which only ships SM100 ops from PyPI, incompatible with A10G SM86).
- The entrypoint uses `ray job submit --entrypoint-num-gpus 1` to schedule weight conversion and the training driver on GPU worker nodes (the head node is CPU-only and cannot run Triton/CUDA code).
- Training uses Megatron-LM with TP=2, PP=2 across the two workers. Pipeline parallelism spans nodes via NCCL.
- Rollout uses disaggregated SGLang inference engines (1 GPU each) for generating training samples.

## A10G-Specific Settings

| Setting | Value | Reason |
|---------|-------|--------|
| `NCCL_NVLS_ENABLE` | `0` | No NVLink on cloud A10G |
| `--attention-backend` | `flash` | FA2 only (Ampere, no FA3) |
| `--sglang-attention-backend` | `flashinfer` | For SGLang on Ampere |
| `--max-tokens-per-gpu` | `4096` | Conservative for 24 GB VRAM |
| No FP8 | — | Ampere does not support FP8 |

## Verification

A successful run shows:
- SGLang engine startup on rollout GPUs
- Cross-node NCCL init for pipeline parallelism
- Training loss values printed each step
- Weight sync between training and rollout engines

## If You Hit OOM

**Training GPUs:**
1. `--max-tokens-per-gpu` -> `2048`
2. `--rollout-max-response-len` -> `1024`
3. `--n-samples-per-prompt` -> `2` and `--global-batch-size` -> `16`
4. Add `--optimizer-cpu-offload`

**Rollout GPUs:**
1. `--sglang-mem-fraction-static` -> `0.5`
2. Add `--sglang-chunked-prefill-size 2048`

## View the Job

View the job in the [jobs tab](https://console.anyscale.com/jobs) of the Anyscale console.
Loading