Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 175 additions & 8 deletions nemo_rl/experience/rollouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import copy
import json
import math
import os
import statistics
from collections import defaultdict
from dataclasses import dataclass
Expand Down Expand Up @@ -221,9 +222,13 @@ async def generate_responses_async(
}
# Attach worker metadata if present (async vLLM path)
if "gen_leader_worker_idx" in generation_outputs:
# generation_outputs carries this as a 1-length list per row; convert to int
v = generation_outputs["gen_leader_worker_idx"][0]
try:
per_sample_worker_indices = [
int(worker_idx[0]) if isinstance(worker_idx, list) else int(worker_idx)
for worker_idx in generation_outputs["gen_leader_worker_idx"]
]
gen_metrics["gen_leader_worker_idx_per_sample"] = per_sample_worker_indices
gen_metrics["gen_leader_worker_idx"] = (
int(v[0]) if isinstance(v, list) else int(v)
)
Expand Down Expand Up @@ -663,10 +668,166 @@ async def async_generate_response_for_sample_turn(
return updated_message_log, generated_tokens, input_lengths, gen_metrics


@dataclass
class _AsyncGenerationRequest:
sample_message_log: list[dict]
sample_stop_strings: list[str] | None
response_future: asyncio.Future


class _AsyncGenerationBroker:
"""Batch and dispatch async generation requests from per-sample rollout coroutines."""

def __init__(
self,
policy_generation: GenerationInterface,
tokenizer: TokenizerType,
max_seq_len: int,
greedy: bool = False,
max_batch_size: Optional[int] = None,
) -> None:
self.policy_generation = policy_generation
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.greedy = greedy
self.max_batch_size = max_batch_size or int(
os.environ.get("NRL_ASYNC_MULTI_TURN_MAX_BATCH", "0") or 0
)
self._request_queue: asyncio.Queue[_AsyncGenerationRequest | None] = (
asyncio.Queue()
)
self._closed = False
self._driver_task = asyncio.create_task(self._run())

async def generate(
self,
sample_message_log: list[dict],
sample_stop_strings: list[str] | None,
) -> tuple[list[dict], torch.Tensor, torch.Tensor, dict[str, Any]]:
loop = asyncio.get_running_loop()
response_future: asyncio.Future = loop.create_future()
await self._request_queue.put(
_AsyncGenerationRequest(
sample_message_log=sample_message_log,
sample_stop_strings=sample_stop_strings,
response_future=response_future,
)
)
return await response_future

async def close(self) -> None:
if self._closed:
return
self._closed = True
await self._request_queue.put(None)
await self._driver_task

async def _run(self) -> None:
batch_requests: list[_AsyncGenerationRequest] = []
try:
while True:
first_request = await self._request_queue.get()
if first_request is None:
break

batch_requests = [first_request]
await asyncio.sleep(0)
while True:
if self.max_batch_size and len(batch_requests) >= self.max_batch_size:
break
try:
maybe_request = self._request_queue.get_nowait()
except asyncio.QueueEmpty:
break
if maybe_request is None:
self._closed = True
break
batch_requests.append(maybe_request)

batch_message_logs = [
request.sample_message_log for request in batch_requests
]
batch_stop_strings = [
request.sample_stop_strings for request in batch_requests
]

flat_messages, input_lengths = batched_message_log_to_flat_message(
batch_message_logs,
pad_value_dict={"token_ids": self.tokenizer.pad_token_id},
)

generation_input_data = BatchedDataDict[GenerationDatumSpec](
{
"input_ids": flat_messages["token_ids"],
"input_lengths": input_lengths,
"stop_strings": batch_stop_strings,
}
)
dummy_batch = BatchedDataDict[DatumSpec](
{
"message_log": batch_message_logs,
"stop_strings": batch_stop_strings,
}
)

updated_batch, generated_ids, gen_metrics = await generate_responses_async(
self.policy_generation,
generation_input_data,
dummy_batch,
self.tokenizer,
input_lengths=input_lengths,
include_logprobs=True,
greedy=self.greedy,
)

per_sample_worker_idxs = gen_metrics.get(
"gen_leader_worker_idx_per_sample", []
)
response_truncated = gen_metrics.get("_response_truncated")

for i, request in enumerate(batch_requests):
per_sample_metrics: dict[str, Any] = {}
if i < len(per_sample_worker_idxs):
per_sample_metrics["gen_leader_worker_idx"] = int(
per_sample_worker_idxs[i]
)
if response_truncated is not None:
per_sample_metrics["_response_truncated"] = response_truncated[
i : i + 1
]

request.response_future.set_result(
(
updated_batch["message_log"][i],
generated_ids[i],
input_lengths[i],
per_sample_metrics,
)
)

if self._closed and self._request_queue.empty():
break
except Exception as e:
for request in batch_requests:
if not request.response_future.done():
request.response_future.set_exception(e)
while True:
try:
maybe_request = self._request_queue.get_nowait()
except asyncio.QueueEmpty:
break
if maybe_request is None:
continue
if not maybe_request.response_future.done():
maybe_request.response_future.set_exception(e)
raise


async def run_sample_multi_turn_rollout(
sample_idx: int,
initial_sample_state: dict,
policy_generation: GenerationInterface,
generation_broker: _AsyncGenerationBroker,
tokenizer: TokenizerType,
task_to_env: dict[str, EnvironmentInterface],
max_seq_len: int,
Expand Down Expand Up @@ -731,13 +892,9 @@ async def run_sample_multi_turn_rollout(
generated_tokens,
input_lengths,
gen_metrics,
) = await async_generate_response_for_sample_turn(
policy_generation,
) = await generation_broker.generate(
current_message_log,
current_stop_strings,
tokenizer,
max_seq_len,
greedy=greedy,
)
current_message_log = updated_message_log

Expand Down Expand Up @@ -774,7 +931,7 @@ async def run_sample_multi_turn_rollout(
)

# Get environment feedback
env_output = calculate_rewards(sample_batch, task_to_env)
env_output = await asyncio.to_thread(calculate_rewards, sample_batch, task_to_env)
# Update total reward and optional per-reward signals (reward1, reward2, ... rewardN)
if env_output.rewards.ndim == 2 and env_output.rewards.shape[1] >= 1:
multi_reward_seen = True
Expand Down Expand Up @@ -891,6 +1048,12 @@ def run_async_multi_turn_rollout(
async def _async_rollout_implementation():
"""Internal async implementation."""
batch_size = len(input_batch["message_log"])
generation_broker = _AsyncGenerationBroker(
policy_generation=policy_generation,
tokenizer=tokenizer,
max_seq_len=max_seq_len,
greedy=greedy,
)

# Prepare initial states for each sample
sample_initial_states = []
Expand All @@ -912,6 +1075,7 @@ async def run_single_sample_with_error_handling(i, sample_state):
sample_idx=i,
initial_sample_state=sample_state,
policy_generation=policy_generation,
generation_broker=generation_broker,
tokenizer=tokenizer,
task_to_env=task_to_env,
max_seq_len=max_seq_len,
Expand All @@ -929,7 +1093,10 @@ async def run_single_sample_with_error_handling(i, sample_state):
]

# Execute all sample rollouts concurrently
sample_results = await asyncio.gather(*sample_tasks, return_exceptions=False)
try:
sample_results = await asyncio.gather(*sample_tasks, return_exceptions=False)
finally:
await generation_broker.close()

# Process results
final_sample_states = []
Expand Down
Loading
Loading