Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions ray_rdt/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
FROM anyscale/ray:2.53.0-py312-cu129

# =============================================================================
# System Dependencies
# =============================================================================
RUN sudo apt-get update && \
sudo apt-get install -y --no-install-recommends \
build-essential \
cmake \
ninja-build \
libnuma1 \
libnuma-dev \
numactl \
git \
curl \
wget \
netcat \
&& sudo rm -rf /var/lib/apt/lists/*

# =============================================================================
# CUDA Toolkit (nvcc compiler) - CUDA 12.9 to match base image
# =============================================================================
RUN curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb -o /tmp/cuda-keyring.deb && \
sudo dpkg -i /tmp/cuda-keyring.deb && \
rm /tmp/cuda-keyring.deb && \
sudo apt-get update && \
sudo apt-get install -y --no-install-recommends \
cuda-nvcc-12-9 \
cuda-cudart-dev-12-9 \
cuda-crt-12-9 \
&& sudo rm -rf /var/lib/apt/lists/*

# Create/update CUDA symlink
RUN sudo rm -rf /usr/local/cuda && \
sudo ln -s /usr/local/cuda-12.9 /usr/local/cuda

# =============================================================================
# Environment Variables
# =============================================================================
ENV PATH="/usr/local/cuda/bin:${PATH}"
ENV CUDA_HOME="/usr/local/cuda"
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
ENV RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES="1"

# =============================================================================
# Python Dependencies
# =============================================================================
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
ENV PATH="/home/ray/.local/bin:${PATH}"

RUN uv pip install --system \
"sglang[all] @ git+https://github.com/xyuzh/sglang.git@feature/ray-actor-scheduler#subdirectory=python" \
numpy \
transformers \
accelerate \
huggingface_hub \
requests \
httpx


WORKDIR /home/ray/default
40 changes: 40 additions & 0 deletions ray_rdt/job_test_rdt_weight_sync.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Anyscale Job: RDT Weight Sync Correctness Test
#
# Transfers model weights from a HuggingFace "trainer" actor to an SGLang
# SchedulerActor via Ray Direct Transport (NCCL), then verifies every
# parameter matches.
#
# Configuration: 2 GPUs on a single node (1 producer + 1 scheduler)
# Submit: anyscale job submit -f ray_rdt/job_test_rdt_weight_sync.yaml

name: sglang-test-rdt-weight-sync

cloud:

compute_config:
head_node:
instance_type: g5.12xlarge # 4x A10G (need at least 2 GPUs)

containerfile: Dockerfile

working_dir: .

entrypoint: >
set -e &&
echo "==========================================" &&
echo "RDT Weight Sync Correctness Test Setup" &&
echo "==========================================" &&
pip install --no-cache-dir -q "numpy>=1.26.0,<2.0" &&
pip install --no-cache-dir -q sgl-kernel --force-reinstall &&
echo "Running RDT weight sync test..." &&
echo "==========================================" &&
python test_rdt_weight_sync.py --model-path Qwen/Qwen3-0.6B

env_vars:
NCCL_DEBUG: INFO
NCCL_IB_DISABLE: "0"
NCCL_NET_GDR_LEVEL: "2"
RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES: "1"
SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK: "1"

max_retries: 0
204 changes: 204 additions & 0 deletions ray_rdt/test_rdt_weight_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
"""
RDT Weight Sync Correctness Test

Verifies that model weights can be correctly transferred from a "trainer"
Ray actor to an SGLang SchedulerActor using Ray Direct Transport (RDT)
with NCCL.

Key idea: SchedulerActor.__init__ loads the model. By *not* calling
run_event_loop() the actor is free to receive regular Ray method calls,
which makes the test trivially sequential.

Requires 2 GPUs on the same node (one for the producer, one for the
scheduler).

Usage:
python test_rdt_weight_sync.py --model-path Qwen/Qwen3-0.6B
"""

import argparse
import json
import sys
import time

import ray
import torch
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy


# ---------------------------------------------------------------------------
# Weight-producer actor (simulates a trainer holding updated weights)
# ---------------------------------------------------------------------------
@ray.remote(num_gpus=1)
class WeightProducerActor:
"""Loads a HuggingFace model and exposes its weights via RDT."""

def __init__(self, model_path: str):
import torch
from transformers import AutoModelForCausalLM

self.model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=torch.float16
).cuda()
self._build_bucket()

# -- internal ---------------------------------------------------------
def _build_bucket(self):
from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket

named_tensors = [
(name, param.data) for name, param in self.model.named_parameters()
]
self._bucket = FlattenedTensorBucket(named_tensors=named_tensors)

# -- public API -------------------------------------------------------
@ray.method(tensor_transport="nccl")
def get_weights(self) -> torch.Tensor:
"""Return the flattened weight tensor (transferred via NCCL/RDT)."""
return self._bucket.get_flattened_tensor()

def get_metadata_json(self) -> str:
"""Return bucket metadata as a JSON string (object-store path)."""
metadata = self._bucket.get_metadata()
return json.dumps(
[
{
"name": m.name,
"shape": list(m.shape),
"dtype": str(m.dtype).replace("torch.", ""),
"start_idx": m.start_idx,
"end_idx": m.end_idx,
"numel": m.numel,
}
for m in metadata
]
)

def get_param(self, name: str) -> torch.Tensor:
"""Return a single parameter tensor (on CPU) for verification."""
return self.model.state_dict()[name].cpu()

def get_param_names(self) -> list:
return list(self.model.state_dict().keys())


# ---------------------------------------------------------------------------
# Main test driver
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="RDT Weight Sync Correctness Test")
parser.add_argument("--model-path", type=str, default="Qwen/Qwen3-0.6B")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()

print(f"\n{'='*60}")
print("RDT Weight Sync Correctness Test")
print(f"{'='*60}")
print(f"Model: {args.model_path}")

# ------------------------------------------------------------------
# 1. Placement group: 2 GPUs on the same node
# ------------------------------------------------------------------
pg = placement_group(
bundles=[{"GPU": 1, "CPU": 1}, {"GPU": 1, "CPU": 1}],
strategy="STRICT_PACK",
)
ray.get(pg.ready())
print("Placement group ready (2 GPUs).\n")

# ------------------------------------------------------------------
# 2. Create the weight producer (bundle 0)
# ------------------------------------------------------------------
print("Creating WeightProducerActor …")
producer = WeightProducerActor.options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_bundle_index=0,
),
).remote(args.model_path)

# ------------------------------------------------------------------
# 3. Create the SchedulerActor (bundle 1) — don't start event loop
# ------------------------------------------------------------------
print("Creating SchedulerActor …")
from sglang.srt.managers.scheduler_actor import SchedulerActor
from sglang.srt.server_args import PortArgs, ServerArgs

server_args = ServerArgs(
model_path=args.model_path,
tp_size=1,
port=args.port,
)
port_args = PortArgs.init_new(server_args)

scheduler = SchedulerActor.options(
num_gpus=1,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_bundle_index=1,
),
).remote(
server_args=server_args,
port_args=port_args,
gpu_id=0,
tp_rank=0,
moe_ep_rank=0,
pp_rank=0,
dp_rank=0,
)

# Wait for both actors to be ready
print("Waiting for actors to initialise …")
ray.get(scheduler.get_info.remote())
param_names = ray.get(producer.get_param_names.remote())
print(f"Both actors ready. Model has {len(param_names)} state-dict entries.\n")

# ------------------------------------------------------------------
# 4. Transfer weights: metadata via object store, tensor via RDT
# ------------------------------------------------------------------
print("Step 1: Fetch metadata (object store) …")
t0 = time.time()
metadata_json = ray.get(producer.get_metadata_json.remote())
print(f" Done ({time.time() - t0:.2f}s)")

print("Step 2: Transfer weights (RDT / NCCL) …")
t0 = time.time()
weights_ref = producer.get_weights.remote() # RDT-enabled ObjectRef
ok = ray.get(scheduler.receive_weights_rdt.remote(weights_ref, metadata_json))
elapsed = time.time() - t0
print(f" Done — success={ok} ({elapsed:.2f}s)")

# ------------------------------------------------------------------
# 5. Verify: compare every parameter between producer and scheduler
# ------------------------------------------------------------------
print(f"\nStep 3: Verifying {len(param_names)} parameters …")
mismatches = []
for i, name in enumerate(param_names):
p_tensor = ray.get(producer.get_param.remote(name))
s_tensor = ray.get(scheduler.get_param.remote(name))
if not torch.allclose(p_tensor, s_tensor, atol=1e-6):
mismatches.append(name)
print(f" MISMATCH [{i}] {name}")
elif i % 20 == 0:
print(f" [{i}/{len(param_names)}] {name} ✓")

print(f"\nChecked {len(param_names)} parameters, {len(mismatches)} mismatches.")
if mismatches:
print("FAILURE — mismatched parameters:")
for n in mismatches:
print(f" - {n}")
return 1

print("SUCCESS: All parameters match!")

# ------------------------------------------------------------------
# Cleanup
# ------------------------------------------------------------------
ray.util.remove_placement_group(pg)
print("\nDone.")
return 0


if __name__ == "__main__":
sys.exit(main())