Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions roll/configs/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ class DataArguments:
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
cache_path: Optional[str] = field(
default=None,
metadata={"help": "Path to cache preprocessed datasets."},
)
file_name: Optional[Union[List[str], str]] = field(
default=None,
metadata={"help": "The name of file path name for train. Conflicts with `--dataset_name`"},
Expand Down
6 changes: 3 additions & 3 deletions roll/distributed/scheduler/generate_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,12 +473,12 @@ def __init__(

self.dataset = dataset
self.indices = list(range(len(dataset)))
if state is not None and state.get("dataset_iter_count", 0) > 0:
for _ in range(state["dataset_iter_count"]):
self.get_next_dataset_item()
self.dataset_epoch = 0
self.dataset_iter = None
self.dataset_iter_count = 0
if state is not None and state.get("dataset_iter_count", 0) > 0:
for _ in range(state["dataset_iter_count"]):
self.get_next_dataset_item()

self.collect_fn_cls = collect_fn_cls
self.collect_fn_kwargs = collect_fn_kwargs
Expand Down
8 changes: 8 additions & 0 deletions roll/distributed/scheduler/initialize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import shlex
import subprocess
import sys
import time
Expand Down Expand Up @@ -36,12 +37,19 @@ def start_ray_cluster():
logger.info("Ray cluster already initialized")
return False

temp_dir = os.environ.get("ROLL_RAY_TEMP_DIR")
temp_dir_arg = ""
if temp_dir:
os.makedirs(temp_dir, exist_ok=True)
temp_dir_arg = f" --temp-dir={shlex.quote(temp_dir)}"

if rank == 0:
cmd = f"ray start --head --port={master_port} --node-name={node_name} --dashboard-port={dashboard_port}"
else:
# fix: 处理大规模下可能会出现的head/worker node创建顺序不一致问题
time.sleep(5)
cmd = f"ray start --address={master_addr}:{master_port} --node-name={node_name} --dashboard-port={dashboard_port}"
cmd += temp_dir_arg

logger.info(f"Starting ray cluster: {cmd}")
ret = subprocess.run(cmd, shell=True, capture_output=True)
Expand Down
47 changes: 34 additions & 13 deletions roll/pipeline/rlvr/rlvr_vlm_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import json
import math
import os
import uuid
from functools import partial
Expand Down Expand Up @@ -224,22 +225,31 @@ def __init__(self, pipeline_config: RLVRConfig):
dataset = get_vlm_dataset(
self.pipeline_config.actor_train.data_args, encode_function, self.processor, get_eval=False
)
# update domain field, DynamicSamplingScheduler requires
dataset = dataset.map(
partial(update_dataset_domain, self.pipeline_config.tag_2_domain),
num_proc=self.pipeline_config.actor_train.data_args.preprocessing_num_workers,
desc="update_dataset_domain",
load_from_cache_file=False,
)

# Avoid rewriting large multimodal Arrow columns when a run has a single domain.
# Re-mapping PIL image columns can overflow pyarrow's regular array offsets.
domains = list(self.pipeline_config.actor_train.data_args.domain_interleave_probs.keys())
self.domain_datasets: Dict[str, datasets.Dataset] = {}
for domain in self.pipeline_config.actor_train.data_args.domain_interleave_probs.keys():
self.domain_datasets[domain] = dataset.filter(
lambda example, dom: example["domain"] == dom,
if len(domains) == 1:
domain = domains[0]
if "domain" not in dataset.column_names:
dataset = dataset.add_column("domain", [domain] * len(dataset))
self.domain_datasets[domain] = dataset
else:
# update domain field, DynamicSamplingScheduler requires
dataset = dataset.map(
partial(update_dataset_domain, self.pipeline_config.tag_2_domain),
num_proc=self.pipeline_config.actor_train.data_args.preprocessing_num_workers,
fn_kwargs={"dom": domain},
desc="update_dataset_domain",
load_from_cache_file=False,
)
assert len(self.domain_datasets[domain]) > 0, f"domain dataset {domain} has no data"

for domain in domains:
self.domain_datasets[domain] = dataset.filter(
lambda example, dom: example["domain"] == dom,
num_proc=self.pipeline_config.actor_train.data_args.preprocessing_num_workers,
fn_kwargs={"dom": domain},
)
assert len(self.domain_datasets[domain]) > 0, f"domain dataset {domain} has no data"

self.val_dataset = None
if self.pipeline_config.validation and self.pipeline_config.validation.data_args:
Expand All @@ -259,6 +269,17 @@ def __init__(self, pipeline_config: RLVRConfig):
kl_horizon=self.pipeline_config.kl_horizon,
)

if self.pipeline_config.max_steps <= 0:
num_train_epochs = self.pipeline_config.actor_train.training_args.num_train_epochs
dataset_size = sum(len(domain_dataset) for domain_dataset in self.domain_datasets.values())
inferred_max_steps = math.ceil(num_train_epochs * dataset_size / self.pipeline_config.rollout_batch_size)
logger.info(
"infer pipeline max_steps from dataset: "
f"num_train_epochs={num_train_epochs}, dataset_size={dataset_size}, "
f"rollout_batch_size={self.pipeline_config.rollout_batch_size}, "
f"max_steps={inferred_max_steps}"
)
self.pipeline_config.max_steps = inferred_max_steps
assert self.pipeline_config.max_steps > 0, "max_steps must be greater than 0"
self.pipeline_config.set_max_steps(max_steps=self.pipeline_config.max_steps)

Expand Down
40 changes: 35 additions & 5 deletions roll/third_party/deepspeed/model_update.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from dataclasses import asdict

import ray
import torch.distributed as dist
from deepspeed.runtime.zero import GatheredParameters
Expand Down Expand Up @@ -29,9 +31,22 @@ def _gather_weights(is_zero3, named_params):
return [(n, p.data) for n, p in named_params]


def gather_deepspeed_weights(model, ds_config, buffer_size):
def _get_deepspeed_named_params(model, ds_config, is_lora=False):
if not is_lora:
return [(name, param) for name, param in model.named_parameters()]

if not ds_config.is_zero3():
return [(name, param) for name, param in get_peft_model_state_dict(model).items()]

adapter_name = "default"
state_dict = model.state_dict()
lora_state_dict = {k: state_dict[k] for k in state_dict if ("lora_" in k and adapter_name in k)}
return [(name.replace(f".{adapter_name}", ""), model.get_parameter(name)) for name in lora_state_dict]


def gather_deepspeed_weights(model, ds_config, buffer_size, is_lora=False):
is_zero3 = ds_config.is_zero3()
named_params = [(name, param) for name, param in model.named_parameters()]
named_params = _get_deepspeed_named_params(model, ds_config, is_lora=is_lora)

waiting_params, waiting_params_size = [], 0
for name, param in named_params:
Expand Down Expand Up @@ -150,7 +165,7 @@ def _setup_broadcast_group(self):
def _colocated_model_update(self):
refs = []
for named_weights in gather_deepspeed_weights(
self.model, self.ds_config, buffer_size=self._model_update_buffer_size
self.model, self.ds_config, buffer_size=self._model_update_buffer_size, is_lora=self.is_lora
):
serialized_tensors = serialize_named_weights(
named_weights, infer_strategy=self.infer_worker_config.strategy_args.strategy_name
Expand All @@ -167,11 +182,16 @@ def _colocated_model_update(self):
ray.get(refs)
refs = []
if co_infer_rank == 0 and self._co_infer_worker is not None:
refs.append(self._co_infer_worker.update_parameter_in_bucket.remote(infer_parallel_tensors))
refs.append(
self._co_infer_worker.update_parameter_in_bucket.remote(
infer_parallel_tensors, is_lora=self.is_lora
)
)
if self._broadcast_workers:
refs.extend(self._broadcast_to_infer_workers(named_weights))
if refs:
ray.get(refs)
self._add_lora_to_infer_workers()
return {}

def _broadcast_to_infer_workers(self, named_weights) -> list[ray.ObjectRef]:
Expand All @@ -183,6 +203,7 @@ def _broadcast_to_infer_workers(self, named_weights) -> list[ray.ObjectRef]:
names=[n for n, _ in named_weights],
dtypes=[w.dtype for _, w in named_weights],
shapes=[w.shape for _, w in named_weights],
is_lora=self.is_lora,
)
for worker in self._broadcast_workers
]
Expand All @@ -198,8 +219,17 @@ def _broadcast_to_infer_workers(self, named_weights) -> list[ray.ObjectRef]:
def _separated_model_update(self):
logger.info(f"start broadcast model update {self.model_update_group_name}")
for named_weights in gather_deepspeed_weights(
self.model, self.ds_config, buffer_size=self._model_update_buffer_size
self.model, self.ds_config, buffer_size=self._model_update_buffer_size, is_lora=self.is_lora
):
refs = self._broadcast_to_infer_workers(named_weights)
ray.get(refs)
self._add_lora_to_infer_workers()
return {}

def _add_lora_to_infer_workers(self):
if dist.get_rank() != 0 or not self.is_lora:
return
peft_config = self.model.peft_config.get("default", None)
ray.get(
[worker.add_lora.remote(peft_config=asdict(peft_config)) for worker in self.model_update_infer_workers]
)
2 changes: 1 addition & 1 deletion roll/utils/deepspeed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_optimizer_grouped_parameters(
"weight_decay": 0.0,
},
]
return optimizer_grouped_parameters
return [group for group in optimizer_grouped_parameters if group["params"]]


def _z3_params_to_fetch(param_list):
Expand Down