Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 58 additions & 11 deletions nemo_rl/algorithms/advantage_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,26 @@
)


def _get_sample_valid_mask(
sample_valid_mask: torch.Tensor | None,
reference: torch.Tensor,
) -> torch.Tensor:
"""Normalize optional sample validity info to a boolean tensor."""
if sample_valid_mask is None:
return torch.ones_like(reference, dtype=torch.bool)
normalized_mask = sample_valid_mask.to(
device=reference.device,
dtype=torch.bool,
).reshape(-1)
if normalized_mask.shape[0] != reference.shape[0]:
raise ValueError(
"sample_valid_mask must have one element per sample; "
f"got {normalized_mask.shape[0]} values for "
f"{reference.shape[0]} samples"
)
return normalized_mask


class GRPOAdvantageEstimator:
"""GRPO-style advantage estimator with leave-one-out baseline.

Expand All @@ -42,7 +62,9 @@ def __init__(self, estimator_config: dict, loss_config: dict):
self.use_leave_one_out_baseline = estimator_config["use_leave_one_out_baseline"]
self.normalize_rewards = estimator_config["normalize_rewards"]

def compute_advantage(self, prompt_ids, rewards, mask, **kwargs):
def compute_advantage(
self, prompt_ids, rewards, mask, sample_valid_mask=None, **kwargs
):
"""Compute GRPO advantages.

Args:
Expand All @@ -55,13 +77,15 @@ def compute_advantage(self, prompt_ids, rewards, mask, **kwargs):
Returns:
Advantages tensor of shape [batch_size, seq_len].
"""
sample_valid_mask = _get_sample_valid_mask(sample_valid_mask, rewards)
baseline, std = calculate_baseline_and_std_per_prompt(
prompt_ids,
rewards,
torch.ones_like(rewards),
sample_valid_mask,
leave_one_out_baseline=self.use_leave_one_out_baseline,
)
advantages = (rewards - baseline).unsqueeze(-1)
advantages = advantages * sample_valid_mask.unsqueeze(-1).float()

if self.normalize_rewards:
# don't sharpen the ones with no variation
Expand Down Expand Up @@ -90,6 +114,7 @@ def compute_advantage(
rewards,
mask,
repeated_batch,
sample_valid_mask=None,
**kwargs,
):
"""Compute GDPO advantages.
Expand All @@ -111,7 +136,10 @@ def compute_advantage(
f"This batch has {len(reward_component_keys)} component(s). "
"Switch to GRPO by setting grpo.adv_estimator.name to 'grpo' in your config."
)
valid = torch.ones_like(repeated_batch[reward_component_keys[0]])
valid = _get_sample_valid_mask(
sample_valid_mask,
repeated_batch[reward_component_keys[0]],
)
leave_one_out = self.use_leave_one_out_baseline
assert prompt_ids.shape[0] == valid.shape[0], (
"prompt_ids must match reward batch size; "
Expand All @@ -137,12 +165,17 @@ def compute_advantage(
advantage_parts.append(adv_k)

advantages = sum(advantage_parts)
# Normalize combined advantage to zero mean and unit std
adv_std = advantages.std()
if adv_std > 0:
advantages = (advantages - advantages.mean()) / adv_std
valid_advantages = advantages[valid]
if valid_advantages.numel() <= 1:
advantages = torch.zeros_like(advantages)
else:
advantages = advantages - advantages.mean()
# Normalize combined advantage to zero mean and unit std using only valid samples.
adv_mean = valid_advantages.mean()
adv_std = valid_advantages.std()
advantages = advantages - adv_mean
if adv_std > 0:
advantages[valid] = advantages[valid] / adv_std
advantages[~valid] = 0.0

return advantages.expand(mask.shape)

Expand All @@ -166,6 +199,7 @@ def compute_advantage(
prompt_ids,
rewards,
mask,
sample_valid_mask=None,
logprobs_policy=None,
logprobs_reference=None,
**kwargs,
Expand All @@ -185,18 +219,25 @@ def compute_advantage(
Returns:
Advantages tensor of shape [batch_size, seq_len], globally normalized across valid tokens.
"""
sample_valid_mask = _get_sample_valid_mask(sample_valid_mask, rewards)
sample_valid_token_mask = sample_valid_mask.unsqueeze(-1).to(
device=mask.device,
dtype=mask.dtype,
)
effective_mask = mask * sample_valid_token_mask
# minus baseline
if self.minus_baseline:
mean, _ = calculate_baseline_and_std_per_prompt(
prompt_ids,
rewards,
torch.ones_like(rewards),
sample_valid_mask,
leave_one_out_baseline=False,
)
adv = rewards - mean
else:
adv = rewards

adv = adv * sample_valid_mask.float()
adv = adv.unsqueeze(-1)
adv = adv.expand(mask.shape)

Expand All @@ -212,11 +253,17 @@ def compute_advantage(
kl_type=self.kl_type,
)
adv = adv - self.kl_coef * kl
adv = adv * sample_valid_token_mask

# global normalization across the batch
adv_mean = (adv * mask).sum() / mask.sum()
adv_var = ((adv - adv_mean).pow(2) * mask).sum() / mask.sum()
if effective_mask.sum() == 0:
return torch.zeros_like(adv)
adv_mean = (adv * effective_mask).sum() / effective_mask.sum()
adv_var = (
((adv - adv_mean).pow(2) * effective_mask).sum() / effective_mask.sum()
)
adv_rstd = adv_var.clamp(min=1e-8).rsqrt()
adv = (adv - adv_mean) * adv_rstd
adv = adv * sample_valid_token_mask

return adv
Loading
Loading