diff --git a/roll/configs/data_args.py b/roll/configs/data_args.py index 921ecd089..47ea816d4 100644 --- a/roll/configs/data_args.py +++ b/roll/configs/data_args.py @@ -20,6 +20,10 @@ class DataArguments: default=None, metadata={"help": "The number of processes to use for the preprocessing."}, ) + cache_path: Optional[str] = field( + default=None, + metadata={"help": "Path to cache preprocessed datasets."}, + ) file_name: Optional[Union[List[str], str]] = field( default=None, metadata={"help": "The name of file path name for train. Conflicts with `--dataset_name`"}, diff --git a/roll/distributed/scheduler/generate_scheduler.py b/roll/distributed/scheduler/generate_scheduler.py index 39f2417fc..26e93f664 100644 --- a/roll/distributed/scheduler/generate_scheduler.py +++ b/roll/distributed/scheduler/generate_scheduler.py @@ -473,12 +473,12 @@ def __init__( self.dataset = dataset self.indices = list(range(len(dataset))) - if state is not None and state.get("dataset_iter_count", 0) > 0: - for _ in range(state["dataset_iter_count"]): - self.get_next_dataset_item() self.dataset_epoch = 0 self.dataset_iter = None self.dataset_iter_count = 0 + if state is not None and state.get("dataset_iter_count", 0) > 0: + for _ in range(state["dataset_iter_count"]): + self.get_next_dataset_item() self.collect_fn_cls = collect_fn_cls self.collect_fn_kwargs = collect_fn_kwargs diff --git a/roll/distributed/scheduler/initialize.py b/roll/distributed/scheduler/initialize.py index 877e4ef18..7d4194d06 100644 --- a/roll/distributed/scheduler/initialize.py +++ b/roll/distributed/scheduler/initialize.py @@ -1,4 +1,5 @@ import os +import shlex import subprocess import sys import time @@ -36,12 +37,19 @@ def start_ray_cluster(): logger.info("Ray cluster already initialized") return False + temp_dir = os.environ.get("ROLL_RAY_TEMP_DIR") + temp_dir_arg = "" + if temp_dir: + os.makedirs(temp_dir, exist_ok=True) + temp_dir_arg = f" --temp-dir={shlex.quote(temp_dir)}" + if rank == 0: cmd = f"ray start --head --port={master_port} --node-name={node_name} --dashboard-port={dashboard_port}" else: # fix: 处理大规模下可能会出现的head/worker node创建顺序不一致问题 time.sleep(5) cmd = f"ray start --address={master_addr}:{master_port} --node-name={node_name} --dashboard-port={dashboard_port}" + cmd += temp_dir_arg logger.info(f"Starting ray cluster: {cmd}") ret = subprocess.run(cmd, shell=True, capture_output=True) diff --git a/roll/pipeline/rlvr/rlvr_vlm_pipeline.py b/roll/pipeline/rlvr/rlvr_vlm_pipeline.py index dc7af456c..7f01770b9 100644 --- a/roll/pipeline/rlvr/rlvr_vlm_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_vlm_pipeline.py @@ -1,5 +1,6 @@ import copy import json +import math import os import uuid from functools import partial @@ -224,22 +225,31 @@ def __init__(self, pipeline_config: RLVRConfig): dataset = get_vlm_dataset( self.pipeline_config.actor_train.data_args, encode_function, self.processor, get_eval=False ) - # update domain field, DynamicSamplingScheduler requires - dataset = dataset.map( - partial(update_dataset_domain, self.pipeline_config.tag_2_domain), - num_proc=self.pipeline_config.actor_train.data_args.preprocessing_num_workers, - desc="update_dataset_domain", - load_from_cache_file=False, - ) - + # Avoid rewriting large multimodal Arrow columns when a run has a single domain. + # Re-mapping PIL image columns can overflow pyarrow's regular array offsets. + domains = list(self.pipeline_config.actor_train.data_args.domain_interleave_probs.keys()) self.domain_datasets: Dict[str, datasets.Dataset] = {} - for domain in self.pipeline_config.actor_train.data_args.domain_interleave_probs.keys(): - self.domain_datasets[domain] = dataset.filter( - lambda example, dom: example["domain"] == dom, + if len(domains) == 1: + domain = domains[0] + if "domain" not in dataset.column_names: + dataset = dataset.add_column("domain", [domain] * len(dataset)) + self.domain_datasets[domain] = dataset + else: + # update domain field, DynamicSamplingScheduler requires + dataset = dataset.map( + partial(update_dataset_domain, self.pipeline_config.tag_2_domain), num_proc=self.pipeline_config.actor_train.data_args.preprocessing_num_workers, - fn_kwargs={"dom": domain}, + desc="update_dataset_domain", + load_from_cache_file=False, ) - assert len(self.domain_datasets[domain]) > 0, f"domain dataset {domain} has no data" + + for domain in domains: + self.domain_datasets[domain] = dataset.filter( + lambda example, dom: example["domain"] == dom, + num_proc=self.pipeline_config.actor_train.data_args.preprocessing_num_workers, + fn_kwargs={"dom": domain}, + ) + assert len(self.domain_datasets[domain]) > 0, f"domain dataset {domain} has no data" self.val_dataset = None if self.pipeline_config.validation and self.pipeline_config.validation.data_args: @@ -259,6 +269,17 @@ def __init__(self, pipeline_config: RLVRConfig): kl_horizon=self.pipeline_config.kl_horizon, ) + if self.pipeline_config.max_steps <= 0: + num_train_epochs = self.pipeline_config.actor_train.training_args.num_train_epochs + dataset_size = sum(len(domain_dataset) for domain_dataset in self.domain_datasets.values()) + inferred_max_steps = math.ceil(num_train_epochs * dataset_size / self.pipeline_config.rollout_batch_size) + logger.info( + "infer pipeline max_steps from dataset: " + f"num_train_epochs={num_train_epochs}, dataset_size={dataset_size}, " + f"rollout_batch_size={self.pipeline_config.rollout_batch_size}, " + f"max_steps={inferred_max_steps}" + ) + self.pipeline_config.max_steps = inferred_max_steps assert self.pipeline_config.max_steps > 0, "max_steps must be greater than 0" self.pipeline_config.set_max_steps(max_steps=self.pipeline_config.max_steps) diff --git a/roll/third_party/deepspeed/model_update.py b/roll/third_party/deepspeed/model_update.py index b6452902c..61caf0c2e 100644 --- a/roll/third_party/deepspeed/model_update.py +++ b/roll/third_party/deepspeed/model_update.py @@ -1,3 +1,5 @@ +from dataclasses import asdict + import ray import torch.distributed as dist from deepspeed.runtime.zero import GatheredParameters @@ -29,9 +31,22 @@ def _gather_weights(is_zero3, named_params): return [(n, p.data) for n, p in named_params] -def gather_deepspeed_weights(model, ds_config, buffer_size): +def _get_deepspeed_named_params(model, ds_config, is_lora=False): + if not is_lora: + return [(name, param) for name, param in model.named_parameters()] + + if not ds_config.is_zero3(): + return [(name, param) for name, param in get_peft_model_state_dict(model).items()] + + adapter_name = "default" + state_dict = model.state_dict() + lora_state_dict = {k: state_dict[k] for k in state_dict if ("lora_" in k and adapter_name in k)} + return [(name.replace(f".{adapter_name}", ""), model.get_parameter(name)) for name in lora_state_dict] + + +def gather_deepspeed_weights(model, ds_config, buffer_size, is_lora=False): is_zero3 = ds_config.is_zero3() - named_params = [(name, param) for name, param in model.named_parameters()] + named_params = _get_deepspeed_named_params(model, ds_config, is_lora=is_lora) waiting_params, waiting_params_size = [], 0 for name, param in named_params: @@ -150,7 +165,7 @@ def _setup_broadcast_group(self): def _colocated_model_update(self): refs = [] for named_weights in gather_deepspeed_weights( - self.model, self.ds_config, buffer_size=self._model_update_buffer_size + self.model, self.ds_config, buffer_size=self._model_update_buffer_size, is_lora=self.is_lora ): serialized_tensors = serialize_named_weights( named_weights, infer_strategy=self.infer_worker_config.strategy_args.strategy_name @@ -167,11 +182,16 @@ def _colocated_model_update(self): ray.get(refs) refs = [] if co_infer_rank == 0 and self._co_infer_worker is not None: - refs.append(self._co_infer_worker.update_parameter_in_bucket.remote(infer_parallel_tensors)) + refs.append( + self._co_infer_worker.update_parameter_in_bucket.remote( + infer_parallel_tensors, is_lora=self.is_lora + ) + ) if self._broadcast_workers: refs.extend(self._broadcast_to_infer_workers(named_weights)) if refs: ray.get(refs) + self._add_lora_to_infer_workers() return {} def _broadcast_to_infer_workers(self, named_weights) -> list[ray.ObjectRef]: @@ -183,6 +203,7 @@ def _broadcast_to_infer_workers(self, named_weights) -> list[ray.ObjectRef]: names=[n for n, _ in named_weights], dtypes=[w.dtype for _, w in named_weights], shapes=[w.shape for _, w in named_weights], + is_lora=self.is_lora, ) for worker in self._broadcast_workers ] @@ -198,8 +219,17 @@ def _broadcast_to_infer_workers(self, named_weights) -> list[ray.ObjectRef]: def _separated_model_update(self): logger.info(f"start broadcast model update {self.model_update_group_name}") for named_weights in gather_deepspeed_weights( - self.model, self.ds_config, buffer_size=self._model_update_buffer_size + self.model, self.ds_config, buffer_size=self._model_update_buffer_size, is_lora=self.is_lora ): refs = self._broadcast_to_infer_workers(named_weights) ray.get(refs) + self._add_lora_to_infer_workers() return {} + + def _add_lora_to_infer_workers(self): + if dist.get_rank() != 0 or not self.is_lora: + return + peft_config = self.model.peft_config.get("default", None) + ray.get( + [worker.add_lora.remote(peft_config=asdict(peft_config)) for worker in self.model_update_infer_workers] + ) diff --git a/roll/utils/deepspeed_utils.py b/roll/utils/deepspeed_utils.py index fe7887748..2fd75d4ad 100644 --- a/roll/utils/deepspeed_utils.py +++ b/roll/utils/deepspeed_utils.py @@ -36,7 +36,7 @@ def get_optimizer_grouped_parameters( "weight_decay": 0.0, }, ] - return optimizer_grouped_parameters + return [group for group in optimizer_grouped_parameters if group["params"]] def _z3_params_to_fetch(param_list):