Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 107 additions & 13 deletions atroposlib/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,41 @@ class BaseEnvConfig(BaseModel):
"no thinking prompt is injected. Use HERMES_REASONING_PROMPT from "
"eval_helpers for the standard Hermes reasoning prompt.",
)
reward_normalization: str = Field(
default="none",
description="Reward normalization mode. 'none' = disabled (default), "
"'zscore' = z-score normalization, 'minmax' = min-max to [0,1]. "
"Uses Welford's online algorithm for running statistics.",
)
reward_clip: float = Field(
default=5.0,
description="Maximum absolute reward value after normalization. "
"Only applies when reward_normalization is not 'none'. "
"Set to 0 to disable clipping.",
)
reward_normalization_warmup: int = Field(
default=10,
description="Number of scored batches to observe before activating "
"reward normalization. During warmup, raw scores are used.",
)
curriculum_strategy: str = Field(
default="uniform",
description="Curriculum learning strategy. 'uniform' = no curriculum (default), "
"'easy_first' = oversample easy items early then anneal, "
"'competence_based' = sample at competence frontier. "
"See Platanios et al. 2019 for competence-based curriculum.",
)
curriculum_bins: int = Field(
default=5,
ge=1,
description="Number of difficulty bins for curriculum scheduling.",
)
curriculum_temperature: float = Field(
default=1.0,
gt=0,
description="Temperature for curriculum bin sampling. Higher = more uniform, "
"lower = more concentrated on target difficulty.",
)


class BaseEnv(ABC):
Expand Down Expand Up @@ -262,6 +297,34 @@ def __init__(
self.max_token_len = -1
self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
self.completion_lengths = []
# Initialize reward normalizer (opt-in via config)
if config.reward_normalization != "none":
from atroposlib.envs.reward_normalization import RewardNormalizer

self.reward_normalizer = RewardNormalizer(
mode=config.reward_normalization,
clip=config.reward_clip,
warmup=config.reward_normalization_warmup,
)
else:
self.reward_normalizer = None

# Initialize curriculum scheduler (opt-in via config)
if config.curriculum_strategy != "uniform":
from atroposlib.envs.curriculum import CurriculumScheduler

self.curriculum = CurriculumScheduler(
strategy=config.curriculum_strategy,
n_bins=config.curriculum_bins,
temperature=config.curriculum_temperature,
)
else:
self.curriculum = None

# Initialize API performance tracker for trainer-inference latency monitoring
from atroposlib.utils.api_perf import APIPerformanceTracker

self.api_perf_tracker = APIPerformanceTracker()
self.max_num_workers = config.max_num_workers
if self.max_num_workers == -1:
self.max_num_workers = config.max_num_workers_per_node * len(
Expand Down Expand Up @@ -656,8 +719,9 @@ def wandb_log(self, wandb_metrics: Optional[Dict] = None):
"""
if wandb_metrics is None:
wandb_metrics = dict()
server_wandb_metrics = {}
for i, server in enumerate(self.server.servers):
server_wandb_metrics = await server.wandb_metrics({}, f"server_{i}")
server_wandb_metrics.update(await server.wandb_metrics({}, f"server_{i}"))
if len(self.completion_lengths) > 0:
wandb_metrics["train/completion_lengths"] = sum(
self.completion_lengths
Expand All @@ -674,6 +738,14 @@ def wandb_log(self, wandb_metrics: Optional[Dict] = None):
wandb_metrics["train/completion_lengths_p95"] = (
np.array(self.completion_lengths) > (0.95 * self.max_token_len)
).mean()
# Log reward normalization metrics if active
if self.reward_normalizer is not None:
wandb_metrics.update(self.reward_normalizer.metrics_dict())
# Log curriculum metrics if active
if self.curriculum is not None:
wandb_metrics.update(self.curriculum.metrics_dict())
# Log API performance metrics
wandb_metrics.update(self.api_perf_tracker.metrics_dict())
wandb_metrics = await self.create_rollout_table(wandb_metrics)
wandb_metrics = self.perf_stats(wandb_metrics)
self.rollouts_for_wandb = []
Expand Down Expand Up @@ -798,32 +870,44 @@ async def evaluate_log(
async def _send_scored_data_to_api(self, scored_data):
"""
Send scored data to the API with retry logic for timeouts and server errors.
Tracks latency and payload metrics via APIPerformanceTracker.
"""
# Add env_id to the data
if isinstance(scored_data, list):
for item in scored_data:
item["env_id"] = getattr(self, "env_id", None)
n_items = sum(len(item.get("tokens", [])) for item in scored_data)
else:
scored_data["env_id"] = getattr(self, "env_id", None)
n_items = len(scored_data.get("tokens", []))

url = (
f"{self.config.rollout_server_url}/scored_data_list"
if isinstance(scored_data, list)
else f"{self.config.rollout_server_url}/scored_data"
)

# Serialize to compute payload size for tracking
serialized = json.dumps(scored_data).encode("utf-8")
payload_bytes = len(serialized)

async with aiohttp.ClientSession() as session:
async with self._post_json_with_compression(
session,
url,
scored_data,
) as resp:
if resp.status >= 500:
logging.debug(f"Server error: {resp.status}, retrying...")
raise Exception(f"Server error: {resp.status}")
elif resp.status >= 400:
logging.error(f"Client error: {resp.status}, not retrying")
return
logger.debug(await resp.text())
with self.api_perf_tracker.track_request(
n_items=n_items,
payload_bytes=payload_bytes,
):
async with self._post_json_with_compression(
session,
url,
scored_data,
) as resp:
if resp.status >= 500:
logging.debug(f"Server error: {resp.status}, retrying...")
raise Exception(f"Server error: {resp.status}")
elif resp.status >= 400:
logging.error(f"Client error: {resp.status}, not retrying")
return
logger.debug(await resp.text())

def _post_json_with_compression(
self,
Expand Down Expand Up @@ -892,6 +976,16 @@ async def handle_send_to_api(
logger.warning("Scores are the same in a group, skipping...")
continue

# Apply reward normalization if enabled (opt-in via config)
if self.reward_normalizer is not None:
group["scores"] = self.reward_normalizer.normalize(group["scores"])
# Re-check after normalization: if all scores collapsed, skip
if len(set(group["scores"])) == 1:
logger.debug(
"Scores collapsed to same value after normalization, skipping"
)
continue

group.setdefault("ref_logprobs", None)
group.setdefault("overrides", None)
group.setdefault("group_overrides", None)
Expand Down
Loading