From b4f20dac320ea9af9a9edcbdee401eea1e0a7f01 Mon Sep 17 00:00:00 2001 From: xyuzh Date: Tue, 17 Feb 2026 14:47:42 -0800 Subject: [PATCH] Add SGLang Ray Direct Transport (RDT) weight sync example Demonstrates transferring model weights from a HuggingFace trainer actor to an SGLang SchedulerActor using Ray Direct Transport with NCCL, then verifying parameter correctness. Co-authored-by: Cursor --- ray_rdt/Dockerfile | 61 ++++++++ ray_rdt/job_test_rdt_weight_sync.yaml | 40 +++++ ray_rdt/test_rdt_weight_sync.py | 204 ++++++++++++++++++++++++++ 3 files changed, 305 insertions(+) create mode 100644 ray_rdt/Dockerfile create mode 100644 ray_rdt/job_test_rdt_weight_sync.yaml create mode 100644 ray_rdt/test_rdt_weight_sync.py diff --git a/ray_rdt/Dockerfile b/ray_rdt/Dockerfile new file mode 100644 index 0000000..c787649 --- /dev/null +++ b/ray_rdt/Dockerfile @@ -0,0 +1,61 @@ +FROM anyscale/ray:2.53.0-py312-cu129 + +# ============================================================================= +# System Dependencies +# ============================================================================= +RUN sudo apt-get update && \ + sudo apt-get install -y --no-install-recommends \ + build-essential \ + cmake \ + ninja-build \ + libnuma1 \ + libnuma-dev \ + numactl \ + git \ + curl \ + wget \ + netcat \ + && sudo rm -rf /var/lib/apt/lists/* + +# ============================================================================= +# CUDA Toolkit (nvcc compiler) - CUDA 12.9 to match base image +# ============================================================================= +RUN curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb -o /tmp/cuda-keyring.deb && \ + sudo dpkg -i /tmp/cuda-keyring.deb && \ + rm /tmp/cuda-keyring.deb && \ + sudo apt-get update && \ + sudo apt-get install -y --no-install-recommends \ + cuda-nvcc-12-9 \ + cuda-cudart-dev-12-9 \ + cuda-crt-12-9 \ + && sudo rm -rf /var/lib/apt/lists/* + +# Create/update CUDA symlink +RUN sudo rm -rf /usr/local/cuda && \ + sudo ln -s /usr/local/cuda-12.9 /usr/local/cuda + +# ============================================================================= +# Environment Variables +# ============================================================================= +ENV PATH="/usr/local/cuda/bin:${PATH}" +ENV CUDA_HOME="/usr/local/cuda" +ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}" +ENV RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES="1" + +# ============================================================================= +# Python Dependencies +# ============================================================================= +RUN curl -LsSf https://astral.sh/uv/install.sh | sh +ENV PATH="/home/ray/.local/bin:${PATH}" + +RUN uv pip install --system \ + "sglang[all] @ git+https://github.com/xyuzh/sglang.git@feature/ray-actor-scheduler#subdirectory=python" \ + numpy \ + transformers \ + accelerate \ + huggingface_hub \ + requests \ + httpx + + +WORKDIR /home/ray/default diff --git a/ray_rdt/job_test_rdt_weight_sync.yaml b/ray_rdt/job_test_rdt_weight_sync.yaml new file mode 100644 index 0000000..3308af2 --- /dev/null +++ b/ray_rdt/job_test_rdt_weight_sync.yaml @@ -0,0 +1,40 @@ +# Anyscale Job: RDT Weight Sync Correctness Test +# +# Transfers model weights from a HuggingFace "trainer" actor to an SGLang +# SchedulerActor via Ray Direct Transport (NCCL), then verifies every +# parameter matches. +# +# Configuration: 2 GPUs on a single node (1 producer + 1 scheduler) +# Submit: anyscale job submit -f ray_rdt/job_test_rdt_weight_sync.yaml + +name: sglang-test-rdt-weight-sync + +cloud: + +compute_config: + head_node: + instance_type: g5.12xlarge # 4x A10G (need at least 2 GPUs) + +containerfile: Dockerfile + +working_dir: . + +entrypoint: > + set -e && + echo "==========================================" && + echo "RDT Weight Sync Correctness Test Setup" && + echo "==========================================" && + pip install --no-cache-dir -q "numpy>=1.26.0,<2.0" && + pip install --no-cache-dir -q sgl-kernel --force-reinstall && + echo "Running RDT weight sync test..." && + echo "==========================================" && + python test_rdt_weight_sync.py --model-path Qwen/Qwen3-0.6B + +env_vars: + NCCL_DEBUG: INFO + NCCL_IB_DISABLE: "0" + NCCL_NET_GDR_LEVEL: "2" + RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES: "1" + SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK: "1" + +max_retries: 0 diff --git a/ray_rdt/test_rdt_weight_sync.py b/ray_rdt/test_rdt_weight_sync.py new file mode 100644 index 0000000..7d35cfe --- /dev/null +++ b/ray_rdt/test_rdt_weight_sync.py @@ -0,0 +1,204 @@ +""" +RDT Weight Sync Correctness Test + +Verifies that model weights can be correctly transferred from a "trainer" +Ray actor to an SGLang SchedulerActor using Ray Direct Transport (RDT) +with NCCL. + +Key idea: SchedulerActor.__init__ loads the model. By *not* calling +run_event_loop() the actor is free to receive regular Ray method calls, +which makes the test trivially sequential. + +Requires 2 GPUs on the same node (one for the producer, one for the +scheduler). + +Usage: + python test_rdt_weight_sync.py --model-path Qwen/Qwen3-0.6B +""" + +import argparse +import json +import sys +import time + +import ray +import torch +from ray.util.placement_group import placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + + +# --------------------------------------------------------------------------- +# Weight-producer actor (simulates a trainer holding updated weights) +# --------------------------------------------------------------------------- +@ray.remote(num_gpus=1) +class WeightProducerActor: + """Loads a HuggingFace model and exposes its weights via RDT.""" + + def __init__(self, model_path: str): + import torch + from transformers import AutoModelForCausalLM + + self.model = AutoModelForCausalLM.from_pretrained( + model_path, torch_dtype=torch.float16 + ).cuda() + self._build_bucket() + + # -- internal --------------------------------------------------------- + def _build_bucket(self): + from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket + + named_tensors = [ + (name, param.data) for name, param in self.model.named_parameters() + ] + self._bucket = FlattenedTensorBucket(named_tensors=named_tensors) + + # -- public API ------------------------------------------------------- + @ray.method(tensor_transport="nccl") + def get_weights(self) -> torch.Tensor: + """Return the flattened weight tensor (transferred via NCCL/RDT).""" + return self._bucket.get_flattened_tensor() + + def get_metadata_json(self) -> str: + """Return bucket metadata as a JSON string (object-store path).""" + metadata = self._bucket.get_metadata() + return json.dumps( + [ + { + "name": m.name, + "shape": list(m.shape), + "dtype": str(m.dtype).replace("torch.", ""), + "start_idx": m.start_idx, + "end_idx": m.end_idx, + "numel": m.numel, + } + for m in metadata + ] + ) + + def get_param(self, name: str) -> torch.Tensor: + """Return a single parameter tensor (on CPU) for verification.""" + return self.model.state_dict()[name].cpu() + + def get_param_names(self) -> list: + return list(self.model.state_dict().keys()) + + +# --------------------------------------------------------------------------- +# Main test driver +# --------------------------------------------------------------------------- +def main(): + parser = argparse.ArgumentParser(description="RDT Weight Sync Correctness Test") + parser.add_argument("--model-path", type=str, default="Qwen/Qwen3-0.6B") + parser.add_argument("--port", type=int, default=30000) + args = parser.parse_args() + + print(f"\n{'='*60}") + print("RDT Weight Sync Correctness Test") + print(f"{'='*60}") + print(f"Model: {args.model_path}") + + # ------------------------------------------------------------------ + # 1. Placement group: 2 GPUs on the same node + # ------------------------------------------------------------------ + pg = placement_group( + bundles=[{"GPU": 1, "CPU": 1}, {"GPU": 1, "CPU": 1}], + strategy="STRICT_PACK", + ) + ray.get(pg.ready()) + print("Placement group ready (2 GPUs).\n") + + # ------------------------------------------------------------------ + # 2. Create the weight producer (bundle 0) + # ------------------------------------------------------------------ + print("Creating WeightProducerActor …") + producer = WeightProducerActor.options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=0, + ), + ).remote(args.model_path) + + # ------------------------------------------------------------------ + # 3. Create the SchedulerActor (bundle 1) — don't start event loop + # ------------------------------------------------------------------ + print("Creating SchedulerActor …") + from sglang.srt.managers.scheduler_actor import SchedulerActor + from sglang.srt.server_args import PortArgs, ServerArgs + + server_args = ServerArgs( + model_path=args.model_path, + tp_size=1, + port=args.port, + ) + port_args = PortArgs.init_new(server_args) + + scheduler = SchedulerActor.options( + num_gpus=1, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=1, + ), + ).remote( + server_args=server_args, + port_args=port_args, + gpu_id=0, + tp_rank=0, + moe_ep_rank=0, + pp_rank=0, + dp_rank=0, + ) + + # Wait for both actors to be ready + print("Waiting for actors to initialise …") + ray.get(scheduler.get_info.remote()) + param_names = ray.get(producer.get_param_names.remote()) + print(f"Both actors ready. Model has {len(param_names)} state-dict entries.\n") + + # ------------------------------------------------------------------ + # 4. Transfer weights: metadata via object store, tensor via RDT + # ------------------------------------------------------------------ + print("Step 1: Fetch metadata (object store) …") + t0 = time.time() + metadata_json = ray.get(producer.get_metadata_json.remote()) + print(f" Done ({time.time() - t0:.2f}s)") + + print("Step 2: Transfer weights (RDT / NCCL) …") + t0 = time.time() + weights_ref = producer.get_weights.remote() # RDT-enabled ObjectRef + ok = ray.get(scheduler.receive_weights_rdt.remote(weights_ref, metadata_json)) + elapsed = time.time() - t0 + print(f" Done — success={ok} ({elapsed:.2f}s)") + + # ------------------------------------------------------------------ + # 5. Verify: compare every parameter between producer and scheduler + # ------------------------------------------------------------------ + print(f"\nStep 3: Verifying {len(param_names)} parameters …") + mismatches = [] + for i, name in enumerate(param_names): + p_tensor = ray.get(producer.get_param.remote(name)) + s_tensor = ray.get(scheduler.get_param.remote(name)) + if not torch.allclose(p_tensor, s_tensor, atol=1e-6): + mismatches.append(name) + print(f" MISMATCH [{i}] {name}") + elif i % 20 == 0: + print(f" [{i}/{len(param_names)}] {name} ✓") + + print(f"\nChecked {len(param_names)} parameters, {len(mismatches)} mismatches.") + if mismatches: + print("FAILURE — mismatched parameters:") + for n in mismatches: + print(f" - {n}") + return 1 + + print("SUCCESS: All parameters match!") + + # ------------------------------------------------------------------ + # Cleanup + # ------------------------------------------------------------------ + ray.util.remove_placement_group(pg) + print("\nDone.") + return 0 + + +if __name__ == "__main__": + sys.exit(main())