diff --git a/nemo_rl/algorithms/advantage_estimator.py b/nemo_rl/algorithms/advantage_estimator.py index 6a7288ff52..2927647850 100644 --- a/nemo_rl/algorithms/advantage_estimator.py +++ b/nemo_rl/algorithms/advantage_estimator.py @@ -32,6 +32,26 @@ ) +def _get_sample_valid_mask( + sample_valid_mask: torch.Tensor | None, + reference: torch.Tensor, +) -> torch.Tensor: + """Normalize optional sample validity info to a boolean tensor.""" + if sample_valid_mask is None: + return torch.ones_like(reference, dtype=torch.bool) + normalized_mask = sample_valid_mask.to( + device=reference.device, + dtype=torch.bool, + ).reshape(-1) + if normalized_mask.shape[0] != reference.shape[0]: + raise ValueError( + "sample_valid_mask must have one element per sample; " + f"got {normalized_mask.shape[0]} values for " + f"{reference.shape[0]} samples" + ) + return normalized_mask + + class GRPOAdvantageEstimator: """GRPO-style advantage estimator with leave-one-out baseline. @@ -42,7 +62,9 @@ def __init__(self, estimator_config: dict, loss_config: dict): self.use_leave_one_out_baseline = estimator_config["use_leave_one_out_baseline"] self.normalize_rewards = estimator_config["normalize_rewards"] - def compute_advantage(self, prompt_ids, rewards, mask, **kwargs): + def compute_advantage( + self, prompt_ids, rewards, mask, sample_valid_mask=None, **kwargs + ): """Compute GRPO advantages. Args: @@ -55,13 +77,15 @@ def compute_advantage(self, prompt_ids, rewards, mask, **kwargs): Returns: Advantages tensor of shape [batch_size, seq_len]. """ + sample_valid_mask = _get_sample_valid_mask(sample_valid_mask, rewards) baseline, std = calculate_baseline_and_std_per_prompt( prompt_ids, rewards, - torch.ones_like(rewards), + sample_valid_mask, leave_one_out_baseline=self.use_leave_one_out_baseline, ) advantages = (rewards - baseline).unsqueeze(-1) + advantages = advantages * sample_valid_mask.unsqueeze(-1).float() if self.normalize_rewards: # don't sharpen the ones with no variation @@ -90,6 +114,7 @@ def compute_advantage( rewards, mask, repeated_batch, + sample_valid_mask=None, **kwargs, ): """Compute GDPO advantages. @@ -111,7 +136,10 @@ def compute_advantage( f"This batch has {len(reward_component_keys)} component(s). " "Switch to GRPO by setting grpo.adv_estimator.name to 'grpo' in your config." ) - valid = torch.ones_like(repeated_batch[reward_component_keys[0]]) + valid = _get_sample_valid_mask( + sample_valid_mask, + repeated_batch[reward_component_keys[0]], + ) leave_one_out = self.use_leave_one_out_baseline assert prompt_ids.shape[0] == valid.shape[0], ( "prompt_ids must match reward batch size; " @@ -137,12 +165,17 @@ def compute_advantage( advantage_parts.append(adv_k) advantages = sum(advantage_parts) - # Normalize combined advantage to zero mean and unit std - adv_std = advantages.std() - if adv_std > 0: - advantages = (advantages - advantages.mean()) / adv_std + valid_advantages = advantages[valid] + if valid_advantages.numel() <= 1: + advantages = torch.zeros_like(advantages) else: - advantages = advantages - advantages.mean() + # Normalize combined advantage to zero mean and unit std using only valid samples. + adv_mean = valid_advantages.mean() + adv_std = valid_advantages.std() + advantages = advantages - adv_mean + if adv_std > 0: + advantages[valid] = advantages[valid] / adv_std + advantages[~valid] = 0.0 return advantages.expand(mask.shape) @@ -166,6 +199,7 @@ def compute_advantage( prompt_ids, rewards, mask, + sample_valid_mask=None, logprobs_policy=None, logprobs_reference=None, **kwargs, @@ -185,18 +219,25 @@ def compute_advantage( Returns: Advantages tensor of shape [batch_size, seq_len], globally normalized across valid tokens. """ + sample_valid_mask = _get_sample_valid_mask(sample_valid_mask, rewards) + sample_valid_token_mask = sample_valid_mask.unsqueeze(-1).to( + device=mask.device, + dtype=mask.dtype, + ) + effective_mask = mask * sample_valid_token_mask # minus baseline if self.minus_baseline: mean, _ = calculate_baseline_and_std_per_prompt( prompt_ids, rewards, - torch.ones_like(rewards), + sample_valid_mask, leave_one_out_baseline=False, ) adv = rewards - mean else: adv = rewards + adv = adv * sample_valid_mask.float() adv = adv.unsqueeze(-1) adv = adv.expand(mask.shape) @@ -212,11 +253,17 @@ def compute_advantage( kl_type=self.kl_type, ) adv = adv - self.kl_coef * kl + adv = adv * sample_valid_token_mask # global normalization across the batch - adv_mean = (adv * mask).sum() / mask.sum() - adv_var = ((adv - adv_mean).pow(2) * mask).sum() / mask.sum() + if effective_mask.sum() == 0: + return torch.zeros_like(adv) + adv_mean = (adv * effective_mask).sum() / effective_mask.sum() + adv_var = ( + ((adv - adv_mean).pow(2) * effective_mask).sum() / effective_mask.sum() + ) adv_rstd = adv_var.clamp(min=1e-8).rsqrt() adv = (adv - adv_mean) * adv_rstd + adv = adv * sample_valid_token_mask return adv diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 02e43ae659..9d71305a82 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -975,6 +975,95 @@ def _scale(reward_tensor: torch.Tensor) -> torch.Tensor: return repeated_batch +def _get_batch_loss_multiplier( + repeated_batch: BatchedDataDict[DatumSpec], + batch_size: int, + device: torch.device | None = None, +) -> torch.Tensor: + """Normalize batch loss multipliers to a float tensor.""" + loss_multiplier = repeated_batch["loss_multiplier"] + if isinstance(loss_multiplier, torch.Tensor): + loss_multiplier_tensor = loss_multiplier.to(device=device, dtype=torch.float32) + else: + loss_multiplier_tensor = torch.tensor( + loss_multiplier, + dtype=torch.float32, + device=device, + ) + loss_multiplier_tensor = loss_multiplier_tensor.reshape(-1) + if loss_multiplier_tensor.shape[0] != batch_size: + raise ValueError( + "loss_multiplier must have one element per sample; " + f"got {loss_multiplier_tensor.shape[0]} values for batch size {batch_size}" + ) + return loss_multiplier_tensor + + +def _get_reward_valid_mask( + repeated_batch: BatchedDataDict[DatumSpec], + rewards: torch.Tensor, +) -> torch.Tensor: + """Normalize optional reward-validity info to a boolean tensor.""" + reward_valid_mask = repeated_batch.get("reward_valid_mask") + if reward_valid_mask is None: + return torch.ones(rewards.shape[0], dtype=torch.bool, device=rewards.device) + if isinstance(reward_valid_mask, torch.Tensor): + reward_valid_mask_tensor = reward_valid_mask.to( + device=rewards.device, + dtype=torch.bool, + ) + else: + reward_valid_mask_tensor = torch.tensor( + reward_valid_mask, + dtype=torch.bool, + device=rewards.device, + ) + reward_valid_mask_tensor = reward_valid_mask_tensor.reshape(-1) + if reward_valid_mask_tensor.shape[0] != rewards.shape[0]: + raise ValueError( + "reward_valid_mask must have one element per reward; " + f"got {reward_valid_mask_tensor.shape[0]} values for " + f"{rewards.shape[0]} rewards" + ) + return reward_valid_mask_tensor + + +def _get_stats_valid_mask( + repeated_batch: BatchedDataDict[DatumSpec], + rewards: torch.Tensor, +) -> torch.Tensor: + """Get the sample-level validity mask used for reward statistics.""" + loss_multiplier = _get_batch_loss_multiplier( + repeated_batch, + rewards.shape[0], + device=rewards.device, + ) + reward_valid_mask = _get_reward_valid_mask(repeated_batch, rewards) + return (loss_multiplier > 0) & reward_valid_mask + + +def _get_reward_validity_metrics( + rewards: torch.Tensor, + reward_valid_mask: torch.Tensor, +) -> dict[str, float | int]: + """Compute basic invalid-reward metrics for logging.""" + if reward_valid_mask.numel() == 0: + return { + "invalid_reward_count": 0, + "invalid_reward_rate": 0.0, + "valid_reward_mean": 0.0, + } + + valid_rewards = rewards[reward_valid_mask] + return { + "invalid_reward_count": int((~reward_valid_mask).sum().item()), + "invalid_reward_rate": float((~reward_valid_mask).float().mean().item()), + "valid_reward_mean": float(valid_rewards.mean().item()) + if valid_rewards.numel() > 0 + else 0.0, + } + + def _should_use_async_rollouts(master_config: MasterConfig) -> bool: """Determine if async rollouts should be used based on the configuration. @@ -1593,6 +1682,7 @@ def grpo_train( with timer.time("reward_calculation"): # Extract rewards from final_batch rewards = repeated_batch["total_reward"] + stats_valid_mask = _get_stats_valid_mask(repeated_batch, rewards) print("▶ Computing advantages...", flush=True) if master_config["grpo"].get("calculate_advantages_on_gpu"): @@ -1602,7 +1692,7 @@ def grpo_train( baseline, std = calculate_baseline_and_std_per_prompt( input_ids.cuda(device_id), rewards.cuda(device_id), - torch.ones_like(rewards).cuda(device_id), + stats_valid_mask.cuda(device_id), leave_one_out_baseline=master_config["grpo"][ "use_leave_one_out_baseline" ], @@ -1613,7 +1703,7 @@ def grpo_train( baseline, std = calculate_baseline_and_std_per_prompt( input_ids, rewards, - torch.ones_like(rewards), + stats_valid_mask, leave_one_out_baseline=master_config["grpo"][ "use_leave_one_out_baseline" ], @@ -1641,6 +1731,11 @@ def grpo_train( if not master_config["grpo"]["use_dynamic_sampling"] else repeated_batch["filtered_reward"] ) + reward_valid_mask = _get_reward_valid_mask(repeated_batch, rewards) + stats_valid_mask = _get_stats_valid_mask(repeated_batch, rewards) + reward_validity_metrics = _get_reward_validity_metrics( + rewards, reward_valid_mask + ) baseline = repeated_batch["baseline"] std = repeated_batch["std"] @@ -1671,16 +1766,20 @@ def grpo_train( del std with timer.time("data_processing"): + loss_multiplier = _get_batch_loss_multiplier( + repeated_batch, + len(repeated_batch["message_log"]), + ) + loss_multiplier = loss_multiplier * reward_valid_mask.float() use_overlong_filtering = master_config["grpo"]["overlong_filtering"] if use_overlong_filtering: - loss_multiplier = repeated_batch["loss_multiplier"].clone() truncated = repeated_batch["truncated"] if isinstance(truncated, list): truncated = torch.tensor(truncated, dtype=torch.bool) loss_multiplier[truncated] = 0 - repeated_batch["loss_multiplier"] = loss_multiplier + repeated_batch["loss_multiplier"] = loss_multiplier # Add loss mask to each message in LLMMessageLogType for i, message_log in enumerate(repeated_batch["message_log"]): for j, message in enumerate(message_log): @@ -1740,7 +1839,7 @@ def grpo_train( "input_ids": train_data["input_ids"], "input_lengths": train_data["input_lengths"], "token_mask": flat_messages["token_loss_mask"], - "sample_mask": repeated_batch["loss_multiplier"], + "sample_mask": train_data["sample_mask"], **extra_multimodal_data, } ) @@ -1785,6 +1884,7 @@ def grpo_train( prompt_ids=prompt_ids_for_adv, rewards=rewards, mask=mask, + sample_valid_mask=stats_valid_mask, repeated_batch=repeated_batch, logprobs_policy=train_data["prev_logprobs"], logprobs_reference=train_data.get("reference_policy_logprobs"), @@ -1896,6 +1996,7 @@ def grpo_train( "advantages/min": torch.min(response_advantages).detach().item() if response_advantages.numel() > 0 else 0.0, + **reward_validity_metrics, **ds_metrics, } if "moe_metrics" in train_results: @@ -1929,6 +2030,8 @@ def grpo_train( "mean_prompt_length", }: metrics[k] = np.mean(v).item() + elif isinstance(v, (float, int, np.floating, np.integer)): + metrics[k] = float(v) elif isinstance(v, (np.ndarray, list)): metrics[k] = np.sum(v).item() else: @@ -2059,6 +2162,7 @@ def grpo_train( log_data["token_ids"] = train_data["input_ids"].tolist() log_data["token_loss_mask"] = train_data["token_mask"].tolist() log_data["sample_loss_mask"] = train_data["sample_mask"].tolist() + log_data["reward_valid_mask"] = reward_valid_mask.tolist() log_data["advantages"] = train_data["advantages"].tolist() log_data["generation_logprobs"] = train_data[ "generation_logprobs" @@ -2128,6 +2232,10 @@ def grpo_train( ) else: print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}") + print( + f" • Invalid Reward Rate: {metrics['invalid_reward_rate']:.4f} ({int(metrics['invalid_reward_count'])} samples)" + ) + print(f" • Avg Valid Reward: {metrics['valid_reward_mean']:.4f}") print( f" • Mean Generation Length: {metrics_logging_data['mean_gen_tokens_per_sample']:.4f}", flush=True, @@ -2232,6 +2340,7 @@ def validate( print(f"▶ Starting validation at step {step}...", flush=True) total_rewards = [] + total_reward_valid_masks = [] total_lengths = [] all_message_logs = [] # Collect all message logs @@ -2284,6 +2393,11 @@ def validate( ) total_rewards.extend(val_batch["total_reward"].tolist()) + total_reward_valid_masks.extend( + _get_reward_valid_mask(val_batch, val_batch["total_reward"]) + .cpu() + .tolist() + ) total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"]) # Collect message logs for later display @@ -2300,9 +2414,25 @@ def validate( num_samples = len(total_rewards) if num_samples > 0: rewards_t = torch.tensor(total_rewards, dtype=torch.float32) + reward_valid_mask_t = torch.tensor( + total_reward_valid_masks, dtype=torch.bool + ) accuracy = rewards_t.mean().item() + invalid_reward_count = int((~reward_valid_mask_t).sum().item()) + invalid_reward_rate = float( + (~reward_valid_mask_t).float().mean().item() + ) + if reward_valid_mask_t.any().item(): + reward_valid_only_mean = float( + rewards_t[reward_valid_mask_t].mean().item() + ) + else: + reward_valid_only_mean = 0.0 else: accuracy = 0.0 + invalid_reward_count = 0 + invalid_reward_rate = 0.0 + reward_valid_only_mean = 0.0 avg_length = ( sum(total_lengths) / len(total_lengths) if len(total_lengths) > 0 else 0.0 @@ -2311,6 +2441,9 @@ def validate( val_metrics = { "accuracy": accuracy, "avg_length": avg_length, + "invalid_reward_count": invalid_reward_count, + "invalid_reward_rate": invalid_reward_rate, + "reward_valid_only_mean": reward_valid_only_mean, **additional_metrics_to_report, } @@ -2336,6 +2469,10 @@ def validate( # Print summary of validation results print("\n📊 Validation Results:") print(f" • Accuracy: {accuracy:.4f}") + print( + f" • Invalid Reward Rate: {invalid_reward_rate:.4f} ({invalid_reward_count} samples)" + ) + print(f" • Avg Valid Reward: {reward_valid_only_mean:.4f}") print(f" • Average response length: {avg_length:.1f} tokens") print(f" • Samples processed: {len(total_rewards)}", flush=True) @@ -2349,6 +2486,7 @@ def validate( val_log_data = { "content": all_message_logs, "rewards": total_rewards, + "reward_valid_mask": total_reward_valid_masks, } logger.log_batched_dict_as_jsonl(val_log_data, f"val_data_step{step}.jsonl") @@ -2735,6 +2873,13 @@ def async_grpo_train( del prompt_batched_flat rewards = repeated_batch["total_reward"] + reward_valid_mask = _get_reward_valid_mask( + repeated_batch, rewards + ) + stats_valid_mask = _get_stats_valid_mask(repeated_batch, rewards) + reward_validity_metrics = _get_reward_validity_metrics( + repeated_batch["total_reward"], reward_valid_mask + ) print( f" 📊 Rewards stats: min={rewards.min():.4f}, max={rewards.max():.4f}, mean={rewards.mean():.4f}, std={rewards.std():.4f}" @@ -2742,6 +2887,13 @@ def async_grpo_train( # Prepare training data (same as sync version) with timer.time("data_processing"): + loss_multiplier = _get_batch_loss_multiplier( + repeated_batch, + len(repeated_batch["message_log"]), + ) + repeated_batch["loss_multiplier"] = ( + loss_multiplier * reward_valid_mask.float() + ) # Add loss mask to each message for i, message_log in enumerate(repeated_batch["message_log"]): for j, message in enumerate(message_log): @@ -2822,6 +2974,7 @@ def async_grpo_train( prompt_ids=prompt_ids_for_adv, rewards=rewards, mask=mask, + sample_valid_mask=stats_valid_mask, repeated_batch=repeated_batch, logprobs_policy=train_data["prev_logprobs"], logprobs_reference=train_data.get("reference_policy_logprobs"), @@ -2952,6 +3105,7 @@ def async_grpo_train( "advantages/min": torch.min(response_advantages).detach().item() if response_advantages.numel() > 0 else 0.0, + **reward_validity_metrics, } if "moe_metrics" in train_results: metrics.update( @@ -2978,6 +3132,8 @@ def async_grpo_train( "mean_prompt_length", }: metrics[k] = np.mean(v).item() + elif isinstance(v, (float, int, np.floating, np.integer)): + metrics[k] = float(v) else: metrics[k] = np.sum(v).item() metrics.update(rollout_metrics) @@ -3086,6 +3242,7 @@ def async_grpo_train( log_data["token_ids"] = train_data["input_ids"].tolist() log_data["token_loss_mask"] = train_data["token_mask"].tolist() log_data["sample_loss_mask"] = train_data["sample_mask"].tolist() + log_data["reward_valid_mask"] = reward_valid_mask.tolist() log_data["advantages"] = train_data["advantages"].tolist() log_data["generation_logprobs"] = train_data["generation_logprobs"].tolist() log_data["prev_logprobs"] = train_data["prev_logprobs"].tolist() @@ -3136,6 +3293,10 @@ def async_grpo_train( print(f" • Draft Loss: {metrics['draft_loss']:.4f}") print(f" • Generation KL Error: {metrics['gen_kl_error']:.4f}") print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}") + print( + f" • Invalid Reward Rate: {metrics['invalid_reward_rate']:.4f} ({int(metrics['invalid_reward_count'])} samples)" + ) + print(f" • Avg Valid Reward: {metrics['valid_reward_mean']:.4f}") print(f" • Buffer Size: {buffer_size_current}") print(f" • Avg Trajectory Age: {avg_trajectory_age:.2f} steps") diff --git a/nemo_rl/environments/code_jaccard_environment.py b/nemo_rl/environments/code_jaccard_environment.py index d0ea62b4b0..1759bf986f 100644 --- a/nemo_rl/environments/code_jaccard_environment.py +++ b/nemo_rl/environments/code_jaccard_environment.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, TypedDict, Union +from typing import Any, Optional, TypedDict import ray import torch @@ -36,6 +36,12 @@ class CodeJaccardEnvironmentMetadata(TypedDict): ground_truth: str +class CodeJaccardVerifyResult(TypedDict): + scores: list[float] + reward_valid_mask: list[bool] + extracted_answers: list[str | None] | None + + @ray.remote # pragma: no cover class CodeJaccardVerifyWorker: """Worker for evaluating code responses using Jaccard-based similarity.""" @@ -48,7 +54,7 @@ def verify( pred_responses: list[str], ground_truths: list[str], return_extracted_answer: bool = False, - ) -> Union[list[float], tuple[list[float], list[str | None]]]: + ) -> CodeJaccardVerifyResult: """Verify code responses against ground-truth solutions using Jaccard-based similarity. We use a simple text similarity approach (Jaccard over tokenized words) @@ -65,7 +71,10 @@ def verify( If return_extracted_answer is True, returns (scores, extracted_answers). """ results = [] - extracted_answers: list[str | None] = [] + extracted_answers: list[str | None] | None = ( + [] if return_extracted_answer else None + ) + reward_valid_mask: list[bool] = [] for response, ground_truth in zip(pred_responses, ground_truths): try: @@ -73,20 +82,25 @@ def verify( # This is a basic implementation - could be enhanced with more sophisticated metrics score = self._calculate_preference_score(response, ground_truth) results.append(float(score)) + reward_valid_mask.append(True) if return_extracted_answer: + assert extracted_answers is not None # For CodeJaccard, the "extracted answer" is the full response extracted_answers.append(response.strip()) except Exception: results.append(0.0) + reward_valid_mask.append(False) if return_extracted_answer: + assert extracted_answers is not None extracted_answers.append(None) - if return_extracted_answer: - return results, extracted_answers - else: - return results + return { + "scores": results, + "reward_valid_mask": reward_valid_mask, + "extracted_answers": extracted_answers, + } def _calculate_preference_score(self, response: str, ground_truth: str) -> float: """Calculate a Jaccard-based alignment score between response and ground truth. @@ -212,18 +226,20 @@ def step( worker_results = ray.get(futures) # Flatten the results and extract both scores and answers - results = [] + results: list[float] = [] + reward_valid_mask: list[bool] = [] extracted_answers: list[str | None] | None = ( [] if return_extracted_answer else None ) for worker_result in worker_results: + results.extend(worker_result["scores"]) + reward_valid_mask.extend(worker_result["reward_valid_mask"]) if return_extracted_answer: - worker_scores, worker_answers = worker_result - results.extend(worker_scores) + worker_answers = worker_result["extracted_answers"] + assert extracted_answers is not None + assert worker_answers is not None extracted_answers.extend(worker_answers) - else: - results.extend(worker_result) # Create observations based on Jaccard alignment observations = [ @@ -240,6 +256,9 @@ def step( rewards = torch.tensor(results).cpu() done = torch.ones_like(rewards).cpu() next_stop_strings = [None] * len(message_log_batch) + reward_valid_mask_tensor = torch.tensor( + reward_valid_mask, dtype=torch.bool + ).cpu() return EnvironmentReturn( observations=observations, @@ -248,6 +267,7 @@ def step( rewards=rewards, terminateds=done, answers=extracted_answers, + reward_valid_mask=reward_valid_mask_tensor, ) def global_post_process_and_metrics( diff --git a/nemo_rl/environments/interfaces.py b/nemo_rl/environments/interfaces.py index 71514e0142..1ae6bb1a84 100644 --- a/nemo_rl/environments/interfaces.py +++ b/nemo_rl/environments/interfaces.py @@ -39,6 +39,9 @@ class EnvironmentReturn(NamedTuple, Generic[MetadataT]): rewards: the rewards for this turn. terminateds: whether the episode ended this turn. answers: the answers for this turn. + reward_valid_mask: optional boolean mask marking whether the reward for + each sample is valid. Invalid rewards should be masked + out of training statistics and loss computation. """ observations: list[dict[str, str]] @@ -47,6 +50,7 @@ class EnvironmentReturn(NamedTuple, Generic[MetadataT]): rewards: Tensor ## Shape [B] for single-reward, [B, num_reward_components] for multi-reward (e.g. GDPO) terminateds: Tensor answers: list[str | None] | None + reward_valid_mask: Tensor | None = None class EnvironmentInterface(abc.ABC, Generic[MetadataT]): diff --git a/nemo_rl/environments/math_environment.py b/nemo_rl/environments/math_environment.py index 6da76d04db..4dc97c66a2 100644 --- a/nemo_rl/environments/math_environment.py +++ b/nemo_rl/environments/math_environment.py @@ -15,7 +15,7 @@ import io import logging import re -from typing import Any, NotRequired, TypedDict, Union +from typing import Any, NotRequired, TypedDict import ray import torch @@ -47,6 +47,18 @@ class MathEnvConfig(TypedDict): math_verify_impl: NotRequired[str | None] +class SingleRewardVerificationResult(TypedDict): + scores: list[float] + reward_valid_mask: list[bool] + extracted_answers: list[str | None] | None + + +class MultiRewardVerificationResult(TypedDict): + scores: list[list[float]] + reward_valid_mask: list[bool] + extracted_answers: list[str | None] | None + + @contextlib.contextmanager def _mute_output(): devnull_out, devnull_err = io.StringIO(), io.StringIO() @@ -78,7 +90,7 @@ def verify( ground_truths: list[str], return_extracted_answer: bool = False, **kwargs, - ) -> Union[list[float], tuple[list[float], list[str | None]]]: + ) -> SingleRewardVerificationResult: """Verify the correctness of the predicted responses against the ground truth. Args: @@ -91,18 +103,21 @@ def verify( If return_extracted_answer is True, returns (scores, extracted_answers). """ results = [] - extracted_answers: list[str | None] = [] + extracted_answers: list[str | None] | None = ( + [] if return_extracted_answer else None + ) + reward_valid_mask: list[bool] = [] for response, ground_truth in zip(pred_responses, ground_truths): try: with _mute_output(): math_verify_impl = kwargs.get("math_verify_impl", "hf_math_verify") - if kwargs.get("math_verify_impl") == "dapo_math_verify": + if math_verify_impl == "dapo_math_verify": # This compute_score is from the DAPO Math Verifier from Verl reward_dict = dapo_math_verify(response, ground_truth) ret_score = reward_dict["score"] extracted_answer = reward_dict["pred"] - elif kwargs.get("math_verify_impl") == "hf_math_verify": + elif math_verify_impl == "hf_math_verify": ground_truth_parsable = "\\boxed{" + ground_truth + "}" ret_score, extracted_answer = self.verify_func( [ground_truth_parsable], [response] @@ -113,8 +128,10 @@ def verify( ) results.append(float(ret_score)) + reward_valid_mask.append(True) if return_extracted_answer: + assert extracted_answers is not None # Make sure the extracted answer is not None and is a list of two elements assert extracted_answer is not None assert len(extracted_answer) == 2 @@ -133,12 +150,16 @@ def verify( # to catch it. except (Exception, TimeoutException): results.append(0.0) - extracted_answers.append(None) + reward_valid_mask.append(False) + if return_extracted_answer: + assert extracted_answers is not None + extracted_answers.append(None) - if return_extracted_answer: - return results, extracted_answers - else: - return results + return { + "scores": results, + "reward_valid_mask": reward_valid_mask, + "extracted_answers": extracted_answers, + } @ray.remote # pragma: no cover @@ -149,7 +170,7 @@ def verify( ground_truths: list[str], return_extracted_answer: bool = False, **kwargs, - ) -> Union[list[float], tuple[list[float], list[str | None]]]: + ) -> SingleRewardVerificationResult: """Verify the correctness of the predicted responses against the ground truth. Args: @@ -162,7 +183,10 @@ def verify( If return_extracted_answer is True, returns (scores, extracted_answers). """ results = [] - extracted_answers: list[str | None] = [] + extracted_answers: list[str | None] | None = ( + [] if return_extracted_answer else None + ) + reward_valid_mask: list[bool] = [] for response, ground_truth in zip(pred_responses, ground_truths): response = answer_parsing.normalize_response(response) @@ -179,12 +203,16 @@ def verify( break score = 1.0 if extracted_answer == ground_truth else 0.0 results.append(score) - extracted_answers.append(extracted_answer) + reward_valid_mask.append(True) + if return_extracted_answer: + assert extracted_answers is not None + extracted_answers.append(extracted_answer) - if return_extracted_answer: - return results, extracted_answers - else: - return results + return { + "scores": results, + "reward_valid_mask": reward_valid_mask, + "extracted_answers": extracted_answers, + } @ray.remote # pragma: no cover @@ -195,7 +223,7 @@ def verify( ground_truths: list[str], return_extracted_answer: bool = False, **kwargs, - ) -> Union[list[float], tuple[list[float], list[str | None]]]: + ) -> SingleRewardVerificationResult: """Verify the correctness of the predicted responses against the ground truth. Args: @@ -208,7 +236,10 @@ def verify( If return_extracted_answer is True, returns (scores, extracted_answers). """ results = [] - extracted_answers: list[str | None] = [] + extracted_answers: list[str | None] | None = ( + [] if return_extracted_answer else None + ) + reward_valid_mask: list[bool] = [] for response, ground_truth in zip(pred_responses, ground_truths): ground_truth = answer_parsing.normalize_response(ground_truth) @@ -221,13 +252,16 @@ def verify( ) score = 1.0 if extracted_answer == ground_truth else 0.0 results.append(score) + reward_valid_mask.append(True) if return_extracted_answer: + assert extracted_answers is not None extracted_answers.append(extracted_answer) - if return_extracted_answer: - return results, extracted_answers - else: - return results + return { + "scores": results, + "reward_valid_mask": reward_valid_mask, + "extracted_answers": extracted_answers, + } @ray.remote # pragma: no cover @@ -253,7 +287,7 @@ def verify( ground_truths: list[str], return_extracted_answer: bool = False, **kwargs, - ) -> Union[list[list[float]], tuple[list[list[float]], list[str | None]]]: + ) -> MultiRewardVerificationResult: """Verify the correctness of the predicted responses against the ground truth. Args: @@ -297,7 +331,10 @@ def format_reward_func(completions, **kwargs) -> list[float]: return rewards results = [[] for _ in range(self.number_of_rewards)] - extracted_answers: list[str | None] = [] + extracted_answers: list[str | None] | None = ( + [] if return_extracted_answer else None + ) + reward_valid_mask: list[bool] = [] for response, ground_truth in zip(pred_responses, ground_truths): try: @@ -314,8 +351,10 @@ def format_reward_func(completions, **kwargs) -> list[float]: results[0].extend(cor_reward) results[1].extend(int_reward) results[2].extend(format_reward) + reward_valid_mask.append(True) if return_extracted_answer: + assert extracted_answers is not None extracted_answer = extract_xml_answer(response) extracted_answers.append(extracted_answer) @@ -326,13 +365,16 @@ def format_reward_func(completions, **kwargs) -> list[float]: results[0].append(0.0) results[1].append(0.0) results[2].append(0.0) - extracted_answers.append(None) + reward_valid_mask.append(False) + if return_extracted_answer: + assert extracted_answers is not None + extracted_answers.append(None) - if return_extracted_answer: - return results, extracted_answers - else: - # return results --> [[0,1,0], [0,2,0], .........] - return results + return { + "scores": results, + "reward_valid_mask": reward_valid_mask, + "extracted_answers": extracted_answers, + } class MathEnvironmentMetadata(TypedDict): @@ -469,15 +511,19 @@ def step( worker_results = ray.get(futures) # Flatten the results and extract both scores and answers - results = [] + results: list[float] = [] + reward_valid_mask: list[bool] = [] extracted_answers: list[str | None] | None = ( [] if return_extracted_answer else None ) for worker_result in worker_results: - worker_scores = worker_result + worker_scores = worker_result["scores"] + reward_valid_mask.extend(worker_result["reward_valid_mask"]) if return_extracted_answer: - worker_scores, worker_answers = worker_result + worker_answers = worker_result["extracted_answers"] + assert extracted_answers is not None + assert worker_answers is not None extracted_answers.extend(worker_answers) results.extend(worker_scores) @@ -495,6 +541,9 @@ def step( rewards = torch.tensor(results).cpu() done = torch.ones_like(rewards).cpu() next_stop_strings = [None] * len(message_log_batch) + reward_valid_mask_tensor = torch.tensor( + reward_valid_mask, dtype=torch.bool + ).cpu() return EnvironmentReturn( observations=observations, @@ -503,6 +552,7 @@ def step( rewards=rewards, terminateds=done, answers=extracted_answers, + reward_valid_mask=reward_valid_mask_tensor, ) @@ -566,16 +616,20 @@ def step( worker_results = ray.get(futures) # Flatten the results and extract both scores and answers - number_of_rewards = len(worker_results[0]) + number_of_rewards = len(worker_results[0]["scores"]) results = [[] for _ in range(number_of_rewards)] + reward_valid_mask: list[bool] = [] extracted_answers: list[str | None] | None = ( [] if return_extracted_answer else None ) for worker_result in worker_results: - worker_scores = worker_result + worker_scores = worker_result["scores"] + reward_valid_mask.extend(worker_result["reward_valid_mask"]) if return_extracted_answer: - worker_scores, worker_answers = worker_result + worker_answers = worker_result["extracted_answers"] + assert extracted_answers is not None + assert worker_answers is not None extracted_answers.extend(worker_answers) for i in range(number_of_rewards): results[i].extend(worker_scores[i]) @@ -595,6 +649,9 @@ def step( ## hard fixed this done to done = torch.ones(rewards.shape[0]).cpu() next_stop_strings = [None] * len(message_log_batch) + reward_valid_mask_tensor = torch.tensor( + reward_valid_mask, dtype=torch.bool + ).cpu() return EnvironmentReturn( observations=observations, @@ -603,4 +660,5 @@ def step( rewards=rewards, terminateds=done, answers=extracted_answers, + reward_valid_mask=reward_valid_mask_tensor, ) diff --git a/nemo_rl/environments/vlm_environment.py b/nemo_rl/environments/vlm_environment.py index 7e4943c3b2..2de9ac6254 100644 --- a/nemo_rl/environments/vlm_environment.py +++ b/nemo_rl/environments/vlm_environment.py @@ -45,6 +45,11 @@ class VLMEnvConfig(TypedDict): reward_functions: List[dict[str, Any]] # list of reward functions and their weights +class VLMVerifyResult(TypedDict): + scores: list[float] + reward_valid_mask: list[bool] + + @contextlib.contextmanager def _mute_output(): devnull_out, devnull_err = io.StringIO(), io.StringIO() @@ -93,7 +98,7 @@ def __init__(self, cfg: VLMEnvConfig) -> None: def verify( self, pred_responses: list[str], ground_truths: list[str] - ) -> list[float]: + ) -> VLMVerifyResult: """Verify the correctness of the predicted responses against the ground truth. Args: @@ -103,20 +108,20 @@ def verify( Returns: list[float]. The rewards for each predicted response. """ - results = [] + results: list[float] = [] + reward_valid_mask: list[bool] = [] for response, ground_truth in zip(pred_responses, ground_truths): + ret_score = 0.0 + is_valid = True try: with _mute_output(): - try: - ret_score, _ = self.verify_func(ground_truth, response) - except Exception as e: - ret_score = 0.0 - print(f"Error in verify_func: {e}") - results.append(float(ret_score)) - except Exception as e: - print(f"Error in verify: {e}") - results.append(0.0) - return results + ret_score, _ = self.verify_func(ground_truth, response) + except Exception: + is_valid = False + + results.append(float(ret_score)) + reward_valid_mask.append(is_valid) + return {"scores": results, "reward_valid_mask": reward_valid_mask} class VLMEnvironmentMetadata(TypedDict): @@ -185,10 +190,15 @@ def step( # type: ignore[override] ) ] - results = ray.get(futures) + worker_results = ray.get(futures) # flatten the results - results = [item for sublist in results for item in sublist] + results = [item for worker_result in worker_results for item in worker_result["scores"]] + reward_valid_mask = [ + is_valid + for worker_result in worker_results + for is_valid in worker_result["reward_valid_mask"] + ] observations = [ { "role": "environment", @@ -202,6 +212,9 @@ def step( # type: ignore[override] # create a tensor of rewards and done flags rewards = torch.tensor(results).cpu() done = torch.ones_like(rewards).cpu() + reward_valid_mask_tensor = torch.tensor( + reward_valid_mask, dtype=torch.bool + ).cpu() next_stop_strings = [None] * len(message_log_batch) @@ -212,6 +225,7 @@ def step( # type: ignore[override] rewards=rewards, terminateds=done, answers=None, + reward_valid_mask=reward_valid_mask_tensor, ) def global_post_process_and_metrics( diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py index d5186e868a..d34c7359a7 100644 --- a/nemo_rl/experience/rollouts.py +++ b/nemo_rl/experience/rollouts.py @@ -237,6 +237,25 @@ async def generate_responses_async( return batch, generated_ids, gen_metrics +def _normalize_reward_valid_mask( + reward_valid_mask: torch.Tensor | list[bool] | None, batch_size: int +) -> torch.Tensor: + """Normalize optional reward validity info to a boolean tensor.""" + if reward_valid_mask is None: + return torch.ones(batch_size, dtype=torch.bool) + mask: torch.Tensor + if isinstance(reward_valid_mask, torch.Tensor): + mask = reward_valid_mask.to(dtype=torch.bool).cpu().reshape(-1) + else: + mask = torch.tensor(reward_valid_mask, dtype=torch.bool).reshape(-1) + if mask.numel() != batch_size: + raise ValueError( + "reward_valid_mask must have one element per sample; " + f"got {mask.numel()} values for batch size {batch_size}" + ) + return mask + + def calculate_rewards( batch: BatchedDataDict[DatumSpec], task_to_env: dict[str, EnvironmentInterface], @@ -254,6 +273,7 @@ def calculate_rewards( - next_stop_strings: List of stop strings for the next generation step. - rewards: Tensor of rewards for the last turn. - terminateds: Tensor of booleans indicating if an episode ended naturally. + - reward_valid_mask: Tensor of booleans marking whether each reward is valid. """ # Extract message logs for environment (most recent interaction) to_env = [ @@ -296,18 +316,20 @@ def calculate_rewards( all_metadata = [] # Store extracted metadata all_indices_order = [] all_answers = [] + all_reward_valid_masks = [] for future, result in zip(futures, results): indices = future_to_indices[future] - # Environment step returns: EnvironmentReturn - ( - env_observations, - metadata, - next_stop_strings, - task_rewards, - terminateds, - answers, - ) = result + env_observations = result.observations + metadata = result.metadata + next_stop_strings = result.next_stop_strings + task_rewards = result.rewards + terminateds = result.terminateds + answers = result.answers + task_reward_valid_mask = _normalize_reward_valid_mask( + result.reward_valid_mask, + len(task_rewards), + ) if next_stop_strings is None: next_stop_strings = [None] * len(task_rewards) if answers is None: @@ -322,6 +344,7 @@ def calculate_rewards( all_next_stop_strings.append(next_stop_strings[i]) all_metadata.append(metadata[i]) all_answers.append(answers[i]) + all_reward_valid_masks.append(bool(task_reward_valid_mask[i].item())) # Sort results by original index to maintain order sorted_indices = sorted( @@ -339,6 +362,9 @@ def calculate_rewards( next_stop_strings = [all_next_stop_strings[i] for i in sorted_indices] metadata = [all_metadata[i] for i in sorted_indices] # Sort metadata answers = [all_answers[i] for i in sorted_indices] + reward_valid_mask = torch.tensor( + [all_reward_valid_masks[i] for i in sorted_indices], dtype=torch.bool + ) return EnvironmentReturn( observations=env_observations, @@ -347,6 +373,7 @@ def calculate_rewards( rewards=rewards, terminateds=terminateds, answers=answers, + reward_valid_mask=reward_valid_mask, ) @@ -379,6 +406,7 @@ def run_multi_turn_rollout( batch_size = len(current_batch["message_log"]) active_indices = torch.arange(batch_size) total_rewards = torch.zeros(batch_size, dtype=torch.float32) + sample_reward_valid_mask = torch.ones(batch_size, dtype=torch.bool) # Multi_rewards: number of components inferred from first env_output (1 for single-reward envs) number_of_rewards: int | None = None @@ -468,6 +496,11 @@ def run_multi_turn_rollout( # Calculate rewards and get environment feedback env_output: EnvironmentReturn = calculate_rewards(active_batch, task_to_env) + turn_reward_valid_mask = _normalize_reward_valid_mask( + env_output.reward_valid_mask, + len(active_indices), + ) + sample_reward_valid_mask[active_indices] &= turn_reward_valid_mask # Infer number of reward components on first turn (supports single- and multi-reward envs) if number_of_rewards is None: @@ -564,6 +597,7 @@ def run_multi_turn_rollout( # Add total rewards to the final batch current_batch["total_reward"] = total_rewards current_batch["truncated"] = sample_truncated + current_batch["reward_valid_mask"] = sample_reward_valid_mask # Expose per-component rewards (reward1, reward2, ...) for multi-reward envs only; GRPO uses total_reward if multi_rewards is not None: num_reward_components = multi_rewards.shape[1] @@ -579,6 +613,8 @@ def run_multi_turn_rollout( "natural_termination_rate": float(sample_terminated.float().mean().item()), "truncation_rate": float(sample_truncated.float().mean().item()), "max_turns_reached_rate": float(sample_max_turns_reached.float().mean().item()), + "invalid_reward_count": int((~sample_reward_valid_mask).sum().item()), + "invalid_reward_rate": float((~sample_reward_valid_mask).float().mean().item()), # Token usage metrics "mean_total_tokens_per_sample": float( sample_token_counts.float().mean().item() @@ -710,6 +746,7 @@ async def run_sample_multi_turn_rollout( terminated = False truncated = False max_turns_reached = False + reward_valid = True # Track per-turn metrics turn_gen_tokens = [] @@ -775,6 +812,9 @@ async def run_sample_multi_turn_rollout( # Get environment feedback env_output = calculate_rewards(sample_batch, task_to_env) + reward_valid = reward_valid and bool( + _normalize_reward_valid_mask(env_output.reward_valid_mask, 1)[0].item() + ) # Update total reward and optional per-reward signals (reward1, reward2, ... rewardN) if env_output.rewards.ndim == 2 and env_output.rewards.shape[1] >= 1: multi_reward_seen = True @@ -832,6 +872,7 @@ async def run_sample_multi_turn_rollout( "extra_env_info": current_extra_env_info, "task_name": task_name, "total_reward": torch.tensor(total_reward), + "reward_valid_mask": torch.tensor(reward_valid, dtype=torch.bool), "stop_strings": current_stop_strings, "idx": sample_idx, } @@ -849,6 +890,7 @@ async def run_sample_multi_turn_rollout( "truncated": truncated, "max_turns_reached": max_turns_reached, "total_reward": total_reward, + "reward_valid": reward_valid, "turn_gen_tokens": turn_gen_tokens, "turn_input_tokens": turn_input_tokens, "turn_total_tokens": turn_total_tokens, @@ -951,6 +993,9 @@ async def run_single_sample_with_error_handling(i, sample_state): "total_reward": torch.stack( [state["total_reward"] for state in final_sample_states] ), + "reward_valid_mask": torch.stack( + [state["reward_valid_mask"] for state in final_sample_states] + ).bool(), "idx": [ state.get("idx", i) for i, state in enumerate(final_sample_states) ], @@ -1006,6 +1051,13 @@ async def run_single_sample_with_error_handling(i, sample_state): m["max_turns_reached"] for m in all_sample_metrics ) / batch_size, + "invalid_reward_count": sum( + not m["reward_valid"] for m in all_sample_metrics + ), + "invalid_reward_rate": sum( + not m["reward_valid"] for m in all_sample_metrics + ) + / batch_size, # Token usage metrics "mean_total_tokens_per_sample": sum( m["total_tokens"] for m in all_sample_metrics diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index 2ddbf001c9..7e3330dd74 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -260,21 +260,31 @@ def mock_ray_get(ref): @ray.remote(num_cpus=0) class MockEnvironment(EnvironmentInterface): - def __init__(self, rewards: list[float]): + def __init__( + self, + rewards: list[float], + reward_valid_mask: list[bool] | None = None, + ): self.rewards = rewards + self.reward_valid_mask = reward_valid_mask or [True] * len(rewards) self._calls = 0 def step( self, messages: list[LLMMessageLogType], env_info: list[dict] ) -> EnvironmentReturn: self._calls += 1 - return ( - [{"role": "environment", "content": "observation"}] * len(messages), - [{}] * len(messages), - [[]] * len(messages), - self.rewards, - [True] * len(messages), - [None] * len(messages), + batch_size = len(messages) + return EnvironmentReturn( + observations=[{"role": "environment", "content": "observation"}] + * batch_size, + metadata=[{}] * batch_size, + next_stop_strings=[[]] * batch_size, + rewards=torch.tensor(self.rewards[:batch_size], dtype=torch.float32), + terminateds=torch.ones(batch_size, dtype=torch.bool), + answers=[None] * batch_size, + reward_valid_mask=torch.tensor( + self.reward_valid_mask[:batch_size], dtype=torch.bool + ), ) def get_calls(self): @@ -330,18 +340,17 @@ def test_calculate_rewards_single_task(mock_env): batch = create_mock_batch(2, task_names, message_logs) # Calculate rewards - env_observations, metadata, next_stop_strings, rewards, terminateds, answers = ( - calculate_rewards(batch, task_to_env) - ) + result = calculate_rewards(batch, task_to_env) # Verify results - assert torch.allclose(rewards, torch.tensor([1.0, 2.0])) - assert len(env_observations) == 2 - assert len(terminateds) == 2 - assert len(next_stop_strings) == 2 - assert len(metadata) == 2 - assert len(answers) == 2 - assert torch.allclose(rewards, torch.tensor([1.0, 2.0])) + assert torch.allclose(result.rewards, torch.tensor([1.0, 2.0])) + assert torch.equal(result.reward_valid_mask, torch.tensor([True, True])) + assert len(result.observations) == 2 + assert len(result.terminateds) == 2 + assert len(result.next_stop_strings) == 2 + assert len(result.metadata) == 2 + assert len(result.answers) == 2 + assert torch.allclose(result.rewards, torch.tensor([1.0, 2.0])) assert ( ray.get(mock_env.get_calls.remote()) == 1 ) # Should only call once for all samples of same task @@ -366,18 +375,19 @@ def test_calculate_rewards_multiple_tasks(mock_envs): batch = create_mock_batch(4, task_names, message_logs) # Calculate rewards - env_observations, metadata, next_stop_strings, rewards, terminateds, answers = ( - calculate_rewards(batch, mock_envs) - ) + result = calculate_rewards(batch, mock_envs) # Verify results - assert torch.allclose(rewards, torch.tensor([1.0, 2.0, 3.0, 4.0])) - assert len(env_observations) == 4 - assert len(terminateds) == 4 - assert len(next_stop_strings) == 4 - assert len(metadata) == 4 - assert len(answers) == 4 - assert torch.allclose(rewards, torch.tensor([1.0, 2.0, 3.0, 4.0])) + assert torch.allclose(result.rewards, torch.tensor([1.0, 2.0, 3.0, 4.0])) + assert torch.equal( + result.reward_valid_mask, torch.tensor([True, True, True, True]) + ) + assert len(result.observations) == 4 + assert len(result.terminateds) == 4 + assert len(result.next_stop_strings) == 4 + assert len(result.metadata) == 4 + assert len(result.answers) == 4 + assert torch.allclose(result.rewards, torch.tensor([1.0, 2.0, 3.0, 4.0])) assert ( ray.get(mock_envs["math"].get_calls.remote()) == 1 ) # One call for all math samples @@ -394,17 +404,16 @@ def test_calculate_rewards_empty_batch(mock_env): batch = create_mock_batch(0, [], []) # Calculate rewards - env_observations, metadata, next_stop_strings, rewards, terminateds, answers = ( - calculate_rewards(batch, task_to_env) - ) + result = calculate_rewards(batch, task_to_env) # Verify results - assert len(rewards) == 0 - assert len(env_observations) == 0 - assert len(terminateds) == 0 - assert len(next_stop_strings) == 0 - assert len(metadata) == 0 - assert len(answers) == 0 + assert len(result.rewards) == 0 + assert len(result.reward_valid_mask) == 0 + assert len(result.observations) == 0 + assert len(result.terminateds) == 0 + assert len(result.next_stop_strings) == 0 + assert len(result.metadata) == 0 + assert len(result.answers) == 0 assert ( ray.get(mock_env.get_calls.remote()) == 0 ) # Should not call environment for empty batch @@ -425,6 +434,49 @@ def test_calculate_rewards_missing_environment(): calculate_rewards(batch, task_to_env) +def test_calculate_rewards_preserves_reward_valid_mask_order(): + """Test reward validity masks are preserved after task grouping and sorting.""" + math_env = MockEnvironment.remote( + rewards=[1.0, 2.0], reward_valid_mask=[True, False] + ) + code_env = MockEnvironment.remote( + rewards=[3.0, 4.0], reward_valid_mask=[False, True] + ) + try: + batch = create_mock_batch( + 4, + ["math", "math", "code", "code"], + [ + [ + {"role": "user", "content": "1+1"}, + {"role": "assistant", "content": "2"}, + ], + [ + {"role": "user", "content": "2+2"}, + {"role": "assistant", "content": "4"}, + ], + [ + {"role": "user", "content": "print('hello')"}, + {"role": "assistant", "content": "hello"}, + ], + [ + {"role": "user", "content": "print('world')"}, + {"role": "assistant", "content": "world"}, + ], + ], + ) + + result = calculate_rewards(batch, {"math": math_env, "code": code_env}) + + assert torch.allclose(result.rewards, torch.tensor([1.0, 2.0, 3.0, 4.0])) + assert torch.equal( + result.reward_valid_mask, torch.tensor([True, False, False, True]) + ) + finally: + ray.kill(math_env) + ray.kill(code_env) + + def test_dapo_dynamic_sampling_filters_nonzero_std(): """Test that DAPO dynamic sampling only selects prompts with non-zero standard deviation.""" # Create mock batch data with 6 prompts (2 prompts * 3 generations each) @@ -1675,6 +1727,32 @@ def test_grpo_advantage_estimator_zero_std(): assert torch.allclose(result[2:], expected_prompt_1, rtol=1e-4) +def test_grpo_advantage_estimator_masks_invalid_samples(): + """Test GRPOAdvantageEstimator ignores invalid samples in baseline computation.""" + estimator = GRPOAdvantageEstimator( + { + "use_leave_one_out_baseline": False, + "normalize_rewards": False, + }, + {}, + ) + + prompt_ids = torch.tensor([[0], [0], [0]]) + rewards = torch.tensor([1.0, 999.0, 3.0]) + mask = torch.ones(3, 2) + sample_valid_mask = torch.tensor([True, False, True]) + + result = estimator.compute_advantage( + prompt_ids=prompt_ids, + rewards=rewards, + mask=mask, + sample_valid_mask=sample_valid_mask, + ) + + expected = torch.tensor([[-1.0], [0.0], [1.0]]).expand(3, 2) + assert torch.allclose(result, expected) + + def test_grpo_advantage_estimator_tensor_shapes(): """Test GRPOAdvantageEstimator with different tensor shapes. @@ -1866,6 +1944,35 @@ def test_gdpo_advantage_estimator_single_reward(): estimator.compute_advantage(prompt_ids, None, mask, repeated_batch) +def test_gdpo_advantage_estimator_masks_invalid_samples(): + """Test GDPOAdvantageEstimator ignores invalid samples in each reward head.""" + estimator_config = { + "use_leave_one_out_baseline": False, + "normalize_rewards": False, + } + loss_config = {} + estimator = GDPOAdvantageEstimator(estimator_config, loss_config) + + prompt_ids = torch.tensor([[0], [0], [0]]) + mask = torch.ones(3, 1) + repeated_batch = { + "reward1": torch.tensor([1.0, 999.0, 3.0]), + "reward2": torch.tensor([0.0, 999.0, 2.0]), + } + sample_valid_mask = torch.tensor([True, False, True]) + + result = estimator.compute_advantage( + prompt_ids, + None, + mask, + repeated_batch, + sample_valid_mask=sample_valid_mask, + ) + + expected = torch.tensor([[-0.7071], [0.0], [0.7071]]) + assert torch.allclose(result, expected, rtol=1e-3, atol=1e-4) + + # ============================================================================ # Tests for ReinforcePlusPlusAdvantageEstimator class # ============================================================================ @@ -1909,6 +2016,78 @@ def test_reinforce_plus_plus_global_normalization(): assert result[0, 0] < result[1, 0] < result[2, 0] < result[3, 0] +def test_reinforce_plus_plus_masks_invalid_samples(): + """Test Reinforce++ excludes invalid samples from baseline and normalization.""" + estimator_config = { + "minus_baseline": True, + } + loss_config = { + "use_kl_in_reward": False, + "reference_policy_kl_penalty": 0.0001, + "reference_policy_kl_type": "k2", + } + estimator = ReinforcePlusPlusAdvantageEstimator(estimator_config, loss_config) + + prompt_ids = torch.tensor([[0], [0], [0]]) + rewards = torch.tensor([1.0, 999.0, 3.0]) + sample_valid_mask = torch.tensor([True, False, True]) + mask = sample_valid_mask.unsqueeze(-1).repeat(1, 2).float() + + result = estimator.compute_advantage( + prompt_ids=prompt_ids, + rewards=rewards, + mask=mask, + sample_valid_mask=sample_valid_mask, + ) + + expected = torch.tensor([[-1.0, -1.0], [0.0, 0.0], [1.0, 1.0]]) + assert torch.allclose(result, expected) + + +def test_reinforce_plus_plus_masks_invalid_samples_after_kl_penalty(): + """Test Reinforce++ keeps invalid samples masked even when KL reward is enabled.""" + estimator_config = { + "minus_baseline": True, + } + loss_config = { + "use_kl_in_reward": True, + "reference_policy_kl_penalty": 1.0, + "reference_policy_kl_type": "k1", + } + estimator = ReinforcePlusPlusAdvantageEstimator(estimator_config, loss_config) + + prompt_ids = torch.tensor([[0], [0], [0]]) + rewards = torch.tensor([1.0, 999.0, 3.0]) + sample_valid_mask = torch.tensor([True, False, True]) + mask = torch.ones(3, 2) + logprobs_policy = torch.tensor( + [ + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ] + ) + logprobs_reference = torch.tensor( + [ + [0.0, 0.0], + [2.0, 2.0], + [0.0, 0.0], + ] + ) + + result = estimator.compute_advantage( + prompt_ids=prompt_ids, + rewards=rewards, + mask=mask, + sample_valid_mask=sample_valid_mask, + logprobs_policy=logprobs_policy, + logprobs_reference=logprobs_reference, + ) + + expected = torch.tensor([[-1.0, -1.0], [0.0, 0.0], [1.0, 1.0]]) + assert torch.allclose(result, expected) + + # ============================================================================ # Tests for validate function # ============================================================================ diff --git a/tests/unit/environments/test_code_jaccard_environment.py b/tests/unit/environments/test_code_jaccard_environment.py index 0880fcc6f6..1387803eeb 100644 --- a/tests/unit/environments/test_code_jaccard_environment.py +++ b/tests/unit/environments/test_code_jaccard_environment.py @@ -16,8 +16,12 @@ import pytest import ray +import torch -from nemo_rl.environments.code_jaccard_environment import CodeJaccardEnvConfig +from nemo_rl.environments.code_jaccard_environment import ( + CodeJaccardEnvConfig, + CodeJaccardVerifyWorker, +) from nemo_rl.environments.utils import create_env @@ -66,6 +70,7 @@ def test_code_jaccard_basic_alignment(code_jaccard_env): assert result.rewards.shape == (2,) assert float(result.rewards[0]) == pytest.approx(1.0, rel=0, abs=1e-6) assert float(result.rewards[1]) == pytest.approx(1.0, rel=0, abs=1e-6) + assert torch.equal(result.reward_valid_mask, torch.tensor([True, True])) # Terminated flags set assert result.terminateds.shape == (2,) assert all(result.terminateds == 1.0) @@ -90,6 +95,7 @@ def test_code_jaccard_misalignment(code_jaccard_env): # Reward should be between 0 and 1, and reasonably low for disjoint tokens score = float(result.rewards[0]) assert 0.0 <= score <= 0.5 + assert torch.equal(result.reward_valid_mask, torch.tensor([True])) assert result.terminateds.shape == (1,) assert result.terminateds[0] == 1.0 @@ -116,6 +122,7 @@ def test_code_jaccard_answers_return(code_jaccard_env): assert result.answers is not None assert result.answers == ["x = a + b", "def add(a, b): return a + b"] + assert torch.equal(result.reward_valid_mask, torch.tensor([True, True])) def test_code_jaccard_empty_input(code_jaccard_env): @@ -125,3 +132,23 @@ def test_code_jaccard_empty_input(code_jaccard_env): assert len(result.metadata) == 0 assert result.rewards.shape == (0,) assert result.terminateds.shape == (0,) + assert result.reward_valid_mask.shape == (0,) + + +def test_code_jaccard_worker_marks_invalid_rewards_on_exception(): + """Invalid worker inputs should produce zero reward and an invalid mask.""" + worker = CodeJaccardVerifyWorker.remote() + try: + result = ray.get( + worker.verify.remote( + [None], # type: ignore[list-item] + ["print('hello')"], + True, + ) + ) + finally: + ray.kill(worker) + + assert result["scores"] == [0.0] + assert result["reward_valid_mask"] == [False] + assert result["extracted_answers"] == [None] diff --git a/tests/unit/environments/test_math_environment.py b/tests/unit/environments/test_math_environment.py index 8f721f97dd..5a1b7fa015 100644 --- a/tests/unit/environments/test_math_environment.py +++ b/tests/unit/environments/test_math_environment.py @@ -15,7 +15,9 @@ import pytest import ray +import torch +from nemo_rl.environments.math_environment import HFVerifyWorker from nemo_rl.environments.utils import create_env # ============================================================================ @@ -47,6 +49,28 @@ def math_multi_reward_env(): time.sleep(0.1) +@pytest.fixture(scope="module") +def invalid_math_env(): + """Create a MathEnvironment actor configured to force verifier failures.""" + env = create_env("math", {"num_workers": 2, "math_verify_impl": "invalid_impl"}) + yield env + env.shutdown.remote() + ray.kill(env) + time.sleep(0.1) + + +@pytest.fixture(scope="module") +def invalid_math_multi_reward_env(): + """Create a MathMultiRewardEnvironment actor configured to force verifier failures.""" + env = create_env( + "math_multi_reward", {"num_workers": 2, "math_verify_impl": "invalid_impl"} + ) + yield env + env.shutdown.remote() + ray.kill(env) + time.sleep(0.1) + + @pytest.fixture(scope="module") def multichoice_env(request): """Create a MathEnvironment actor for testing.""" @@ -240,6 +264,7 @@ def test_math_env_step_basic(math_env, basic_test_data): # Check rewards and done flags assert result.rewards.shape == (3,), "Rewards should be a tensor of shape (3,)" assert all(result.rewards == 1.0), "All rewards should be 1.0 for correct answers" + assert torch.equal(result.reward_valid_mask, torch.tensor([True, True, True])) assert result.terminateds.shape == (3,), ( "Terminated flags should be a tensor of shape (3,)" ) @@ -276,6 +301,7 @@ def test_multi_reward_env_step_basic(math_multi_reward_env, multi_reward_test_da # Check rewards: shape (batch_size=3, number_of_rewards=3) assert result.rewards.shape == (3, 3), "Rewards should be a tensor of shape (3, 3)" + assert torch.equal(result.reward_valid_mask, torch.tensor([True, True, True])) # Check rewards for each data point # First reward: correctness reward 1.0, int reward 1.0, format reward 1.0 @@ -339,6 +365,7 @@ def test_multichoice_env_step_basic(multichoice_env, multichoice_test_data): "The first two rewards should be 1.0 for correct answers" ) assert result.rewards[2] == 0.0, "The third reward should be 0.0 for wrong answer" + assert torch.equal(result.reward_valid_mask, torch.tensor([True, True, True])) assert result.terminateds.shape == (3,), ( "Terminated flags should be a tensor of shape (3,)" ) @@ -450,6 +477,7 @@ def test_math_exception_handling(math_env): # Program should not crash assert result.rewards.shape == (1,), "Rewards should be a tensor of shape (1,)" assert result.rewards[0] == 0.0, "Reward should be 0.0" + assert result.reward_valid_mask[0] == False, "Reward should be marked invalid" def test_math_timeout_handling(math_env): @@ -499,3 +527,49 @@ def test_math_timeout_handling(math_env): "Terminated flags should be a tensor of shape (1,)" ) assert result.terminateds[0] == 1.0, "Terminated flag should be 1.0" + assert result.reward_valid_mask[0] == False, "Reward should be marked invalid" + + +def test_math_invalid_verifier_marks_rewards_invalid(invalid_math_env, basic_test_data): + """Test forced verifier failures propagate an invalid reward mask.""" + result = ray.get( + invalid_math_env.step.remote( + basic_test_data["message_log_batch"], basic_test_data["metadata"] + ) + ) + + assert torch.all(result.rewards == 0.0) + assert torch.equal(result.reward_valid_mask, torch.tensor([False, False, False])) + + +def test_hf_verify_worker_defaults_to_hf_math_verify(): + """Test direct worker calls use hf_math_verify when the kwarg is omitted.""" + worker = HFVerifyWorker.remote() + try: + result = ray.get( + worker.verify.remote( + ["2 + 2 = \\boxed{4}"], + ["4"], + ) + ) + finally: + ray.kill(worker) + + assert result["scores"] == [1.0] + assert result["reward_valid_mask"] == [True] + + +def test_math_multi_reward_invalid_verifier_marks_rewards_invalid( + invalid_math_multi_reward_env, multi_reward_test_data +): + """Test forced verifier failures propagate invalid masks for multi-reward math.""" + result = ray.get( + invalid_math_multi_reward_env.step.remote( + multi_reward_test_data["message_log_batch"], + multi_reward_test_data["metadata"], + ) + ) + + assert result.rewards.shape == (3, 3) + assert torch.all(result.rewards == 0.0) + assert torch.equal(result.reward_valid_mask, torch.tensor([False, False, False])) diff --git a/tests/unit/experience/test_rollouts.py b/tests/unit/experience/test_rollouts.py index f3486de21e..4c7a21e2b9 100644 --- a/tests/unit/experience/test_rollouts.py +++ b/tests/unit/experience/test_rollouts.py @@ -17,6 +17,7 @@ import tempfile from copy import deepcopy from dataclasses import asdict +from types import SimpleNamespace import pytest import ray @@ -30,6 +31,7 @@ from nemo_rl.data.processors import nemo_gym_data_processor from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.virtual_cluster import RayVirtualCluster +from nemo_rl.environments.interfaces import EnvironmentInterface, EnvironmentReturn from nemo_rl.environments.games.sliding_puzzle import ( SlidingPuzzleConfig, SlidingPuzzleEnv, @@ -65,6 +67,49 @@ MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" +class _DummyTokenizer: + pad_token_id = 0 + + def __call__( + self, + text: str, + return_tensors: str = "pt", + add_special_tokens: bool = False, + ) -> SimpleNamespace: + del return_tensors, add_special_tokens + num_tokens = 0 if text == "" else max(1, len(text.split())) + input_ids = torch.arange(num_tokens, dtype=torch.long).unsqueeze(0) + return SimpleNamespace(input_ids=input_ids) + + +@ray.remote(num_cpus=0) +class InvalidRewardMaskEnv(EnvironmentInterface): + def __init__(self, rewards: list[float], reward_valid_mask: list[bool]): + self.rewards = rewards + self.reward_valid_mask = reward_valid_mask + + def step(self, message_log_batch, metadata) -> EnvironmentReturn: + sample_indices = [sample_metadata["sample_idx"] for sample_metadata in metadata] + return EnvironmentReturn( + observations=[{"role": "environment", "content": "feedback"}] + * len(message_log_batch), + metadata=metadata, + next_stop_strings=[None] * len(message_log_batch), + rewards=torch.tensor( + [self.rewards[i] for i in sample_indices], dtype=torch.float32 + ), + terminateds=torch.ones(len(message_log_batch), dtype=torch.bool), + answers=[None] * len(message_log_batch), + reward_valid_mask=torch.tensor( + [self.reward_valid_mask[i] for i in sample_indices], + dtype=torch.bool, + ), + ) + + def global_post_process_and_metrics(self, batch: BatchedDataDict) -> tuple: + return batch, {} + + class TestCalculateSingleMetric: """Unit tests for _calculate_single_metric function.""" @@ -94,11 +139,152 @@ def test_multiple_values_computes_stddev(self): assert result["test/median"] == 2.0 assert abs(result["test/stddev"] - 1.0) < 1e-9 # stdev of [1,2,3] is 1.0 - def test_two_identical_values_returns_zero_stddev(self): - """Test that stddev is 0 when all values are identical.""" - result = _calculate_single_metric([5.0, 5.0], batch_size=2, key_name="test") - assert result["test/stddev"] == 0.0 +def test_run_multi_turn_rollout_propagates_reward_valid_mask(monkeypatch): + """Test sync rollouts preserve reward validity returned by the environment.""" + + def fake_generate_responses( + policy_generation, + generation_input_data, + batch, + tokenizer, + input_lengths, + include_logprobs=True, + greedy=False, + ): + del ( + policy_generation, + generation_input_data, + tokenizer, + input_lengths, + include_logprobs, + greedy, + ) + generated_ids = [] + for i in range(len(batch["message_log"])): + token_ids = torch.tensor([11, 12], dtype=torch.long) + batch["message_log"][i].append( + { + "role": "assistant", + "content": "assistant response", + "token_ids": token_ids, + } + ) + generated_ids.append(token_ids) + return batch, generated_ids, { + "mean_generation_length": 2.0, + "total_generated_tokens": 2 * len(batch["message_log"]), + } + + monkeypatch.setattr( + "nemo_rl.experience.rollouts.generate_responses", fake_generate_responses + ) + + tokenizer = _DummyTokenizer() + env = InvalidRewardMaskEnv.remote([0.0, 1.0], [False, True]) + input_batch = BatchedDataDict( + { + "message_log": [ + [{"role": "user", "content": "prompt 1", "token_ids": torch.tensor([1])}], + [{"role": "user", "content": "prompt 2", "token_ids": torch.tensor([2])}], + ], + "extra_env_info": [{"sample_idx": 0}, {"sample_idx": 1}], + "loss_multiplier": torch.tensor([1.0, 1.0]), + "idx": [0, 1], + "task_name": ["invalid_reward_env", "invalid_reward_env"], + } + ) + + try: + final_batch, rollout_metrics = run_multi_turn_rollout( + policy_generation=object(), + input_batch=input_batch, + tokenizer=tokenizer, + task_to_env={"invalid_reward_env": env}, + max_seq_len=64, + max_rollout_turns=1, + ) + finally: + ray.kill(env) + + assert torch.equal(final_batch["reward_valid_mask"], torch.tensor([False, True])) + assert torch.allclose(final_batch["total_reward"], torch.tensor([0.0, 1.0])) + assert rollout_metrics["invalid_reward_count"] == 1 + assert rollout_metrics["invalid_reward_rate"] == 0.5 + + +def test_run_async_multi_turn_rollout_propagates_reward_valid_mask(monkeypatch): + """Test async rollouts preserve reward validity returned by the environment.""" + + async def fake_async_generate_response_for_sample_turn( + policy_generation, + sample_message_log, + sample_stop_strings, + tokenizer, + max_seq_len, + greedy=False, + ): + del ( + policy_generation, + sample_stop_strings, + tokenizer, + max_seq_len, + greedy, + ) + token_ids = torch.tensor([21, 22], dtype=torch.long) + updated_message_log = deepcopy(sample_message_log) + updated_message_log.append( + { + "role": "assistant", + "content": "assistant response", + "token_ids": token_ids, + } + ) + return updated_message_log, token_ids, torch.tensor(1), {} + + monkeypatch.setattr( + "nemo_rl.experience.rollouts.async_generate_response_for_sample_turn", + fake_async_generate_response_for_sample_turn, + ) + + tokenizer = _DummyTokenizer() + env = InvalidRewardMaskEnv.remote([0.0, 1.0], [False, True]) + input_batch = BatchedDataDict( + { + "message_log": [ + [{"role": "user", "content": "prompt 1", "token_ids": torch.tensor([1])}], + [{"role": "user", "content": "prompt 2", "token_ids": torch.tensor([2])}], + ], + "extra_env_info": [{"sample_idx": 0}, {"sample_idx": 1}], + "loss_multiplier": torch.tensor([1.0, 1.0]), + "idx": [0, 1], + "task_name": ["invalid_reward_env", "invalid_reward_env"], + } + ) + + try: + final_batch, rollout_metrics = run_async_multi_turn_rollout( + policy_generation=object(), + input_batch=input_batch, + tokenizer=tokenizer, + task_to_env={"invalid_reward_env": env}, + max_seq_len=64, + max_rollout_turns=1, + ) + finally: + ray.kill(env) + + assert torch.equal(final_batch["reward_valid_mask"], torch.tensor([False, True])) + assert torch.allclose(final_batch["total_reward"], torch.tensor([0.0, 1.0])) + assert rollout_metrics["invalid_reward_count"] == 1 + assert rollout_metrics["invalid_reward_rate"] == 0.5 + + +def test_two_identical_values_returns_zero_stddev(): + """Test that stddev is 0 when all values are identical.""" + result = _calculate_single_metric([5.0, 5.0], batch_size=2, key_name="test") + + assert result["test/stddev"] == 0.0 @pytest.fixture(scope="function")