diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py index d5186e868a..05d46568a1 100644 --- a/nemo_rl/experience/rollouts.py +++ b/nemo_rl/experience/rollouts.py @@ -19,6 +19,7 @@ import copy import json import math +import os import statistics from collections import defaultdict from dataclasses import dataclass @@ -221,9 +222,13 @@ async def generate_responses_async( } # Attach worker metadata if present (async vLLM path) if "gen_leader_worker_idx" in generation_outputs: - # generation_outputs carries this as a 1-length list per row; convert to int v = generation_outputs["gen_leader_worker_idx"][0] try: + per_sample_worker_indices = [ + int(worker_idx[0]) if isinstance(worker_idx, list) else int(worker_idx) + for worker_idx in generation_outputs["gen_leader_worker_idx"] + ] + gen_metrics["gen_leader_worker_idx_per_sample"] = per_sample_worker_indices gen_metrics["gen_leader_worker_idx"] = ( int(v[0]) if isinstance(v, list) else int(v) ) @@ -663,10 +668,166 @@ async def async_generate_response_for_sample_turn( return updated_message_log, generated_tokens, input_lengths, gen_metrics +@dataclass +class _AsyncGenerationRequest: + sample_message_log: list[dict] + sample_stop_strings: list[str] | None + response_future: asyncio.Future + + +class _AsyncGenerationBroker: + """Batch and dispatch async generation requests from per-sample rollout coroutines.""" + + def __init__( + self, + policy_generation: GenerationInterface, + tokenizer: TokenizerType, + max_seq_len: int, + greedy: bool = False, + max_batch_size: Optional[int] = None, + ) -> None: + self.policy_generation = policy_generation + self.tokenizer = tokenizer + self.max_seq_len = max_seq_len + self.greedy = greedy + self.max_batch_size = max_batch_size or int( + os.environ.get("NRL_ASYNC_MULTI_TURN_MAX_BATCH", "0") or 0 + ) + self._request_queue: asyncio.Queue[_AsyncGenerationRequest | None] = ( + asyncio.Queue() + ) + self._closed = False + self._driver_task = asyncio.create_task(self._run()) + + async def generate( + self, + sample_message_log: list[dict], + sample_stop_strings: list[str] | None, + ) -> tuple[list[dict], torch.Tensor, torch.Tensor, dict[str, Any]]: + loop = asyncio.get_running_loop() + response_future: asyncio.Future = loop.create_future() + await self._request_queue.put( + _AsyncGenerationRequest( + sample_message_log=sample_message_log, + sample_stop_strings=sample_stop_strings, + response_future=response_future, + ) + ) + return await response_future + + async def close(self) -> None: + if self._closed: + return + self._closed = True + await self._request_queue.put(None) + await self._driver_task + + async def _run(self) -> None: + batch_requests: list[_AsyncGenerationRequest] = [] + try: + while True: + first_request = await self._request_queue.get() + if first_request is None: + break + + batch_requests = [first_request] + await asyncio.sleep(0) + while True: + if self.max_batch_size and len(batch_requests) >= self.max_batch_size: + break + try: + maybe_request = self._request_queue.get_nowait() + except asyncio.QueueEmpty: + break + if maybe_request is None: + self._closed = True + break + batch_requests.append(maybe_request) + + batch_message_logs = [ + request.sample_message_log for request in batch_requests + ] + batch_stop_strings = [ + request.sample_stop_strings for request in batch_requests + ] + + flat_messages, input_lengths = batched_message_log_to_flat_message( + batch_message_logs, + pad_value_dict={"token_ids": self.tokenizer.pad_token_id}, + ) + + generation_input_data = BatchedDataDict[GenerationDatumSpec]( + { + "input_ids": flat_messages["token_ids"], + "input_lengths": input_lengths, + "stop_strings": batch_stop_strings, + } + ) + dummy_batch = BatchedDataDict[DatumSpec]( + { + "message_log": batch_message_logs, + "stop_strings": batch_stop_strings, + } + ) + + updated_batch, generated_ids, gen_metrics = await generate_responses_async( + self.policy_generation, + generation_input_data, + dummy_batch, + self.tokenizer, + input_lengths=input_lengths, + include_logprobs=True, + greedy=self.greedy, + ) + + per_sample_worker_idxs = gen_metrics.get( + "gen_leader_worker_idx_per_sample", [] + ) + response_truncated = gen_metrics.get("_response_truncated") + + for i, request in enumerate(batch_requests): + per_sample_metrics: dict[str, Any] = {} + if i < len(per_sample_worker_idxs): + per_sample_metrics["gen_leader_worker_idx"] = int( + per_sample_worker_idxs[i] + ) + if response_truncated is not None: + per_sample_metrics["_response_truncated"] = response_truncated[ + i : i + 1 + ] + + request.response_future.set_result( + ( + updated_batch["message_log"][i], + generated_ids[i], + input_lengths[i], + per_sample_metrics, + ) + ) + + if self._closed and self._request_queue.empty(): + break + except Exception as e: + for request in batch_requests: + if not request.response_future.done(): + request.response_future.set_exception(e) + while True: + try: + maybe_request = self._request_queue.get_nowait() + except asyncio.QueueEmpty: + break + if maybe_request is None: + continue + if not maybe_request.response_future.done(): + maybe_request.response_future.set_exception(e) + raise + + async def run_sample_multi_turn_rollout( sample_idx: int, initial_sample_state: dict, policy_generation: GenerationInterface, + generation_broker: _AsyncGenerationBroker, tokenizer: TokenizerType, task_to_env: dict[str, EnvironmentInterface], max_seq_len: int, @@ -731,13 +892,9 @@ async def run_sample_multi_turn_rollout( generated_tokens, input_lengths, gen_metrics, - ) = await async_generate_response_for_sample_turn( - policy_generation, + ) = await generation_broker.generate( current_message_log, current_stop_strings, - tokenizer, - max_seq_len, - greedy=greedy, ) current_message_log = updated_message_log @@ -774,7 +931,7 @@ async def run_sample_multi_turn_rollout( ) # Get environment feedback - env_output = calculate_rewards(sample_batch, task_to_env) + env_output = await asyncio.to_thread(calculate_rewards, sample_batch, task_to_env) # Update total reward and optional per-reward signals (reward1, reward2, ... rewardN) if env_output.rewards.ndim == 2 and env_output.rewards.shape[1] >= 1: multi_reward_seen = True @@ -891,6 +1048,12 @@ def run_async_multi_turn_rollout( async def _async_rollout_implementation(): """Internal async implementation.""" batch_size = len(input_batch["message_log"]) + generation_broker = _AsyncGenerationBroker( + policy_generation=policy_generation, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + greedy=greedy, + ) # Prepare initial states for each sample sample_initial_states = [] @@ -912,6 +1075,7 @@ async def run_single_sample_with_error_handling(i, sample_state): sample_idx=i, initial_sample_state=sample_state, policy_generation=policy_generation, + generation_broker=generation_broker, tokenizer=tokenizer, task_to_env=task_to_env, max_seq_len=max_seq_len, @@ -929,7 +1093,10 @@ async def run_single_sample_with_error_handling(i, sample_state): ] # Execute all sample rollouts concurrently - sample_results = await asyncio.gather(*sample_tasks, return_exceptions=False) + try: + sample_results = await asyncio.gather(*sample_tasks, return_exceptions=False) + finally: + await generation_broker.close() # Process results final_sample_states = [] diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index 0faaad17a1..977066677a 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -587,37 +587,22 @@ async def _async_generate_base( if not data_validation_fn(data): return - # Determine the leader worker for the current data parallel shard - leader_worker_idx = self.worker_group.get_dp_leader_worker_idx( - self.current_generate_dp_shard_idx - ) - - # Run the async method on the selected leader worker - worker_gen_proxy = self.worker_group.run_single_worker_single_data( - method_name=method_name, - worker_idx=leader_worker_idx, - data=data, - greedy=greedy, - ) + total_batch_size = len(next(iter(data.values()))) - # Increment the round-robin worker group index - self.current_generate_dp_shard_idx += 1 - self.current_generate_dp_shard_idx %= self.worker_group.dp_size - - # Create a queue to collect sample results from the worker as they complete + # Create a queue to collect sample results from workers as they complete. result_queue = asyncio.Queue() - finished = False + finished_workers = 0 - async def consume_worker_generator(worker_idx, worker_gen): - """Consume a single worker generator and put sample results in the queue.""" - nonlocal finished + async def consume_worker_generator(worker_idx, worker_gen, original_idx_offset): + """Consume a worker generator and put sample results in the queue.""" worker_name = f"Worker-{worker_idx}" try: async for sample_result_ref in worker_gen: sample_result = await sample_result_ref # sample_result is a tuple: (original_idx, BatchedDataDict) # Tag the result with worker index for downstream attribution - original_idx, result_batch = sample_result + local_original_idx, result_batch = sample_result + original_idx = original_idx_offset + int(local_original_idx) # Use a length-one list so BatchedDataDict.from_batches can merge without shape errors result_batch["gen_leader_worker_idx"] = [int(worker_idx)] sample_result = (original_idx, result_batch) @@ -630,20 +615,64 @@ async def consume_worker_generator(worker_idx, worker_gen): traceback.print_exc() await result_queue.put(("error", e)) finally: - finished = True - await result_queue.put(("worker_done", None)) + await result_queue.put(("worker_done", worker_idx)) - # Start the task to consume the worker generator - worker_task = asyncio.create_task( - consume_worker_generator(leader_worker_idx, worker_gen_proxy) - ) + worker_tasks: list[asyncio.Task] = [] + + if total_batch_size == 1: + # For single-sample requests, keep round-robin dispatch across DP leaders so + # concurrent callers naturally spread work across all inference engines. + leader_worker_idx = self.worker_group.get_dp_leader_worker_idx( + self.current_generate_dp_shard_idx + ) + worker_gen_proxy = self.worker_group.run_single_worker_single_data( + method_name=method_name, + worker_idx=leader_worker_idx, + data=data, + greedy=greedy, + ) + self.current_generate_dp_shard_idx += 1 + self.current_generate_dp_shard_idx %= self.worker_group.dp_size + worker_tasks.append( + asyncio.create_task( + consume_worker_generator(leader_worker_idx, worker_gen_proxy, 0) + ) + ) + else: + # Shard the batch across all DP leaders, matching the sync generate path. + dp_size = self.sharding_annotations.get_axis_size("data_parallel") + sharded_data: list[SlicedDataDict] = data.shard_by_batch_size( + dp_size, allow_uneven_shards=True + ) + shard_offset = 0 + for dp_shard_idx, shard in enumerate(sharded_data): + shard_batch_size = len(next(iter(shard.values()))) if len(shard) > 0 else 0 + if shard_batch_size == 0: + continue + leader_worker_idx = self.worker_group.get_dp_leader_worker_idx( + dp_shard_idx + ) + worker_gen_proxy = self.worker_group.run_single_worker_single_data( + method_name=method_name, + worker_idx=leader_worker_idx, + data=shard, + greedy=greedy, + ) + worker_tasks.append( + asyncio.create_task( + consume_worker_generator( + leader_worker_idx, worker_gen_proxy, shard_offset + ) + ) + ) + shard_offset += shard_batch_size # Yield sample results as they become available from the worker timeout_seconds = float( os.environ.get("NRL_VLLM_ASYNC_TIMEOUT_SECONDS", "600") ) # Default 10 minutes - while not finished: + while finished_workers < len(worker_tasks): try: msg_type, item = await asyncio.wait_for( result_queue.get(), timeout=timeout_seconds @@ -655,10 +684,10 @@ async def consume_worker_generator(worker_idx, worker_gen): print( f"For longer sequences, increase the timeout by setting: export NRL_VLLM_ASYNC_TIMEOUT_SECONDS={int(timeout_seconds * 2)}" ) - # Cancel the task - if not worker_task.done(): - worker_task.cancel() - await asyncio.gather(worker_task, return_exceptions=True) + for worker_task in worker_tasks: + if not worker_task.done(): + worker_task.cancel() + await asyncio.gather(*worker_tasks, return_exceptions=True) raise RuntimeError( f"Timeout waiting for worker results after {timeout_seconds}s. " f"For longer sequences, increase timeout by setting: export NRL_VLLM_ASYNC_TIMEOUT_SECONDS={int(timeout_seconds * 2)}" @@ -669,20 +698,18 @@ async def consume_worker_generator(worker_idx, worker_gen): yield item elif msg_type == "error": # Cancel the task and propagate error - if not worker_task.done(): - worker_task.cancel() - await asyncio.gather(worker_task, return_exceptions=True) + for worker_task in worker_tasks: + if not worker_task.done(): + worker_task.cancel() + await asyncio.gather(*worker_tasks, return_exceptions=True) raise item elif msg_type == "worker_done": - # Worker finished, just continue the loop - pass + finished_workers += 1 else: raise RuntimeError(f"Unexpected message type: {msg_type}") - # Verify the task is actually done - assert worker_task.done(), ( - f"Worker task {leader_worker_idx} should be done but isn't" - ) + for worker_task in worker_tasks: + assert worker_task.done(), "A worker task should be done but isn't" async def generate_text_async( self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index 91d4d9a292..e62fa43e92 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -711,12 +711,6 @@ async def generate_async( input_lengths_batch = data["input_lengths"] batch_size = input_ids_batch.shape[0] - # Ensure generate_async only receives single samples (batch_size = 1) - assert batch_size == 1, ( - f"generate_async is restricted to handle only single samples, " - f"but received batch_size={batch_size}. Please handle batching outside this method." - ) - batch_specific_stop_strings_list = data.get( "stop_strings", [[] for _ in range(batch_size)] )