diff --git a/slime_multi_node_rl/Dockerfile.anyscale b/slime_multi_node_rl/Dockerfile.anyscale new file mode 100644 index 0000000..74e55ce --- /dev/null +++ b/slime_multi_node_rl/Dockerfile.anyscale @@ -0,0 +1,121 @@ +FROM anyscale/ray:2.54.0-py312-cu129 + +ARG PATCH_VERSION=latest +ARG MEGATRON_COMMIT=3714d81d418c9f1bca4594fc35f9e8289f652862 +ARG SGLANG_COMMIT=dce8b0606c06d3a191a24c7b8cbe8e238ab316c9 +ARG SLIME_REF=main + +# Anyscale base image runs as non-root; switch to root for system installs. +USER root +WORKDIR /root + +RUN apt-get update && \ + apt-get install -y --no-install-recommends git rsync dnsutils && \ + rm -rf /var/lib/apt/lists/* + +# Keep pip tooling current and pin numpy to 1.x for Megatron compatibility. +RUN python -m pip install --upgrade pip setuptools wheel && \ + python -m pip install "numpy<2" huggingface_hub + +# Downgrade PyTorch from 2.9.1 (Anyscale base) to 2.7.1 for compatibility +# with pre-built flash-attn and transformer_engine wheels. +RUN python -m pip install torch==2.7.1 torchvision torchaudio \ + --index-url https://download.pytorch.org/whl/cu128 + +# Pre-built flash-attn wheel for torch 2.7 + cu12 (source compilation +# exceeds Anyscale's ~60 min build timeout). +RUN python -m pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1%2Bcu12torch2.7cxx11abiTRUE-cp312-cp312-linux_x86_64.whl + +# Apex: install Python-only (no CUDA extensions) to stay within Anyscale's +# ~60 min build timeout. Megatron falls back to PyTorch-native kernels. +RUN git clone --filter=blob:none https://github.com/NVIDIA/apex.git /tmp/apex && \ + cd /tmp/apex && \ + git checkout 10417aceddd7d5d05d7cbf7b0fc2daad1105f8b4 && \ + python -m pip install --disable-pip-version-check --no-cache-dir \ + --no-build-isolation . && \ + rm -rf /tmp/apex + +# Install SGLang from source and apply slime patch. +RUN git clone https://github.com/sgl-project/sglang.git /root/sglang && \ + cd /root/sglang && \ + git checkout ${SGLANG_COMMIT} && \ + python -m pip install -e "python[all]" + +# Install Megatron from source and apply slime patch. +RUN git clone --recursive https://github.com/NVIDIA/Megatron-LM.git /root/Megatron-LM && \ + cd /root/Megatron-LM && \ + git checkout ${MEGATRON_COMMIT} && \ + python -m pip install -e . + +# Pull slime source for patches and dependency manifests. +RUN git clone https://github.com/THUDM/slime.git /tmp/slime && \ + cd /tmp/slime && \ + git checkout ${SLIME_REF} + +RUN cd /root/sglang && \ + cp /tmp/slime/docker/patch/${PATCH_VERSION}/sglang.patch ./sglang.patch && \ + git update-index --refresh && \ + git apply sglang.patch --3way && \ + if grep -R -n '^<<<<<<< ' .; then \ + echo "SGLang patch failed to apply cleanly. Please resolve conflicts." && \ + exit 1; \ + fi && \ + rm sglang.patch + +RUN cd /root/Megatron-LM && \ + cp /tmp/slime/docker/patch/${PATCH_VERSION}/megatron.patch ./megatron.patch && \ + git update-index --refresh && \ + git apply megatron.patch --3way && \ + if grep -R -n '^<<<<<<< ' .; then \ + echo "Megatron patch failed to apply cleanly. Please resolve conflicts." && \ + exit 1; \ + fi && \ + rm megatron.patch + +RUN python -m pip install git+https://github.com/ISEEKYAN/mbridge.git@89eb10887887bc74853f89a4de258c0702932a1c --no-deps && \ + python -m pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@dc6876905830430b5054325fa4211ff302169c6b --no-cache-dir --force-reinstall && \ + python -m pip install git+https://github.com/fzyzcjy/Megatron-Bridge.git@dev_rl --no-build-isolation && \ + python -m pip install "nvidia-modelopt[torch]>=0.37.0" --no-build-isolation + +RUN python -m pip install -r /tmp/slime/requirements.txt && \ + python -m pip install -e /tmp/slime --no-deps && \ + cd /tmp/slime/slime/backends/megatron_utils/kernels/int4_qat && \ + python -m pip install . --no-build-isolation + +# Re-pin PyTorch 2.7.1 and reinstall flash-attn + TE at the end. +# Earlier installs (SGLang, modelopt, etc.) may have upgraded torch, +# breaking pre-built binary wheels. Re-pinning here ensures all +# native extensions (flash-attn, TE) match the same PyTorch ABI. +RUN python -c "import torch; print(f'Before re-pin: PyTorch {torch.__version__}')" +RUN python -m pip install torch==2.7.1 torchvision torchaudio \ + --index-url https://download.pytorch.org/whl/cu128 +RUN python -m pip install --force-reinstall --no-deps \ + https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1%2Bcu12torch2.7cxx11abiTRUE-cp312-cp312-linux_x86_64.whl +RUN python -m pip install --no-build-isolation "transformer_engine[pytorch]==2.10.0" +# Reinstall sgl_kernel from PyPI (needed for package metadata), then patch +# load_utils.py to return a lazy stub instead of crashing when no compatible +# architecture ops are found. PyPI sgl_kernel only ships SM100 ops and is +# compiled against torch 2.9.1, so it can't load on A10G (SM86) + torch 2.7.1. +# SGLang engines on A10G don't use sgl_kernel FP8 ops; the RolloutManager +# (CPU coordinator) just needs the import to succeed without crashing. +RUN python -m pip install --force-reinstall sgl_kernel +# Comprehensive sgl_kernel compat patch: +# Part 1: Patch load_utils.py to return a lazy stub (handles Python imports) +# Part 2: Create a .pth startup hook that monkeypatches torch.library.register_fake +# to silently skip when sgl_kernel C++ operators are missing (handles torch ops) +RUN echo 'aW1wb3J0IHBhdGhsaWIsIHN5cywgaW1wb3J0bGliLnV0aWwsIHJlCgojID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PQojIFBhcnQgMTogUGF0Y2ggc2dsX2tlcm5lbC9sb2FkX3V0aWxzLnB5IHRvIHJldHVybiBhIHN0dWIgaW5zdGVhZCBvZgojICAgICAgICAgIGNyYXNoaW5nIHdoZW4gbm8gY29tcGF0aWJsZSBHUFUgb3BzIGFyZSBmb3VuZC4KIyA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT0Kc3BlYyA9IGltcG9ydGxpYi51dGlsLmZpbmRfc3BlYygnc2dsX2tlcm5lbCcpCmlmIG5vdCBzcGVjIG9yIG5vdCBzcGVjLnN1Ym1vZHVsZV9zZWFyY2hfbG9jYXRpb25zOgogICAgcHJpbnQoInNnbF9rZXJuZWwgcGFja2FnZSBub3QgZm91bmQsIHNraXBwaW5nIHBhdGNoIikKICAgIHN5cy5leGl0KDApCgpwa2dfZGlyID0gcGF0aGxpYi5QYXRoKHNwZWMuc3VibW9kdWxlX3NlYXJjaF9sb2NhdGlvbnNbMF0pCnAgPSBwa2dfZGlyIC8gJ2xvYWRfdXRpbHMucHknCmlmIG5vdCBwLmV4aXN0cygpOgogICAgcHJpbnQoZiJsb2FkX3V0aWxzLnB5IG5vdCBmb3VuZCBhdCB7cH0sIHNraXBwaW5nIHBhdGNoIikKICAgIHN5cy5leGl0KDApCgpzcmMgPSBwLnJlYWRfdGV4dCgpCgptYXRjaCA9IHJlLnNlYXJjaChyJ14oICopcmFpc2UgSW1wb3J0RXJyb3JcKGVycm9yX21zZ1wpJywgc3JjLCByZS5NVUxUSUxJTkUpCmlmIG5vdCBtYXRjaDoKICAgIHByaW50KGYiV0FSTklORzogJ3JhaXNlIEltcG9ydEVycm9yKGVycm9yX21zZyknIG5vdCBmb3VuZCBpbiB7cH0iKQogICAgc3lzLmV4aXQoMSkKCmluZGVudCA9IG1hdGNoLmdyb3VwKDEpCnRhcmdldCA9ICdyYWlzZSBJbXBvcnRFcnJvcihlcnJvcl9tc2cpJwoKcmVwbCA9ICdcbicuam9pbihbCiAgICAnaW1wb3J0IHdhcm5pbmdzJywKICAgIGYne2luZGVudH13YXJuaW5ncy53YXJuKCJzZ2xfa2VybmVsOiBubyBjb21wYXRpYmxlIG9wcyBmb3IgdGhpcyBHUFU7ICInLAogICAgZid7aW5kZW50fSAgICAgICAgICAgICAgIkNVREEga2VybmVscyB3aWxsIGJlIHVuYXZhaWxhYmxlLiIsIFJ1bnRpbWVXYXJuaW5nKScsCiAgICBmJ3tpbmRlbnR9Y2xhc3MgX0xhenlTdHViOicsCiAgICBmJ3tpbmRlbnR9ICAgIGRlZiBfX2dldGF0dHJfXyhzZWxmLCBuYW1lKTonLAogICAgZid7aW5kZW50fSAgICAgICAgZGVmIF91bmF2YWlsYWJsZSgqYSwgKiprdyk6JywKICAgIGYne2luZGVudH0gICAgICAgICAgICByYWlzZSBSdW50aW1lRXJyb3IoZiJzZ2xfa2VybmVsLnt7bmFtZX19IHVuYXZhaWxhYmxlIG9uIHRoaXMgR1BVIGFyY2giKScsCiAgICBmJ3tpbmRlbnR9ICAgICAgICByZXR1cm4gX3VuYXZhaWxhYmxlJywKICAgIGYne2luZGVudH1yZXR1cm4gX0xhenlTdHViKCknLApdKQoKc3JjID0gc3JjLnJlcGxhY2UodGFyZ2V0LCByZXBsKQpwLndyaXRlX3RleHQoc3JjKQpwcmludChmIlBhdGNoZWQge3B9IChpbmRlbnQ9e2xlbihpbmRlbnQpfSBzcGFjZXMpIikKCiMgQ2xlYXIgYnl0ZWNvZGUgY2FjaGUgZm9yIGxvYWRfdXRpbHMKcHljYWNoZSA9IHBrZ19kaXIgLyAnX19weWNhY2hlX18nCmlmIHB5Y2FjaGUuZXhpc3RzKCk6CiAgICBmb3IgcHljIGluIHB5Y2FjaGUuZ2xvYignbG9hZF91dGlscyonKToKICAgICAgICBweWMudW5saW5rKCkKCiMgPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09CiMgUGFydCAyOiBDcmVhdGUgYSBzaXRlLXBhY2thZ2VzIG1vZHVsZSB0aGF0IG1vbmtleXBhdGNoZXMKIyAgICAgICAgICB0b3JjaC5saWJyYXJ5LnJlZ2lzdGVyX2Zha2UgdG8gc2lsZW50bHkgc2tpcCB3aGVuIHRoZQojICAgICAgICAgIHVuZGVybHlpbmcgQysrIG9wZXJhdG9yIGRvZXNuJ3QgZXhpc3QuCiMgPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09CnNpdGVfZGlyID0gcGtnX2Rpci5wYXJlbnQgICMgc2l0ZS1wYWNrYWdlcyBkaXJlY3RvcnkKCmNvbXBhdF9tb2R1bGUgPSBzaXRlX2RpciAvICdfc2dsX2tlcm5lbF9jb21wYXQucHknCmNvbXBhdF9tb2R1bGUud3JpdGVfdGV4dCgnJydcCiIiIk1vbmtleXBhdGNoIHRvcmNoLmxpYnJhcnkucmVnaXN0ZXJfZmFrZSB0byBoYW5kbGUgbWlzc2luZyBzZ2xfa2VybmVsIG9wcy4iIiIKaW1wb3J0IHRvcmNoCgpfb3JpZ19yZWdpc3Rlcl9mYWtlID0gdG9yY2gubGlicmFyeS5yZWdpc3Rlcl9mYWtlCgpkZWYgX3Jlc2lsaWVudF9yZWdpc3Rlcl9mYWtlKG9wX25hbWUsIGZuPU5vbmUsIC8sICoqa3dhcmdzKToKICAgICIiIldyYXBwZXIgdGhhdCByZXR1cm5zIGEgbm8tb3AgZGVjb3JhdG9yIGlmIHRoZSBvcGVyYXRvciBkb2Vzbid0IGV4aXN0LiIiIgogICAgdHJ5OgogICAgICAgIHJldHVybiBfb3JpZ19yZWdpc3Rlcl9mYWtlKG9wX25hbWUsIGZuLCAqKmt3YXJncykKICAgIGV4Y2VwdCBSdW50aW1lRXJyb3IgYXMgZToKICAgICAgICBpZiAiZG9lcyBub3QgZXhpc3QiIGluIHN0cihlKSBhbmQgInNnbF9rZXJuZWwiIGluIHN0cihvcF9uYW1lKToKICAgICAgICAgICAgIyBzZ2xfa2VybmVsIG9wcyBub3QgbG9hZGVkIChhcmNoaXRlY3R1cmUgbWlzbWF0Y2gpOyBza2lwIHJlZ2lzdHJhdGlvbgogICAgICAgICAgICBpZiBmbiBpcyBub3QgTm9uZToKICAgICAgICAgICAgICAgIHJldHVybiBmbgogICAgICAgICAgICByZXR1cm4gbGFtYmRhIGY6IGYKICAgICAgICByYWlzZQoKdG9yY2gubGlicmFyeS5yZWdpc3Rlcl9mYWtlID0gX3Jlc2lsaWVudF9yZWdpc3Rlcl9mYWtlCicnJykKcHJpbnQoZiJDcmVhdGVkIHtjb21wYXRfbW9kdWxlfSIpCgojIENyZWF0ZSAucHRoIGZpbGUgc28gdGhlIGNvbXBhdCBtb2R1bGUgcnVucyBvbiBldmVyeSBQeXRob24gc3RhcnR1cApwdGhfZmlsZSA9IHNpdGVfZGlyIC8gJ19zZ2xfa2VybmVsX2NvbXBhdC5wdGgnCnB0aF9maWxlLndyaXRlX3RleHQoJ2ltcG9ydCBfc2dsX2tlcm5lbF9jb21wYXRcbicpCnByaW50KGYiQ3JlYXRlZCB7cHRoX2ZpbGV9IikK' | base64 -d | python + +# Patch Triton's driver.py so that importing transformer_engine does not crash +# on CPU-only nodes/processes (RolloutManager, head-node driver, etc.). +# Triton 3.3.1 (bundled with torch 2.7.1) raises RuntimeError when no GPU +# driver is found; we replace the error with a graceful no-op fallback. +RUN echo 'aW1wb3J0IHBhdGhsaWIsIHN5cywgaW1wb3J0bGliLnV0aWwKCiMgRmluZCB0cml0b24gZHJpdmVyLnB5IFdJVEhPVVQgaW1wb3J0aW5nIGl0IChpbXBvcnQgdHJpZ2dlcnMgdGhlIGNyYXNoKQpzcGVjID0gaW1wb3J0bGliLnV0aWwuZmluZF9zcGVjKCd0cml0b24ucnVudGltZS5kcml2ZXInKQppZiBub3Qgc3BlYyBvciBub3Qgc3BlYy5vcmlnaW46CiAgICBwcmludCgidHJpdG9uLnJ1bnRpbWUuZHJpdmVyIG5vdCBmb3VuZCwgc2tpcHBpbmcgcGF0Y2giKQogICAgc3lzLmV4aXQoMCkKCnAgPSBwYXRobGliLlBhdGgoc3BlYy5vcmlnaW4pCnNyYyA9IHAucmVhZF90ZXh0KCkKdGFyZ2V0ID0gJ3JhaXNlIFJ1bnRpbWVFcnJvcihmIntsZW4oYWN0aXZlcyl9IGFjdGl2ZSBkcml2ZXJzICh7YWN0aXZlc30pLiBUaGVyZSBzaG91bGQgb25seSBiZSBvbmUuIiknCmlmIHRhcmdldCBub3QgaW4gc3JjOgogICAgcHJpbnQoZiJXQVJOSU5HOiBwYXRjaCB0YXJnZXQgbm90IGZvdW5kIGluIHtwfSIpCiAgICBwcmludCgiRmlsZSBjb250ZW50czoiKQogICAgcHJpbnQoc3JjKQogICAgc3lzLmV4aXQoMSkKCnJlcGwgPSAoCiAgICAnaWYgbGVuKGFjdGl2ZXMpID09IDA6XG4nCiAgICAnICAgICAgICAgICAgY2xhc3MgX051bGxEcml2ZXI6XG4nCiAgICAnICAgICAgICAgICAgICAgIGRlZiBnZXRfYmVuY2htYXJrZXIoc2VsZik6IHJldHVybiBsYW1iZGEgKmEsICoqa3c6IDAuMFxuJwogICAgJyAgICAgICAgICAgICAgICBkZWYgZ2V0X2N1cnJlbnRfdGFyZ2V0KHNlbGYpOiByZXR1cm4gTm9uZVxuJwogICAgJyAgICAgICAgICAgICAgICBkZWYgaXNfYWN0aXZlKHNlbGYpOiByZXR1cm4gRmFsc2VcbicKICAgICcgICAgICAgICAgICByZXR1cm4gX051bGxEcml2ZXIoKVxuJwogICAgJyAgICAgICAgJyArIHRhcmdldAopCnAud3JpdGVfdGV4dChzcmMucmVwbGFjZSh0YXJnZXQsIHJlcGwpKQoKIyBDbGVhciBieXRlY29kZSBjYWNoZQpweWNhY2hlID0gcC5wYXJlbnQgLyAnX19weWNhY2hlX18nCmlmIHB5Y2FjaGUuZXhpc3RzKCk6CiAgICBmb3IgcHljIGluIHB5Y2FjaGUuZ2xvYignZHJpdmVyKicpOgogICAgICAgIHB5Yy51bmxpbmsoKQogICAgICAgIHByaW50KGYiUmVtb3ZlZCBjYWNoZWQge3B5Y30iKQoKcHJpbnQoZiJQYXRjaGVkIHtwfSIpCg==' | base64 -d | python + +# Verify torch + flash-attn ABI compatibility. +# (TE verification runs at runtime on GPU nodes where Triton patch takes effect.) +RUN python -c "\ +import torch; print(f'PyTorch: {torch.__version__}, CUDA: {torch.version.cuda}'); \ +assert torch.__version__.startswith('2.7'), f'Expected 2.7.x, got {torch.__version__}'; \ +from flash_attn import flash_attn_func; print('flash-attn OK')" + +WORKDIR /tmp/slime diff --git a/slime_multi_node_rl/README.md b/slime_multi_node_rl/README.md new file mode 100644 index 0000000..bc542af --- /dev/null +++ b/slime_multi_node_rl/README.md @@ -0,0 +1,93 @@ +# Multi-Node RL Training with Slime + +[Slime](https://github.com/THUDM/slime) is an RL training framework that uses Megatron-LM for distributed training with disaggregated rollout via SGLang. This example runs GRPO training of Qwen3-1.7B on **2 workers x 4x A10G** (8 GPUs, 24 GB VRAM each) using Anyscale. + +## Cluster Layout + +``` +Head node (m5.2xlarge): driver only, no GPUs +Worker 0 (4 GPUs): [GPU 0-1: Training TP=2, Stage 0] [GPU 2-3: Rollout] +Worker 1 (4 GPUs): [GPU 0-1: Training TP=2, Stage 1] [GPU 2-3: Rollout] +``` + +- **Training**: 4 GPUs — TP=2 x PP=2 x DP=1 (Megatron backend, PP spans workers) +- **Rollout**: 3 GPUs — disaggregated SGLang inference, 1 GPU per engine (1 GPU reserved for driver) + +## Files + +| File | Description | +|------|-------------| +| `job.yaml` | Anyscale job config (`m5.2xlarge` head + 2x `g5.12xlarge` workers) | +| `Dockerfile.anyscale` | Docker image with Slime, Megatron-LM, SGLang, and A10G compatibility patches | +| `anyscale-smoke-2node-a10g.sh` | Anyscale entrypoint (downloads model/data, converts weights, runs training) | +| `patch_all_nodes.py` | Runtime patches for sgl_kernel compatibility on A10G (SM86) | +| `run-qwen3-4B-smoke-2node-a10g.sh` | Bare-metal variant for Qwen3-4B (manual Ray cluster setup) | + +## Install the Anyscale CLI + +```bash +pip install -U anyscale +anyscale login +``` + +## Quick Start + +Clone the example from GitHub. + +```bash +git clone https://github.com/anyscale/examples.git +cd examples/slime_multi_node_rl +``` + +Submit the job. + +```bash +anyscale job submit -f job.yaml +``` + +The entrypoint automatically: +1. Downloads `Qwen/Qwen3-1.7B` and `zhuzilin/dapo-math-17k` to `/mnt/cluster_storage` +2. Converts HF weights to Megatron torch_dist format (on a GPU worker) +3. Patches all nodes for A10G compatibility (sgl_kernel SM86 workaround) +4. Runs GRPO training with `deepscaler` reward model + +## Understanding the Example + +- The [Dockerfile.anyscale](Dockerfile.anyscale) builds on `anyscale/ray:2.54.0-py312-cu129` and installs Slime with all dependencies. It downgrades PyTorch from 2.9.1 to 2.7.1 for compatibility with pre-built flash-attn and transformer_engine wheels, and includes patches for sgl_kernel (which only ships SM100 ops from PyPI, incompatible with A10G SM86). +- The entrypoint uses `ray job submit --entrypoint-num-gpus 1` to schedule weight conversion and the training driver on GPU worker nodes (the head node is CPU-only and cannot run Triton/CUDA code). +- Training uses Megatron-LM with TP=2, PP=2 across the two workers. Pipeline parallelism spans nodes via NCCL. +- Rollout uses disaggregated SGLang inference engines (1 GPU each) for generating training samples. + +## A10G-Specific Settings + +| Setting | Value | Reason | +|---------|-------|--------| +| `NCCL_NVLS_ENABLE` | `0` | No NVLink on cloud A10G | +| `--attention-backend` | `flash` | FA2 only (Ampere, no FA3) | +| `--sglang-attention-backend` | `flashinfer` | For SGLang on Ampere | +| `--max-tokens-per-gpu` | `4096` | Conservative for 24 GB VRAM | +| No FP8 | — | Ampere does not support FP8 | + +## Verification + +A successful run shows: +- SGLang engine startup on rollout GPUs +- Cross-node NCCL init for pipeline parallelism +- Training loss values printed each step +- Weight sync between training and rollout engines + +## If You Hit OOM + +**Training GPUs:** +1. `--max-tokens-per-gpu` -> `2048` +2. `--rollout-max-response-len` -> `1024` +3. `--n-samples-per-prompt` -> `2` and `--global-batch-size` -> `16` +4. Add `--optimizer-cpu-offload` + +**Rollout GPUs:** +1. `--sglang-mem-fraction-static` -> `0.5` +2. Add `--sglang-chunked-prefill-size 2048` + +## View the Job + +View the job in the [jobs tab](https://console.anyscale.com/jobs) of the Anyscale console. diff --git a/slime_multi_node_rl/anyscale-smoke-2node-a10g.sh b/slime_multi_node_rl/anyscale-smoke-2node-a10g.sh new file mode 100755 index 0000000..bb39472 --- /dev/null +++ b/slime_multi_node_rl/anyscale-smoke-2node-a10g.sh @@ -0,0 +1,183 @@ +#!/bin/bash +# Anyscale entrypoint: Qwen3-1.7B smoke test on 2 workers × 4x A10G +# Downloads model/dataset, converts weights, and runs a full RL training run. +# +# Head node (m5.2xlarge): driver only, no GPUs +# Layout (GPU workers): +# Worker 0 (4 GPUs): [GPU 0-1: Training TP=2, PP Stage 0] [GPU 2-3: Rollout] +# Worker 1 (4 GPUs): [GPU 0-1: Training TP=2, PP Stage 1] [GPU 2-3: Rollout] + +set -ex + +export PYTHONBUFFERED=16 +STORAGE=/mnt/cluster_storage + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" + +# Qwen3-1.7B model architecture args +MODEL_ARGS=( + --swiglu + --num-layers 28 + --hidden-size 2048 + --ffn-hidden-size 6144 + --num-attention-heads 16 + --group-query-attention + --num-query-groups 8 + --use-rotary-position-embeddings + --disable-bias-linear + --normalization "RMSNorm" + --norm-epsilon 1e-6 + --rotary-base 1000000 + --vocab-size 151936 + --kv-channels 128 + --qk-layernorm +) + +# ======================== Step 1: Download model & dataset ======================== + +echo "=== Downloading model ===" +huggingface-cli download Qwen/Qwen3-1.7B --local-dir ${STORAGE}/Qwen3-1.7B + +echo "=== Downloading dataset ===" +huggingface-cli download --repo-type dataset zhuzilin/dapo-math-17k --local-dir ${STORAGE}/dapo-math-17k + +# ======================== Step 2: Convert HF weights to torch_dist ======================== + +if [ ! -d "${STORAGE}/Qwen3-1.7B_torch_dist/iter_0000000" ]; then + echo "=== Converting weights (HF -> torch_dist) on GPU worker ===" + # Run conversion on a GPU worker via Ray (head node is CPU-only; Triton + # crashes without GPU drivers). Output goes to shared /mnt/cluster_storage. + CONVERT_ENV_JSON='{ + "env_vars": { + "PYTHONPATH": "/root/Megatron-LM/" + } + }' + ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${CONVERT_ENV_JSON}" \ + --entrypoint-num-gpus 1 \ + -- python3 /tmp/slime/tools/convert_hf_to_torch_dist.py \ + ${MODEL_ARGS[@]} \ + --no-gradient-accumulation-fusion \ + --hf-checkpoint ${STORAGE}/Qwen3-1.7B \ + --save ${STORAGE}/Qwen3-1.7B_torch_dist +else + echo "=== Converted weights already exist, skipping ===" +fi + +# ======================== Step 2.5: Patch fake_impl.py on all nodes ======================== +# sgl_kernel from PyPI ships SM100-only ops; on A10G (SM86) the C++ operators +# never load, so @torch.library.register_fake(...) crashes with +# "operator ... does not exist". This patches torch's fake_impl.py on every +# node to catch that RuntimeError gracefully. + +echo "=== Patching fake_impl.py on all nodes ===" +cp ${SCRIPT_DIR}/patch_all_nodes.py ${STORAGE}/patch_all_nodes.py +ray job submit --address="http://127.0.0.1:8265" \ + -- python3 ${STORAGE}/patch_all_nodes.py + +# ======================== Step 3: Run training ======================== + +CKPT_ARGS=( + --hf-checkpoint ${STORAGE}/Qwen3-1.7B + --ref-load ${STORAGE}/Qwen3-1.7B_torch_dist + --load ${STORAGE}/Qwen3-1.7B_torch_dist +) + +ROLLOUT_ARGS=( + --prompt-data ${STORAGE}/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --balance-data + --rm-type deepscaler + --num-rollout 20 + --rollout-batch-size 16 + --n-samples-per-prompt 4 + --rollout-max-response-len 2048 + --rollout-temperature 1 + --global-batch-size 64 +) + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 2 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 4096 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.7 + --sglang-attention-backend flashinfer +) + +MISC_ARGS=( + --no-gradient-accumulation-fusion + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + # A10G (Ampere) — use FA2, not FA3 + --attention-backend flash + # Metrics: loss, reward, TFLOPS, pass@k are logged by default. + # Save to tensorboard for post-hoc analysis. + --use-tensorboard + --tensorboard-dir ${STORAGE}/tensorboard_logs +) + +RUNTIME_ENV_JSON='{ + "env_vars": { + "PYTHONPATH": "/root/Megatron-LM/", + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "NCCL_NVLS_ENABLE": "0", + "TENSORBOARD_DIR": "/mnt/cluster_storage/tensorboard_logs" + } +}' + +echo "=== Submitting training job ===" +# Run the driver on a GPU worker (needs CUDA for Megatron arg validation). +# Reserve 1 GPU for driver; reduce rollout to 3 (4 train + 3 rollout + 1 driver = 8). +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + --entrypoint-num-gpus 1 \ + -- python3 /tmp/slime/train.py \ + --actor-num-nodes 2 \ + --actor-num-gpus-per-node 2 \ + --rollout-num-gpus 3 \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} diff --git a/slime_multi_node_rl/job.yaml b/slime_multi_node_rl/job.yaml new file mode 100644 index 0000000..eab1611 --- /dev/null +++ b/slime_multi_node_rl/job.yaml @@ -0,0 +1,29 @@ +# Anyscale job config: Slime multi-node RL training +# Qwen3-1.7B GRPO on 2 worker nodes × 4x A10G (8 GPUs total) +# +# Layout: +# Head node (m5.2xlarge): driver only, no GPUs +# Worker 0 (4 GPUs): [GPU 0-1: Training TP=2, PP Stage 0] [GPU 2-3: Rollout] +# Worker 1 (4 GPUs): [GPU 0-1: Training TP=2, PP Stage 1] [GPU 2-3: Rollout] +# +# Submit with: +# cd examples/slime_multi_node_rl +# anyscale job submit -f job.yaml + +name: slime-qwen3-1.7b-rl-2node-a10g + +containerfile: ./Dockerfile.anyscale + +compute_config: + head_node: + instance_type: m5.2xlarge # CPU-only, runs driver script + worker_nodes: + - instance_type: g5.12xlarge # 4x A10G, 48 vCPU, 192 GB RAM + min_nodes: 2 + max_nodes: 2 + +working_dir: . + +entrypoint: bash anyscale-smoke-2node-a10g.sh + +max_retries: 0 diff --git a/slime_multi_node_rl/patch_all_nodes.py b/slime_multi_node_rl/patch_all_nodes.py new file mode 100644 index 0000000..b209270 --- /dev/null +++ b/slime_multi_node_rl/patch_all_nodes.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +"""Runtime patches for sgl_kernel incompatibility on A10G (SM86). + +sgl_kernel from PyPI only ships SM100 (Blackwell) CUDA ops. On A10G (SM86) +the C++ operators never load. This causes two classes of errors: + +1. Import-time: @torch.library.register_fake("sgl_kernel::...") crashes with + "RuntimeError: operator ... does not exist" + → Fix: Wrap _register_fake in try/except in torch/library.py + +2. Runtime: SGLang's forward_cuda calls torch.ops.sgl_kernel.rmsnorm which + doesn't exist → "AttributeError: no attribute 'rmsnorm'" + → Fix: Redirect MultiPlatformOp.dispatch_forward to forward_native + +This script applies both patches on all GPU worker nodes via Ray. +""" +import ray + + +@ray.remote(num_cpus=0.01) +def patch_node(): + """Apply all patches on the current node.""" + import pathlib + import importlib.util + import socket + + hostname = socket.gethostname() + patched = [] + + # ---- Locate torch ---- + torch_spec = importlib.util.find_spec("torch") + if not torch_spec or not torch_spec.submodule_search_locations: + return f"[{hostname}] torch not found, skipping" + torch_dir = pathlib.Path(torch_spec.submodule_search_locations[0]) + + # ==== Patch 1: torch/library.py ==== + # Wrap _register_fake call with try/except to handle "does not exist" errors + library_py = torch_dir / "library.py" + if library_py.exists(): + src = library_py.read_text() + if "_rf_err" in src: + patched.append("library.py (already patched)") + else: + old = " use_lib._register_fake(op_name, func, _stacklevel=stacklevel + 1)\n return func" + new = ( + " try:\n" + " use_lib._register_fake(op_name, func, _stacklevel=stacklevel + 1)\n" + " except RuntimeError as _rf_err:\n" + ' if "does not exist" in str(_rf_err):\n' + " pass # operator not registered (e.g., sgl_kernel on incompatible arch)\n" + " else:\n" + " raise\n" + " return func" + ) + if old in src: + src = src.replace(old, new) + library_py.write_text(src) + # Clear bytecode cache + for pyc in (torch_dir / "__pycache__").glob("library*"): + pyc.unlink() + patched.append("library.py") + else: + patched.append("library.py (target not found)") + + # ==== Patch 2: sglang multi_platform.py ==== + # Redirect dispatch_forward to forward_native on CUDA (sgl_kernel ops unavailable) + sglang_spec = importlib.util.find_spec("sglang") + if sglang_spec and sglang_spec.submodule_search_locations: + sglang_dir = pathlib.Path(sglang_spec.submodule_search_locations[0]) + mp_py = sglang_dir / "srt" / "layers" / "utils" / "multi_platform.py" + if mp_py.exists(): + src = mp_py.read_text() + if "sgl_kernel ops unavailable" in src: + patched.append("multi_platform.py (already patched)") + else: + old_mp = ( + " if _is_cuda:\n" + " return self.forward_cuda" + ) + new_mp = ( + " if _is_cuda:\n" + " return self.forward_native # sgl_kernel ops unavailable on A10G (SM86)" + ) + if old_mp in src: + src = src.replace(old_mp, new_mp) + mp_py.write_text(src) + # Clear bytecode cache + pycache = mp_py.parent / "__pycache__" + if pycache.exists(): + for pyc in pycache.glob("multi_platform*"): + pyc.unlink() + patched.append("multi_platform.py") + else: + patched.append("multi_platform.py (target not found)") + else: + patched.append("multi_platform.py (file not found)") + + return f"[{hostname}] patched: {', '.join(patched)}" + + +def main(): + ray.init() + + # Only patch GPU nodes — the head node (CPU-only) can't schedule tasks + nodes = [ + n for n in ray.nodes() + if n["Alive"] and n.get("Resources", {}).get("GPU", 0) > 0 + ] + print(f"Patching {len(nodes)} GPU nodes...") + + refs = [] + for node in nodes: + ip = node["NodeManagerAddress"] + ref = patch_node.options( + resources={f"node:{ip}": 0.001} + ).remote() + refs.append(ref) + + results = ray.get(refs) + for r in results: + print(f" {r}") + + print("All nodes patched!") + + +if __name__ == "__main__": + main() diff --git a/slime_multi_node_rl/run-qwen3-4B-smoke-2node-a10g.sh b/slime_multi_node_rl/run-qwen3-4B-smoke-2node-a10g.sh new file mode 100755 index 0000000..06853ad --- /dev/null +++ b/slime_multi_node_rl/run-qwen3-4B-smoke-2node-a10g.sh @@ -0,0 +1,148 @@ +#!/bin/bash + +# Smoke test: Qwen3-4B on 2 nodes × 4x A10G (bare-metal) +# Layout: +# Node 0 (4 GPUs): [GPU 0-1: Training TP=2, PP Stage 0] [GPU 2-3: Rollout] +# Node 1 (4 GPUs): [GPU 0-1: Training TP=2, PP Stage 1] [GPU 2-3: Rollout] +# Training: 4 GPUs (TP=2 × PP=2 × DP=1), Rollout: 4 GPUs (disaggregated) + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" + +# Qwen3-4B model architecture args +MODEL_ARGS=( + --swiglu + --num-layers 36 + --hidden-size 2560 + --ffn-hidden-size 9728 + --num-attention-heads 32 + --group-query-attention + --num-query-groups 8 + --use-rotary-position-embeddings + --disable-bias-linear + --normalization "RMSNorm" + --norm-epsilon 1e-6 + --rotary-base 1000000 + --vocab-size 151936 + --kv-channels 128 + --qk-layernorm +) + +CKPT_ARGS=( + --hf-checkpoint /root/Qwen3-4B + --ref-load /root/Qwen3-4B_torch_dist + --load /root/Qwen3-4B_torch_dist +) + +ROLLOUT_ARGS=( + --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --balance-data + --rm-type deepscaler + --num-rollout 3 + --rollout-batch-size 8 + --n-samples-per-prompt 4 + --rollout-max-response-len 2048 + --rollout-temperature 1 + --global-batch-size 32 +) + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 2 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 4096 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + # --use-wandb + # --wandb-project slime-dev + # --wandb-group qwen3-4B-smoke-2node-a10g +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.7 + --sglang-attention-backend flashinfer +) + +MISC_ARGS=( + # default dropout in megatron is 0.1 + --attention-dropout 0.0 + --hidden-dropout 0.0 + # should be good for model performance + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + # A10G (Ampere) — use FA2, not FA3 + --attention-backend flash +) + +# --- Multi-Node Ray Cluster Setup --- +# On Anyscale: skip ray start commands, the Ray cluster is managed automatically. +# For bare-metal / non-Anyscale, uncomment the following: +# +# export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +# ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 4 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 +# +# On each worker node: +# ray start --address=${MASTER_ADDR}:6379 --node-ip-address ${WORKER_ADDR} --num-gpus 4 --disable-usage-stats + +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"0\", + \"MASTER_ADDR\": \"${MASTER_ADDR}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 /tmp/slime/train.py \ + --actor-num-nodes 2 \ + --actor-num-gpus-per-node 2 \ + --rollout-num-gpus 4 \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]}