diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 3d3b6c207..4120620ce 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -211,6 +211,41 @@ class BaseEnvConfig(BaseModel): "no thinking prompt is injected. Use HERMES_REASONING_PROMPT from " "eval_helpers for the standard Hermes reasoning prompt.", ) + reward_normalization: str = Field( + default="none", + description="Reward normalization mode. 'none' = disabled (default), " + "'zscore' = z-score normalization, 'minmax' = min-max to [0,1]. " + "Uses Welford's online algorithm for running statistics.", + ) + reward_clip: float = Field( + default=5.0, + description="Maximum absolute reward value after normalization. " + "Only applies when reward_normalization is not 'none'. " + "Set to 0 to disable clipping.", + ) + reward_normalization_warmup: int = Field( + default=10, + description="Number of scored batches to observe before activating " + "reward normalization. During warmup, raw scores are used.", + ) + curriculum_strategy: str = Field( + default="uniform", + description="Curriculum learning strategy. 'uniform' = no curriculum (default), " + "'easy_first' = oversample easy items early then anneal, " + "'competence_based' = sample at competence frontier. " + "See Platanios et al. 2019 for competence-based curriculum.", + ) + curriculum_bins: int = Field( + default=5, + ge=1, + description="Number of difficulty bins for curriculum scheduling.", + ) + curriculum_temperature: float = Field( + default=1.0, + gt=0, + description="Temperature for curriculum bin sampling. Higher = more uniform, " + "lower = more concentrated on target difficulty.", + ) class BaseEnv(ABC): @@ -262,6 +297,34 @@ def __init__( self.max_token_len = -1 self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) self.completion_lengths = [] + # Initialize reward normalizer (opt-in via config) + if config.reward_normalization != "none": + from atroposlib.envs.reward_normalization import RewardNormalizer + + self.reward_normalizer = RewardNormalizer( + mode=config.reward_normalization, + clip=config.reward_clip, + warmup=config.reward_normalization_warmup, + ) + else: + self.reward_normalizer = None + + # Initialize curriculum scheduler (opt-in via config) + if config.curriculum_strategy != "uniform": + from atroposlib.envs.curriculum import CurriculumScheduler + + self.curriculum = CurriculumScheduler( + strategy=config.curriculum_strategy, + n_bins=config.curriculum_bins, + temperature=config.curriculum_temperature, + ) + else: + self.curriculum = None + + # Initialize API performance tracker for trainer-inference latency monitoring + from atroposlib.utils.api_perf import APIPerformanceTracker + + self.api_perf_tracker = APIPerformanceTracker() self.max_num_workers = config.max_num_workers if self.max_num_workers == -1: self.max_num_workers = config.max_num_workers_per_node * len( @@ -656,8 +719,9 @@ def wandb_log(self, wandb_metrics: Optional[Dict] = None): """ if wandb_metrics is None: wandb_metrics = dict() + server_wandb_metrics = {} for i, server in enumerate(self.server.servers): - server_wandb_metrics = await server.wandb_metrics({}, f"server_{i}") + server_wandb_metrics.update(await server.wandb_metrics({}, f"server_{i}")) if len(self.completion_lengths) > 0: wandb_metrics["train/completion_lengths"] = sum( self.completion_lengths @@ -674,6 +738,14 @@ def wandb_log(self, wandb_metrics: Optional[Dict] = None): wandb_metrics["train/completion_lengths_p95"] = ( np.array(self.completion_lengths) > (0.95 * self.max_token_len) ).mean() + # Log reward normalization metrics if active + if self.reward_normalizer is not None: + wandb_metrics.update(self.reward_normalizer.metrics_dict()) + # Log curriculum metrics if active + if self.curriculum is not None: + wandb_metrics.update(self.curriculum.metrics_dict()) + # Log API performance metrics + wandb_metrics.update(self.api_perf_tracker.metrics_dict()) wandb_metrics = await self.create_rollout_table(wandb_metrics) wandb_metrics = self.perf_stats(wandb_metrics) self.rollouts_for_wandb = [] @@ -798,32 +870,44 @@ async def evaluate_log( async def _send_scored_data_to_api(self, scored_data): """ Send scored data to the API with retry logic for timeouts and server errors. + Tracks latency and payload metrics via APIPerformanceTracker. """ # Add env_id to the data if isinstance(scored_data, list): for item in scored_data: item["env_id"] = getattr(self, "env_id", None) + n_items = sum(len(item.get("tokens", [])) for item in scored_data) else: scored_data["env_id"] = getattr(self, "env_id", None) + n_items = len(scored_data.get("tokens", [])) url = ( f"{self.config.rollout_server_url}/scored_data_list" if isinstance(scored_data, list) else f"{self.config.rollout_server_url}/scored_data" ) + + # Serialize to compute payload size for tracking + serialized = json.dumps(scored_data).encode("utf-8") + payload_bytes = len(serialized) + async with aiohttp.ClientSession() as session: - async with self._post_json_with_compression( - session, - url, - scored_data, - ) as resp: - if resp.status >= 500: - logging.debug(f"Server error: {resp.status}, retrying...") - raise Exception(f"Server error: {resp.status}") - elif resp.status >= 400: - logging.error(f"Client error: {resp.status}, not retrying") - return - logger.debug(await resp.text()) + with self.api_perf_tracker.track_request( + n_items=n_items, + payload_bytes=payload_bytes, + ): + async with self._post_json_with_compression( + session, + url, + scored_data, + ) as resp: + if resp.status >= 500: + logging.debug(f"Server error: {resp.status}, retrying...") + raise Exception(f"Server error: {resp.status}") + elif resp.status >= 400: + logging.error(f"Client error: {resp.status}, not retrying") + return + logger.debug(await resp.text()) def _post_json_with_compression( self, @@ -892,6 +976,16 @@ async def handle_send_to_api( logger.warning("Scores are the same in a group, skipping...") continue + # Apply reward normalization if enabled (opt-in via config) + if self.reward_normalizer is not None: + group["scores"] = self.reward_normalizer.normalize(group["scores"]) + # Re-check after normalization: if all scores collapsed, skip + if len(set(group["scores"])) == 1: + logger.debug( + "Scores collapsed to same value after normalization, skipping" + ) + continue + group.setdefault("ref_logprobs", None) group.setdefault("overrides", None) group.setdefault("group_overrides", None) diff --git a/atroposlib/envs/curriculum.py b/atroposlib/envs/curriculum.py new file mode 100644 index 000000000..2f8fd53a1 --- /dev/null +++ b/atroposlib/envs/curriculum.py @@ -0,0 +1,354 @@ +from __future__ import annotations + +""" +Curriculum learning scheduler for sample-efficient RL training. + +Implements automatic difficulty-based sampling for environments, tracking +per-item difficulty from reward signals and adjusting sampling probabilities +to focus training on appropriately challenging examples. + +Strategies: +- uniform: No curriculum (baseline, default) +- easy_first: Oversample easy items early, anneal to uniform +- competence_based: Sample items at the competence frontier (reward ~ 0.5), + following Platanios et al. 2019 (https://arxiv.org/abs/1904.03746) + +Usage: + scheduler = CurriculumScheduler( + strategy="competence_based", + n_bins=5, + temperature=1.0, + ) + + # After scoring an item + scheduler.update("item_key_123", reward_score=0.7) + + # When selecting next item + target_bin = scheduler.sample_bin(current_step=50, total_steps=1000) +""" + +import logging +import math +import random +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + + +class CurriculumStrategy(str, Enum): + """Available curriculum learning strategies.""" + + UNIFORM = "uniform" + EASY_FIRST = "easy_first" + COMPETENCE_BASED = "competence_based" + + +class CurriculumScheduler: + """ + Curriculum learning scheduler that tracks item difficulty and provides + difficulty-aware sampling. + + Maintains an exponential moving average (EMA) of reward scores per item + to estimate difficulty. Items are binned by difficulty quantile, and the + sampling strategy determines which bins are preferred at each stage of + training. + + Args: + strategy: Sampling strategy. One of "uniform", "easy_first", + "competence_based". + n_bins: Number of difficulty bins. Default: 5. + temperature: Controls sampling sharpness. Higher = more uniform, + lower = more concentrated on target bin. Default: 1.0. + ema_alpha: EMA smoothing factor for difficulty scores. Higher values + give more weight to recent rewards. Default: 0.3. + competence_threshold: For competence_based strategy, the target + reward level considered "at frontier". Default: 0.5. + """ + + def __init__( + self, + strategy: str = "uniform", + n_bins: int = 5, + temperature: float = 1.0, + ema_alpha: float = 0.3, + competence_threshold: float = 0.5, + ): + # Validate strategy + try: + self._strategy = CurriculumStrategy(strategy) + except ValueError: + valid = [s.value for s in CurriculumStrategy] + raise ValueError( + f"Invalid curriculum strategy '{strategy}'. Must be one of: {valid}" + ) + + if n_bins < 1: + raise ValueError(f"n_bins must be >= 1, got {n_bins}") + + self.n_bins = n_bins + self.temperature = max(0.01, temperature) + self.ema_alpha = max(0.0, min(1.0, ema_alpha)) + self.competence_threshold = competence_threshold + + # Per-item difficulty tracking: key -> (ema_score, count) + self._item_scores: Dict[str, Tuple[float, int]] = {} + + # Bin boundaries (recomputed periodically) + self._bin_boundaries: List[float] = [] + self._last_rebin_count: int = 0 + self._rebin_interval: int = 50 # Recompute bins every N updates + + @property + def strategy(self) -> str: + """Current strategy name.""" + return self._strategy.value + + @property + def n_items_tracked(self) -> int: + """Number of unique items being tracked.""" + return len(self._item_scores) + + def update(self, item_key: str, score: float) -> None: + """ + Update difficulty estimate for an item based on its reward score. + + Uses exponential moving average so recent performance has more + influence than historical. + + Args: + item_key: Unique identifier for the item (e.g., dataset index). + score: Reward score achieved on this item. Higher = easier. + """ + if item_key in self._item_scores: + old_ema, count = self._item_scores[item_key] + new_ema = self.ema_alpha * score + (1 - self.ema_alpha) * old_ema + self._item_scores[item_key] = (new_ema, count + 1) + else: + self._item_scores[item_key] = (score, 1) + + # Periodically recompute bin boundaries + total_updates = sum(c for _, c in self._item_scores.values()) + if total_updates - self._last_rebin_count >= self._rebin_interval: + self._recompute_bins() + self._last_rebin_count = total_updates + + def update_batch(self, item_key: str, scores: List[float]) -> None: + """ + Update difficulty estimate with multiple scores (e.g., from group_size). + + Args: + item_key: Unique identifier for the item. + scores: List of reward scores from the group rollout. + """ + if not scores: + return + avg_score = sum(scores) / len(scores) + self.update(item_key, avg_score) + + def get_item_difficulty(self, item_key: str) -> Optional[float]: + """ + Get the current difficulty estimate for an item. + + Returns: + EMA reward score (higher = easier), or None if item not tracked. + """ + if item_key not in self._item_scores: + return None + return self._item_scores[item_key][0] + + def get_item_bin(self, item_key: str) -> int: + """ + Get the difficulty bin for an item. + + Args: + item_key: Unique identifier for the item. + + Returns: + Bin index (0 = easiest, n_bins-1 = hardest). + Returns middle bin if item is not tracked. + """ + difficulty = self.get_item_difficulty(item_key) + if difficulty is None: + return self.n_bins // 2 # Default to middle bin + + if not self._bin_boundaries: + self._recompute_bins() + + # Bin assignment: higher score = lower bin index (easier) + # We invert so bin 0 = easiest (highest reward) + for i, boundary in enumerate(self._bin_boundaries): + if difficulty >= boundary: + return i + return self.n_bins - 1 + + def sample_bin(self, current_step: int = 0, total_steps: int = 1000) -> int: + """ + Sample a target difficulty bin based on the curriculum strategy. + + Args: + current_step: Current training step (for annealing strategies). + total_steps: Total training steps planned. + + Returns: + Target bin index to sample from (0 = easiest, n_bins-1 = hardest). + """ + if self._strategy == CurriculumStrategy.UNIFORM: + return random.randint(0, self.n_bins - 1) + + # Compute bin probabilities + probs = self._compute_bin_probabilities(current_step, total_steps) + + # Temperature-scaled sampling + if self.temperature != 1.0: + log_probs = [math.log(max(p, 1e-10)) / self.temperature for p in probs] + max_lp = max(log_probs) + exp_probs = [math.exp(lp - max_lp) for lp in log_probs] + total = sum(exp_probs) + probs = [p / total for p in exp_probs] + + # Weighted random choice + return random.choices(range(self.n_bins), weights=probs, k=1)[0] + + def _compute_bin_probabilities( + self, current_step: int, total_steps: int + ) -> List[float]: + """Compute sampling probabilities for each bin.""" + progress = min(1.0, max(0.0, current_step / max(1, total_steps))) + + if self._strategy == CurriculumStrategy.EASY_FIRST: + return self._easy_first_probs(progress) + elif self._strategy == CurriculumStrategy.COMPETENCE_BASED: + return self._competence_based_probs(progress) + else: + # Uniform fallback + return [1.0 / self.n_bins] * self.n_bins + + def _easy_first_probs(self, progress: float) -> List[float]: + """ + Easy-first: linearly anneal from easy-biased to uniform. + + At progress=0: strongly prefer easy items (bin 0). + At progress=1: uniform sampling across all bins. + """ + probs = [] + for i in range(self.n_bins): + # Base: uniform + uniform_prob = 1.0 / self.n_bins + # Bias: exponential decay favoring low bins (easy) + easy_bias = math.exp(-2.0 * i / max(1, self.n_bins - 1)) + # Anneal from biased to uniform + prob = (1.0 - progress) * easy_bias + progress * uniform_prob + probs.append(prob) + + # Normalize + total = sum(probs) + return [p / total for p in probs] + + def _competence_based_probs(self, progress: float) -> List[float]: + """ + Competence-based: sample items near the competence frontier. + + The frontier moves from easy to hard as training progresses. + Items where expected reward ~ competence_threshold are preferred. + """ + # Competence level increases with training progress + # Maps to which bin is at the frontier + frontier_bin = progress * (self.n_bins - 1) + + probs = [] + for i in range(self.n_bins): + # Gaussian-like probability centered on frontier bin + distance = abs(i - frontier_bin) + prob = math.exp(-0.5 * (distance**2)) + probs.append(prob) + + total = sum(probs) + return [p / total for p in probs] + + def _recompute_bins(self) -> None: + """Recompute bin boundaries based on current difficulty quantiles.""" + if not self._item_scores: + self._bin_boundaries = [] + return + + # Sort scores descending (highest reward = easiest = bin 0) + scores = sorted([ema for ema, _ in self._item_scores.values()], reverse=True) + + if len(scores) < self.n_bins: + # Not enough items to properly bin, use equal spacing + min_s = min(scores) + max_s = max(scores) + if max_s == min_s: + self._bin_boundaries = [min_s] * self.n_bins + else: + step = (max_s - min_s) / self.n_bins + self._bin_boundaries = [max_s - i * step for i in range(self.n_bins)] + return + + # Quantile-based boundaries + boundaries = [] + for i in range(self.n_bins): + idx = int(i * len(scores) / self.n_bins) + idx = min(idx, len(scores) - 1) + boundaries.append(scores[idx]) + self._bin_boundaries = boundaries + + def metrics_dict(self) -> Dict[str, float]: + """ + Return curriculum stats for WandB logging. + + Returns: + Dictionary with keys suitable for wandb.log(). + """ + if not self._item_scores: + return { + "curriculum/items_tracked": 0, + "curriculum/strategy": 0, # Can't log strings to wandb + } + + scores = [ema for ema, _ in self._item_scores.values()] + counts = [c for _, c in self._item_scores.values()] + + metrics = { + "curriculum/items_tracked": float(len(scores)), + "curriculum/mean_difficulty": sum(scores) / len(scores), + "curriculum/min_difficulty": min(scores), + "curriculum/max_difficulty": max(scores), + "curriculum/total_updates": float(sum(counts)), + } + + # Bin distribution + if self._bin_boundaries: + bin_counts = [0] * self.n_bins + for key in self._item_scores: + bin_idx = self.get_item_bin(key) + bin_counts[bin_idx] += 1 + for i, count in enumerate(bin_counts): + metrics[f"curriculum/bin_{i}_count"] = float(count) + + return metrics + + def state_dict(self) -> Dict[str, Any]: + """Serialize state for checkpointing.""" + return { + "strategy": self._strategy.value, + "n_bins": self.n_bins, + "temperature": self.temperature, + "ema_alpha": self.ema_alpha, + "competence_threshold": self.competence_threshold, + "item_scores": dict(self._item_scores), + "bin_boundaries": self._bin_boundaries, + "last_rebin_count": self._last_rebin_count, + } + + def load_state_dict(self, state: Dict[str, Any]) -> None: + """Restore state from checkpoint.""" + self._strategy = CurriculumStrategy(state["strategy"]) + self.n_bins = state["n_bins"] + self.temperature = state["temperature"] + self.ema_alpha = state["ema_alpha"] + self.competence_threshold = state["competence_threshold"] + self._item_scores = {k: tuple(v) for k, v in state["item_scores"].items()} + self._bin_boundaries = state["bin_boundaries"] + self._last_rebin_count = state["last_rebin_count"] diff --git a/atroposlib/envs/reward_fns/__init__.py b/atroposlib/envs/reward_fns/__init__.py index 411d61490..9ea6d0b04 100644 --- a/atroposlib/envs/reward_fns/__init__.py +++ b/atroposlib/envs/reward_fns/__init__.py @@ -24,7 +24,8 @@ def compute(self, completions, **kwargs): """ from .combined_reward import CombinedReward +from .ensemble_reward import EnsembleReward from .registry import registry from .reward_function import RewardFunction -__all__ = ["RewardFunction", "registry", "CombinedReward"] +__all__ = ["RewardFunction", "registry", "CombinedReward", "EnsembleReward"] diff --git a/atroposlib/envs/reward_fns/ensemble_reward.py b/atroposlib/envs/reward_fns/ensemble_reward.py new file mode 100644 index 000000000..811788880 --- /dev/null +++ b/atroposlib/envs/reward_fns/ensemble_reward.py @@ -0,0 +1,314 @@ +""" +Ensemble reward function with robust aggregation and inter-rater reliability. + +Extends the CombinedReward pattern with: +- Multiple aggregation strategies (mean, median, min, majority_vote) +- Inter-rater reliability metrics (Krippendorff's alpha) +- Disagreement tracking for reward hacking detection + +Usage: + reward_fn = registry.create("ensemble", rewards=["accuracy", "format"], strategy="median") + scores = reward_fn(completions, **kwargs) + + # Access reliability metrics + alpha = reward_fn.last_reliability_alpha +""" + +import logging +import warnings +from typing import Any, Dict, List, Optional, Union + +import numpy as np + +from .registry import registry +from .reward_function import RewardFunction + +logger = logging.getLogger(__name__) + + +def _krippendorff_alpha(ratings_matrix: np.ndarray) -> float: + """ + Compute Krippendorff's alpha for inter-rater reliability. + + Uses the interval/ratio metric (squared differences). + + Args: + ratings_matrix: Shape (n_raters, n_items). NaN values indicate + missing ratings and are excluded from computation. + + Returns: + Alpha value in [-1, 1]. 1 = perfect agreement, 0 = chance agreement, + negative = systematic disagreement. + """ + n_raters, n_items = ratings_matrix.shape + + if n_raters < 2 or n_items < 2: + return float("nan") + + # Build coincidence matrix approach using pairwise disagreements + # For each item, compute observed disagreement across all rater pairs + observed_disagreement = 0.0 + total_pairs = 0 + + for item_idx in range(n_items): + values = ratings_matrix[:, item_idx] + valid = values[~np.isnan(values)] + n_valid = len(valid) + if n_valid < 2: + continue + + # Sum of squared differences for all pairs within this item + for i in range(n_valid): + for j in range(i + 1, n_valid): + observed_disagreement += (valid[i] - valid[j]) ** 2 + total_pairs += 1 + + if total_pairs == 0: + return float("nan") + + observed_disagreement /= total_pairs + + # Expected disagreement: pairwise differences across ALL values + all_valid = ratings_matrix[~np.isnan(ratings_matrix)] + n_all = len(all_valid) + if n_all < 2: + return float("nan") + + expected_disagreement = 0.0 + expected_pairs = 0 + for i in range(n_all): + for j in range(i + 1, n_all): + expected_disagreement += (all_valid[i] - all_valid[j]) ** 2 + expected_pairs += 1 + + if expected_pairs == 0: + return float("nan") + + expected_disagreement /= expected_pairs + + if expected_disagreement == 0.0: + # All raters gave identical scores -- perfect agreement + return 1.0 + + alpha = 1.0 - (observed_disagreement / expected_disagreement) + return float(alpha) + + +@registry.register +class EnsembleReward(RewardFunction): + """ + Ensemble reward function that aggregates multiple reward functions + with robust strategies and inter-rater reliability tracking. + + Compared to CombinedReward, this adds: + - Median and min (conservative) aggregation for robustness + - Majority vote for binary reward environments + - Krippendorff's alpha inter-rater reliability metric + - Per-item disagreement tracking for reward hacking detection + + Strategies: + - "mean": Weighted average (same as CombinedReward) + - "median": Median across reward functions (robust to outliers) + - "min": Conservative -- use the minimum score (prevents reward hacking) + - "majority_vote": For binary rewards -- majority wins (ties -> positive) + """ + + def __init__( + self, + rewards: List[Union[str, Dict]], + strategy: str = "mean", + weight: float = 1.0, + track_disagreement: bool = True, + **kwargs, + ): + """ + Initialize the ensemble reward function. + + Args: + rewards: List of reward function names or config dicts. + Resolved via RewardRegistry. + strategy: Aggregation strategy. One of: "mean", "median", + "min", "majority_vote". + weight: Weight for this ensemble when used inside another + CombinedReward. + track_disagreement: If True, track per-item reward variance + for disagreement analysis. + **kwargs: Additional parameters passed to RewardFunction. + """ + super().__init__(weight=weight, **kwargs) + + valid_strategies = {"mean", "median", "min", "majority_vote"} + if strategy not in valid_strategies: + raise ValueError( + f"Invalid strategy '{strategy}'. Must be one of: {valid_strategies}" + ) + + self.strategy = strategy + self.track_disagreement = track_disagreement + self.reward_functions: List[RewardFunction] = [] + + # Initialize sub-reward functions via registry + for reward_config in rewards: + self.reward_functions.append(registry.create(reward_config)) + + if len(self.reward_functions) < 2: + warnings.warn( + "EnsembleReward initialized with fewer than 2 reward functions. " + "Inter-rater reliability metrics will not be meaningful.", + stacklevel=2, + ) + + # State for reliability tracking + self.last_reliability_alpha: float = float("nan") + self.last_disagreement_scores: Optional[List[float]] = None + self._all_sub_rewards: Optional[List[List[float]]] = None + + @property + def name(self) -> str: + sub_names = ",".join(r.name for r in self.reward_functions) + return f"ensemble_{self.strategy}({sub_names})" + + def set_wandb_logger(self, wandb_logger): + """Propagate WandB logger to all sub-reward functions.""" + super().set_wandb_logger(wandb_logger) + for reward_fn in self.reward_functions: + reward_fn.set_wandb_logger(wandb_logger) + + def compute(self, completions: List[Any], **kwargs) -> List[float]: + """ + Compute ensemble reward scores. + + Calls all sub-reward functions, aggregates by strategy, + and computes reliability metrics. + + Args: + completions: List of completions to evaluate. + **kwargs: Additional context passed to sub-rewards. + + Returns: + Aggregated reward scores, one per completion. + """ + if not completions: + return [] + + n_completions = len(completions) + + # Collect all sub-reward scores + all_rewards: List[List[float]] = [] + for reward_fn in self.reward_functions: + try: + scores = reward_fn.compute(completions, **kwargs) + if len(scores) != n_completions: + logger.warning( + "Reward function %s returned %d scores for %d completions. " + "Padding/truncating.", + reward_fn.name, + len(scores), + n_completions, + ) + # Pad or truncate + if len(scores) < n_completions: + scores = scores + [0.0] * (n_completions - len(scores)) + else: + scores = scores[:n_completions] + all_rewards.append(scores) + except Exception as e: + logger.error("Error in reward function %s: %s", reward_fn.name, e) + all_rewards.append([0.0] * n_completions) + + self._all_sub_rewards = all_rewards + + if not all_rewards: + return [0.0] * n_completions + + # Convert to numpy for efficient aggregation + # Shape: (n_reward_fns, n_completions) + reward_matrix = np.array(all_rewards, dtype=np.float64) + + # Aggregate by strategy + if self.strategy == "mean": + aggregated = np.mean(reward_matrix, axis=0) + elif self.strategy == "median": + aggregated = np.median(reward_matrix, axis=0) + elif self.strategy == "min": + aggregated = np.min(reward_matrix, axis=0) + elif self.strategy == "majority_vote": + # Treat positive as vote for 1, non-positive as vote for 0 + votes = (reward_matrix > 0).astype(np.float64) + vote_fractions = np.mean(votes, axis=0) + # Majority wins; ties (0.5) go to positive + aggregated = np.where(vote_fractions >= 0.5, 1.0, 0.0) + else: + # Should not reach here due to __init__ validation + aggregated = np.mean(reward_matrix, axis=0) + + # Compute reliability metrics + self._compute_reliability_metrics(reward_matrix) + + # Track per-item disagreement + if self.track_disagreement: + self.last_disagreement_scores = np.var(reward_matrix, axis=0).tolist() + + return aggregated.tolist() + + def _compute_reliability_metrics(self, reward_matrix: np.ndarray): + """ + Compute and store inter-rater reliability metrics. + + Args: + reward_matrix: Shape (n_raters, n_items) + """ + n_raters, n_items = reward_matrix.shape + + if n_raters < 2 or n_items < 2: + self.last_reliability_alpha = float("nan") + return + + self.last_reliability_alpha = _krippendorff_alpha(reward_matrix) + + def reliability_metrics(self) -> Dict[str, float]: + """ + Return the latest inter-rater reliability metrics. + + Returns: + Dictionary with reliability statistics: + - alpha: Krippendorff's alpha + - mean_disagreement: Average per-item variance across raters + - max_disagreement: Maximum per-item variance (worst agreement) + """ + metrics = { + "alpha": self.last_reliability_alpha, + } + + if self.last_disagreement_scores is not None: + scores = self.last_disagreement_scores + metrics["mean_disagreement"] = sum(scores) / len(scores) if scores else 0.0 + metrics["max_disagreement"] = max(scores) if scores else 0.0 + + return metrics + + def log_metrics(self, raw_rewards: List[float], weighted_rewards: List[float]): + """Log ensemble-specific metrics alongside standard reward metrics.""" + super().log_metrics(raw_rewards, weighted_rewards) + + if not self.wandb_logger: + return + + reliability = self.reliability_metrics() + wandb_metrics = {} + + if not np.isnan(reliability.get("alpha", float("nan"))): + wandb_metrics[f"reward/{self.name}/reliability_alpha"] = reliability[ + "alpha" + ] + + if "mean_disagreement" in reliability: + wandb_metrics[f"reward/{self.name}/mean_disagreement"] = reliability[ + "mean_disagreement" + ] + wandb_metrics[f"reward/{self.name}/max_disagreement"] = reliability[ + "max_disagreement" + ] + + if wandb_metrics: + self.wandb_logger.log(wandb_metrics) diff --git a/atroposlib/envs/reward_normalization.py b/atroposlib/envs/reward_normalization.py new file mode 100644 index 000000000..5531200e7 --- /dev/null +++ b/atroposlib/envs/reward_normalization.py @@ -0,0 +1,267 @@ +""" +Online reward normalization for multi-environment RL training stability. + +Implements Welford's online algorithm for running mean/variance computation, +enabling z-score and min-max normalization of reward signals without needing +to store all historical values. + +This is critical for multi-environment training where different environments +produce rewards on different scales (e.g., GSM8K gives {-1, 1} while +tool-use environments give continuous [0, 1] scores). + +Usage: + normalizer = RewardNormalizer(mode="zscore", clip=5.0) + + # During training loop + scores = [0.5, -0.3, 0.8, 1.0] + normalized = normalizer.normalize(scores) + + # Checkpointing + state = normalizer.state_dict() + normalizer.load_state_dict(state) +""" + +import logging +import math +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +class WelfordAccumulator: + """ + Welford's online algorithm for computing running mean and variance. + + Numerically stable single-pass algorithm that avoids catastrophic + cancellation. Maintains count, mean, and M2 (sum of squared deviations) + to compute variance on demand. + + Reference: Welford, B. P. (1962). "Note on a method for calculating + corrected sums of squares and products". Technometrics. 4(3): 419-420. + """ + + def __init__(self): + self.count: int = 0 + self.mean: float = 0.0 + self._m2: float = 0.0 + self._min: float = float("inf") + self._max: float = float("-inf") + + def update(self, value: float) -> None: + """Update running statistics with a new value.""" + self.count += 1 + delta = value - self.mean + self.mean += delta / self.count + delta2 = value - self.mean + self._m2 += delta * delta2 + self._min = min(self._min, value) + self._max = max(self._max, value) + + def update_batch(self, values: List[float]) -> None: + """Update running statistics with a batch of values.""" + for v in values: + self.update(v) + + @property + def variance(self) -> float: + """Population variance of all observed values.""" + if self.count < 2: + return 0.0 + return self._m2 / self.count + + @property + def std(self) -> float: + """Population standard deviation of all observed values.""" + return math.sqrt(self.variance) + + @property + def min_val(self) -> float: + """Minimum observed value.""" + return self._min if self.count > 0 else 0.0 + + @property + def max_val(self) -> float: + """Maximum observed value.""" + return self._max if self.count > 0 else 0.0 + + def state_dict(self) -> Dict[str, Any]: + """Serialize state for checkpointing.""" + return { + "count": self.count, + "mean": self.mean, + "m2": self._m2, + "min": self._min, + "max": self._max, + } + + def load_state_dict(self, state: Dict[str, Any]) -> None: + """Restore state from checkpoint.""" + self.count = state["count"] + self.mean = state["mean"] + self._m2 = state["m2"] + self._min = state["min"] + self._max = state["max"] + + +class RewardNormalizer: + """ + Reward normalization for stable multi-environment RL training. + + Supports two normalization modes: + - "zscore": Standardize to zero mean, unit variance using running stats + - "minmax": Scale to [0, 1] range using observed min/max + + Both modes use Welford's online algorithm so no historical data storage + is required. Optional reward clipping prevents extreme values from + destabilizing training. + + Args: + mode: Normalization mode. One of "zscore", "minmax", or "none". + clip: Maximum absolute value after normalization. Set to 0 or None + to disable clipping. Default: 5.0. + warmup: Minimum number of samples before normalization activates. + During warmup, raw scores are returned (optionally clipped). + Default: 10. + eps: Small constant for numerical stability in division. Default: 1e-8. + """ + + VALID_MODES = {"zscore", "minmax", "none"} + + def __init__( + self, + mode: str = "zscore", + clip: Optional[float] = 5.0, + warmup: int = 10, + eps: float = 1e-8, + ): + if mode not in self.VALID_MODES: + raise ValueError( + f"Invalid normalization mode '{mode}'. " + f"Must be one of: {self.VALID_MODES}" + ) + + self.mode = mode + self.clip = clip if clip and clip > 0 else None + self.warmup = max(0, warmup) + self.eps = eps + self._accumulator = WelfordAccumulator() + + @property + def count(self) -> int: + """Number of samples observed.""" + return self._accumulator.count + + @property + def mean(self) -> float: + """Running mean of observed values.""" + return self._accumulator.mean + + @property + def std(self) -> float: + """Running standard deviation of observed values.""" + return self._accumulator.std + + @property + def is_warmed_up(self) -> bool: + """Whether enough samples have been observed for normalization.""" + return self._accumulator.count >= self.warmup + + def normalize(self, scores: List[float]) -> List[float]: + """ + Normalize a batch of reward scores. + + Updates running statistics with the new scores, then applies + normalization. During warmup, raw scores are returned (with + optional clipping). + + Args: + scores: Raw reward scores to normalize. + + Returns: + Normalized (and optionally clipped) scores. + """ + if not scores: + return [] + + if self.mode == "none": + return list(scores) + + # Update running statistics + self._accumulator.update_batch(scores) + + # During warmup, return raw scores (optionally clipped) + if not self.is_warmed_up: + logger.debug( + "Reward normalizer warmup: %d/%d samples", + self._accumulator.count, + self.warmup, + ) + return self._clip(list(scores)) + + # Apply normalization + if self.mode == "zscore": + normalized = self._zscore(scores) + elif self.mode == "minmax": + normalized = self._minmax(scores) + else: + normalized = list(scores) + + return self._clip(normalized) + + def _zscore(self, scores: List[float]) -> List[float]: + """Z-score normalize: (x - mean) / std.""" + mean = self._accumulator.mean + std = self._accumulator.std + if std < self.eps: + # All values nearly identical -- return zeros + return [0.0] * len(scores) + return [(s - mean) / (std + self.eps) for s in scores] + + def _minmax(self, scores: List[float]) -> List[float]: + """Min-max normalize to [0, 1] range.""" + min_val = self._accumulator.min_val + max_val = self._accumulator.max_val + range_val = max_val - min_val + if range_val < self.eps: + return [0.5] * len(scores) + return [(s - min_val) / (range_val + self.eps) for s in scores] + + def _clip(self, scores: List[float]) -> List[float]: + """Clip scores to [-clip, clip] range.""" + if self.clip is None: + return scores + return [max(-self.clip, min(self.clip, s)) for s in scores] + + def metrics_dict(self) -> Dict[str, float]: + """ + Return current normalization statistics for WandB logging. + + Returns: + Dictionary with keys suitable for wandb.log(). + """ + metrics = { + "reward_norm/count": float(self._accumulator.count), + "reward_norm/mean": self._accumulator.mean, + "reward_norm/std": self._accumulator.std, + "reward_norm/min": self._accumulator.min_val, + "reward_norm/max": self._accumulator.max_val, + } + return metrics + + def state_dict(self) -> Dict[str, Any]: + """Serialize full state for checkpointing.""" + return { + "mode": self.mode, + "clip": self.clip, + "warmup": self.warmup, + "eps": self.eps, + "accumulator": self._accumulator.state_dict(), + } + + def load_state_dict(self, state: Dict[str, Any]) -> None: + """Restore state from checkpoint.""" + self.mode = state["mode"] + self.clip = state["clip"] + self.warmup = state["warmup"] + self.eps = state["eps"] + self._accumulator.load_state_dict(state["accumulator"]) diff --git a/atroposlib/envs/server_handling/server_harness.py b/atroposlib/envs/server_handling/server_harness.py index 90a087122..2ea12e303 100644 --- a/atroposlib/envs/server_handling/server_harness.py +++ b/atroposlib/envs/server_handling/server_harness.py @@ -166,6 +166,10 @@ async def tokens_and_logprobs_completion( except KeyError as e: raise KeyError(f"KeyError: {e} for prompt:\n{prompt}") + async def wandb_metrics(self, metrics_dict: dict, server_name: str) -> dict: + """Mock implementation of wandb_metrics.""" + return metrics_dict + if __name__ == "__main__": diff --git a/atroposlib/tests/test_api_perf.py b/atroposlib/tests/test_api_perf.py new file mode 100644 index 000000000..1b3c229be --- /dev/null +++ b/atroposlib/tests/test_api_perf.py @@ -0,0 +1,100 @@ +""" +Tests for API performance tracking utilities. + +Tests cover: +- Request tracking via context manager +- Latency percentile computation (p50, p95, p99) +- Throughput calculation (items/sec, requests/sec) +- Compression ratio tracking +- Rolling window behavior +- Multi-threaded/Parallel safety (simulated via multiple records) +- WandB metrics dictionary formatting +""" + +import unittest +from unittest.mock import patch + +from atroposlib.utils.api_perf import APIPerformanceTracker + + +class TestAPIPerformanceTracker(unittest.TestCase): + def setUp(self): + self.tracker = APIPerformanceTracker(window_size=10) + + def test_track_request_context_manager(self): + with patch("time.monotonic", side_effect=[0, 0.05]): # 50ms latency + with self.tracker.track_request(n_items=2, payload_bytes=1000): + pass + + stats = self.tracker.latency_stats() + self.assertEqual(self.tracker.n_records, 1) + self.assertAlmostEqual(stats["mean_ms"], 50.0) + + def test_latency_percentiles(self): + # Record 10 requests with specific latencies: 10, 20, ..., 100ms + for i in range(1, 11): + self.tracker.record_request(latency_ms=float(i * 10), n_items=1) + + stats = self.tracker.latency_stats() + # p50 of [10, 20, 30, 40, 50, 60, 70, 80, 90, 100] is 55.0 + self.assertAlmostEqual(stats["p50_ms"], 55.0) + # p95 is between 90 and 100 + self.assertGreater(stats["p95_ms"], 90.0) + self.assertLessEqual(stats["p95_ms"], 100.0) + + def test_rolling_window(self): + tracker = APIPerformanceTracker(window_size=3) + for i in range(5): + tracker.record_request(latency_ms=float(i), n_items=1) + + self.assertEqual(tracker.n_records, 3) + stats = tracker.latency_stats() + # Should only have records for 2, 3, 4 + self.assertAlmostEqual(stats["min_ms"], 2.0) + self.assertAlmostEqual(stats["max_ms"], 4.0) + + def test_throughput_calculation(self): + # Record 2 requests at t=0 and t=1s + with patch("time.time", side_effect=[100.0, 101.0]): + self.tracker.record_request(latency_ms=10, n_items=10) # t=100 + self.tracker.record_request(latency_ms=10, n_items=10) # t=101 + + stats = self.tracker.throughput_stats() + # 20 items over 1 second = 20 items/sec + self.assertAlmostEqual(stats["items_per_sec"], 20.0) + # 2 requests over 1 second = 2 req/sec + self.assertAlmostEqual(stats["requests_per_sec"], 2.0) + + def test_compression_stats(self): + # 50% compression + self.tracker.record_request( + latency_ms=10, payload_bytes=1000, compressed_bytes=500 + ) + # 25% compression + self.tracker.record_request( + latency_ms=10, payload_bytes=1000, compressed_bytes=250 + ) + + stats = self.tracker.compression_stats() + # mean of 0.5 and 0.25 is 0.375 + self.assertAlmostEqual(stats["mean_compression_ratio"], 0.375) + + def test_metrics_dict_formatting(self): + self.tracker.record_request(latency_ms=50, n_items=5, success=True) + self.tracker.record_request(latency_ms=100, n_items=5, success=False) + + metrics = self.tracker.metrics_dict() + self.assertIn("api_perf/latency_p50_ms", metrics) + self.assertIn("api_perf/items_per_sec", metrics) + self.assertEqual(metrics["api_perf/failed_requests"], 1) + self.assertEqual(metrics["api_perf/error_rate"], 0.5) + + def test_slow_request_warning(self): + with self.assertLogs("atroposlib.utils.api_perf", level="WARNING") as cm: + self.tracker.slow_request_threshold_ms = 100 + self.tracker.record_request(latency_ms=150) + self.assertTrue(any("Slow API request" in msg for msg in cm.output)) + + +if __name__ == "__main__": + unittest.main() diff --git a/atroposlib/tests/test_curriculum.py b/atroposlib/tests/test_curriculum.py new file mode 100644 index 000000000..f8b421c6d --- /dev/null +++ b/atroposlib/tests/test_curriculum.py @@ -0,0 +1,248 @@ +""" +Tests for CurriculumScheduler -- difficulty-based sampling for RL training. + +Tests cover: +- Uniform passthrough (default behavior unchanged) +- Easy-first annealing +- Competence-based frontier sampling +- EMA difficulty updates +- Bin assignment with quantile boundaries +- Metrics and state persistence +- Edge cases +""" + +import math +import random + +import pytest + +from atroposlib.envs.curriculum import CurriculumScheduler, CurriculumStrategy + +# --------------------------------------------------------------------------- +# Strategy tests +# --------------------------------------------------------------------------- + + +class TestUniformStrategy: + def test_uniform_returns_valid_bins(self): + scheduler = CurriculumScheduler(strategy="uniform", n_bins=5) + bins = [scheduler.sample_bin(step, 1000) for step in range(100)] + assert all(0 <= b < 5 for b in bins) + + def test_uniform_covers_all_bins(self): + """Uniform should eventually sample from every bin.""" + random.seed(42) + scheduler = CurriculumScheduler(strategy="uniform", n_bins=5) + bins = set() + for _ in range(200): + bins.add(scheduler.sample_bin(0, 1000)) + assert bins == {0, 1, 2, 3, 4} + + +class TestEasyFirstStrategy: + def test_early_training_prefers_easy(self): + """At step 0, easy_first should strongly prefer low bins (easy).""" + random.seed(42) + scheduler = CurriculumScheduler( + strategy="easy_first", n_bins=5, temperature=0.5 + ) + bins = [scheduler.sample_bin(0, 1000) for _ in range(200)] + easy_count = sum(1 for b in bins if b <= 1) + hard_count = sum(1 for b in bins if b >= 3) + # Early training should have more easy than hard + assert easy_count > hard_count + + def test_late_training_approaches_uniform(self): + """Near the end (step~total), easy_first should be roughly uniform.""" + random.seed(42) + scheduler = CurriculumScheduler( + strategy="easy_first", n_bins=5, temperature=1.0 + ) + probs = scheduler._easy_first_probs(progress=1.0) + # At progress=1.0, all probs should be near 1/n_bins + for p in probs: + assert abs(p - 0.2) < 0.05 + + +class TestCompetenceBasedStrategy: + def test_competence_frontier_moves(self): + """The frontier should shift from easy to hard as training progresses.""" + scheduler = CurriculumScheduler( + strategy="competence_based", n_bins=5, temperature=0.5 + ) + + # Early training: frontier at easy bins + random.seed(42) + early_bins = [scheduler.sample_bin(0, 1000) for _ in range(200)] + early_mean = sum(early_bins) / len(early_bins) + + # Late training: frontier at hard bins + late_bins = [scheduler.sample_bin(900, 1000) for _ in range(200)] + late_mean = sum(late_bins) / len(late_bins) + + # Late mean should be higher (harder bins) + assert late_mean > early_mean + + def test_mid_training_prefers_middle(self): + """At 50% progress, competence_based should prefer middle bins.""" + random.seed(42) + scheduler = CurriculumScheduler( + strategy="competence_based", n_bins=5, temperature=0.5 + ) + bins = [scheduler.sample_bin(500, 1000) for _ in range(300)] + mid_count = sum(1 for b in bins if 1 <= b <= 3) + edge_count = sum(1 for b in bins if b == 0 or b == 4) + assert mid_count > edge_count + + +# --------------------------------------------------------------------------- +# EMA difficulty tracking tests +# --------------------------------------------------------------------------- + + +class TestDifficultyTracking: + def test_ema_update(self): + scheduler = CurriculumScheduler(strategy="uniform", ema_alpha=0.5) + scheduler.update("item_1", 1.0) + assert math.isclose(scheduler.get_item_difficulty("item_1"), 1.0) + + scheduler.update("item_1", 0.0) + # EMA: 0.5 * 0.0 + 0.5 * 1.0 = 0.5 + assert math.isclose(scheduler.get_item_difficulty("item_1"), 0.5) + + def test_batch_update(self): + scheduler = CurriculumScheduler(strategy="uniform") + scheduler.update_batch("item_1", [0.8, 0.6, 1.0]) + # Should use average: 0.8 + diff = scheduler.get_item_difficulty("item_1") + assert diff is not None + assert math.isclose(diff, 0.8) + + def test_untracked_item_returns_none(self): + scheduler = CurriculumScheduler(strategy="uniform") + assert scheduler.get_item_difficulty("nonexistent") is None + + def test_multiple_items_tracked(self): + scheduler = CurriculumScheduler(strategy="uniform") + scheduler.update("easy", 0.9) + scheduler.update("hard", 0.1) + scheduler.update("medium", 0.5) + + assert scheduler.n_items_tracked == 3 + assert scheduler.get_item_difficulty("easy") > scheduler.get_item_difficulty( + "hard" + ) + + +# --------------------------------------------------------------------------- +# Bin assignment tests +# --------------------------------------------------------------------------- + + +class TestBinAssignment: + def test_easy_item_gets_low_bin(self): + scheduler = CurriculumScheduler(strategy="uniform", n_bins=5) + # Create items spanning the difficulty range + for i in range(100): + scheduler.update(f"item_{i}", i / 100.0) + + # High score = easy = low bin + easy_bin = scheduler.get_item_bin("item_95") + hard_bin = scheduler.get_item_bin("item_5") + assert easy_bin < hard_bin + + def test_untracked_gets_middle_bin(self): + scheduler = CurriculumScheduler(strategy="uniform", n_bins=5) + assert scheduler.get_item_bin("unknown") == 2 # n_bins // 2 + + def test_single_bin(self): + scheduler = CurriculumScheduler(strategy="uniform", n_bins=1) + scheduler.update("item", 0.5) + assert scheduler.get_item_bin("item") == 0 + + +# --------------------------------------------------------------------------- +# Metrics and state tests +# --------------------------------------------------------------------------- + + +class TestMetrics: + def test_metrics_dict_empty(self): + scheduler = CurriculumScheduler(strategy="uniform") + metrics = scheduler.metrics_dict() + assert "curriculum/items_tracked" in metrics + assert metrics["curriculum/items_tracked"] == 0 + + def test_metrics_dict_populated(self): + scheduler = CurriculumScheduler(strategy="uniform", n_bins=3) + for i in range(60): # Enough to trigger rebinning + scheduler.update(f"item_{i}", i / 60.0) + + metrics = scheduler.metrics_dict() + assert metrics["curriculum/items_tracked"] == 60 + assert "curriculum/mean_difficulty" in metrics + assert "curriculum/min_difficulty" in metrics + assert "curriculum/max_difficulty" in metrics + assert "curriculum/total_updates" in metrics + + +class TestStatePersistence: + def test_save_load_roundtrip(self): + scheduler = CurriculumScheduler( + strategy="competence_based", n_bins=3, temperature=0.8 + ) + for i in range(20): + scheduler.update(f"item_{i}", i / 20.0) + + state = scheduler.state_dict() + + scheduler2 = CurriculumScheduler(strategy="uniform") + scheduler2.load_state_dict(state) + + assert scheduler2.strategy == "competence_based" + assert scheduler2.n_bins == 3 + assert math.isclose(scheduler2.temperature, 0.8) + assert scheduler2.n_items_tracked == 20 + + # Difficulty scores should match + for i in range(20): + key = f"item_{i}" + d1 = scheduler.get_item_difficulty(key) + d2 = scheduler2.get_item_difficulty(key) + assert math.isclose(d1, d2) + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + def test_invalid_strategy_raises(self): + with pytest.raises(ValueError, match="Invalid curriculum strategy"): + CurriculumScheduler(strategy="invalid") + + def test_invalid_n_bins_raises(self): + with pytest.raises(ValueError, match="n_bins must be >= 1"): + CurriculumScheduler(n_bins=0) + + def test_temperature_floor(self): + scheduler = CurriculumScheduler(temperature=0.001) + assert scheduler.temperature >= 0.01 + + def test_ema_alpha_clamped(self): + scheduler = CurriculumScheduler(ema_alpha=2.0) + assert scheduler.ema_alpha <= 1.0 + + scheduler2 = CurriculumScheduler(ema_alpha=-1.0) + assert scheduler2.ema_alpha >= 0.0 + + def test_empty_batch_update(self): + scheduler = CurriculumScheduler(strategy="uniform") + scheduler.update_batch("item", []) + assert scheduler.n_items_tracked == 0 + + def test_strategy_enum_values(self): + assert CurriculumStrategy.UNIFORM.value == "uniform" + assert CurriculumStrategy.EASY_FIRST.value == "easy_first" + assert CurriculumStrategy.COMPETENCE_BASED.value == "competence_based" diff --git a/atroposlib/tests/test_reward_ensemble.py b/atroposlib/tests/test_reward_ensemble.py new file mode 100644 index 000000000..f7460690a --- /dev/null +++ b/atroposlib/tests/test_reward_ensemble.py @@ -0,0 +1,310 @@ +""" +Tests for EnsembleReward -- reward aggregation with inter-rater reliability. + +Tests cover: +- All aggregation strategies (mean, median, min, majority_vote) +- Krippendorff's alpha computation (perfect/no agreement) +- Disagreement tracking +- Registry integration +- Edge cases (empty completions, single reward function) +""" + +import math +from typing import Any, List + +import numpy as np +import pytest + +from atroposlib.envs.reward_fns.ensemble_reward import ( + EnsembleReward, + _krippendorff_alpha, +) +from atroposlib.envs.reward_fns.registry import RewardRegistry +from atroposlib.envs.reward_fns.reward_function import RewardFunction + +# --------------------------------------------------------------------------- +# Test fixtures -- simple reward functions for composing ensembles +# --------------------------------------------------------------------------- + + +class ConstantReward(RewardFunction): + """Returns a fixed score for every completion.""" + + def __init__(self, value: float = 1.0, **kwargs): + super().__init__(**kwargs) + self._value = value + + def compute(self, completions: List[Any], **kwargs) -> List[float]: + return [self._value] * len(completions) + + +class LengthReward(RewardFunction): + """Scores by string length (for testing divergent reward signals).""" + + def compute(self, completions: List[Any], **kwargs) -> List[float]: + return [float(len(self.get_content(c))) for c in completions] + + +class BinaryReward(RewardFunction): + """Returns 1.0 if completion contains 'good', else 0.0.""" + + def compute(self, completions: List[Any], **kwargs) -> List[float]: + return [ + 1.0 if "good" in self.get_content(c).lower() else 0.0 for c in completions + ] + + +def _make_ensemble(strategy, reward_functions): + """Helper to construct an EnsembleReward without going through registry.""" + ensemble = EnsembleReward.__new__(EnsembleReward) + ensemble.weight = 1.0 + ensemble.strategy = strategy + ensemble.track_disagreement = True + ensemble.reward_functions = reward_functions + ensemble.wandb_logger = None + ensemble._name = None + ensemble.config = {} + ensemble.last_reliability_alpha = float("nan") + ensemble.last_disagreement_scores = None + ensemble._all_sub_rewards = None + return ensemble + + +@pytest.fixture +def test_registry(): + """Create a clean registry with test reward functions.""" + reg = RewardRegistry() + reg.register(name="constant")(ConstantReward) + reg.register(name="length")(LengthReward) + reg.register(name="binary")(BinaryReward) + return reg + + +@pytest.fixture +def completions(): + """Sample completions for testing.""" + return ["short", "a medium length string", "good answer here"] + + +# --------------------------------------------------------------------------- +# Aggregation strategy tests +# --------------------------------------------------------------------------- + + +class TestMeanAggregation: + def test_mean_of_identical_scores(self, completions): + ensemble = _make_ensemble( + "mean", + [ + ConstantReward(value=2.0), + ConstantReward(value=2.0), + ], + ) + scores = ensemble.compute(completions) + assert len(scores) == 3 + assert all(math.isclose(s, 2.0, rel_tol=1e-9) for s in scores) + + def test_mean_of_different_scores(self, completions): + ensemble = _make_ensemble( + "mean", + [ + ConstantReward(value=1.0), + ConstantReward(value=3.0), + ], + ) + scores = ensemble.compute(completions) + assert all(math.isclose(s, 2.0, rel_tol=1e-9) for s in scores) + + +class TestMedianAggregation: + def test_median_rejects_outlier(self, completions): + """Median should be robust to a single outlier reward function.""" + ensemble = _make_ensemble( + "median", + [ + ConstantReward(value=1.0), + ConstantReward(value=1.0), + ConstantReward(value=100.0), + ], + ) + scores = ensemble.compute(completions) + assert all(math.isclose(s, 1.0, rel_tol=1e-9) for s in scores) + + +class TestMinAggregation: + def test_min_is_conservative(self, completions): + ensemble = _make_ensemble( + "min", + [ + ConstantReward(value=0.5), + ConstantReward(value=0.8), + ConstantReward(value=1.0), + ], + ) + scores = ensemble.compute(completions) + assert all(math.isclose(s, 0.5, rel_tol=1e-9) for s in scores) + + +class TestMajorityVoteAggregation: + def test_majority_positive(self, completions): + ensemble = _make_ensemble( + "majority_vote", + [ + ConstantReward(value=1.0), + ConstantReward(value=1.0), + ConstantReward(value=-1.0), + ], + ) + scores = ensemble.compute(completions) + assert all(math.isclose(s, 1.0) for s in scores) + + def test_majority_negative(self, completions): + ensemble = _make_ensemble( + "majority_vote", + [ + ConstantReward(value=-1.0), + ConstantReward(value=-1.0), + ConstantReward(value=1.0), + ], + ) + scores = ensemble.compute(completions) + assert all(math.isclose(s, 0.0) for s in scores) + + def test_tie_goes_positive(self, completions): + ensemble = _make_ensemble( + "majority_vote", + [ + ConstantReward(value=1.0), + ConstantReward(value=-1.0), + ], + ) + scores = ensemble.compute(completions) + assert all(math.isclose(s, 1.0) for s in scores) + + +# --------------------------------------------------------------------------- +# Inter-rater reliability tests +# --------------------------------------------------------------------------- + + +class TestKrippendorffAlpha: + def test_perfect_agreement(self): + ratings = np.array( + [ + [1.0, 2.0, 3.0, 4.0], + [1.0, 2.0, 3.0, 4.0], + [1.0, 2.0, 3.0, 4.0], + ] + ) + alpha = _krippendorff_alpha(ratings) + assert math.isclose(alpha, 1.0, rel_tol=1e-9) + + def test_no_agreement(self): + ratings = np.array( + [ + [1.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0], + ] + ) + alpha = _krippendorff_alpha(ratings) + assert alpha < 0.0 + + def test_random_agreement(self): + np.random.seed(42) + ratings = np.random.rand(5, 100) + alpha = _krippendorff_alpha(ratings) + assert abs(alpha) < 0.3 + + def test_insufficient_data(self): + alpha = _krippendorff_alpha(np.array([[1.0, 2.0, 3.0]])) + assert math.isnan(alpha) + + alpha = _krippendorff_alpha(np.array([[1.0], [2.0]])) + assert math.isnan(alpha) + + +class TestReliabilityMetrics: + def test_reliability_computed_after_scoring(self, completions): + ensemble = _make_ensemble( + "mean", + [ + ConstantReward(value=1.0), + ConstantReward(value=1.0), + ], + ) + ensemble.compute(completions) + metrics = ensemble.reliability_metrics() + assert "alpha" in metrics + assert math.isclose(metrics["alpha"], 1.0, rel_tol=1e-9) + + def test_disagreement_tracked(self, completions): + ensemble = _make_ensemble( + "mean", + [ + ConstantReward(value=0.0), + ConstantReward(value=10.0), + ], + ) + ensemble.compute(completions) + assert ensemble.last_disagreement_scores is not None + assert len(ensemble.last_disagreement_scores) == len(completions) + # Variance of [0.0, 10.0] = 25.0 + assert all( + math.isclose(d, 25.0, rel_tol=1e-9) + for d in ensemble.last_disagreement_scores + ) + + +# --------------------------------------------------------------------------- +# Registry integration +# --------------------------------------------------------------------------- + + +class TestRegistryIntegration: + def test_create_via_registry(self, test_registry): + # EnsembleReward.__init__ resolves sub-rewards via the global registry, + # so we must register our test fixtures there too. + from atroposlib.envs.reward_fns.registry import registry as global_registry + + global_registry.register(name="test_constant")(ConstantReward) + test_registry.register(name="ensemble")(EnsembleReward) + + ensemble = test_registry.create( + { + "type": "ensemble", + "rewards": ["test_constant", "test_constant"], + "strategy": "median", + } + ) + assert isinstance(ensemble, EnsembleReward) + assert ensemble.strategy == "median" + assert len(ensemble.reward_functions) == 2 + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + def test_empty_completions(self): + ensemble = _make_ensemble("mean", [ConstantReward(value=1.0)]) + scores = ensemble.compute([]) + assert scores == [] + + def test_invalid_strategy_raises(self): + with pytest.raises(ValueError, match="Invalid strategy"): + EnsembleReward(rewards=[], strategy="nonexistent") + + def test_name_format(self): + ensemble = _make_ensemble( + "median", + [ + ConstantReward(value=1.0), + LengthReward(), + ], + ) + name = ensemble.name + assert "ensemble_median" in name + assert "constantreward" in name + assert "lengthreward" in name diff --git a/atroposlib/tests/test_reward_normalization.py b/atroposlib/tests/test_reward_normalization.py new file mode 100644 index 000000000..abda7f5c3 --- /dev/null +++ b/atroposlib/tests/test_reward_normalization.py @@ -0,0 +1,250 @@ +""" +Tests for RewardNormalizer -- online reward normalization with Welford's algorithm. + +Tests cover: +- Welford's accumulator numerical accuracy vs numpy +- Z-score normalization +- Min-max normalization +- Clipping behavior +- Warmup period +- State save/load roundtrip +- Edge cases (empty input, constant values, mode validation) +""" + +import math + +import numpy as np +import pytest + +from atroposlib.envs.reward_normalization import RewardNormalizer, WelfordAccumulator + +# --------------------------------------------------------------------------- +# WelfordAccumulator tests +# --------------------------------------------------------------------------- + + +class TestWelfordAccumulator: + def test_single_value(self): + acc = WelfordAccumulator() + acc.update(5.0) + assert acc.count == 1 + assert math.isclose(acc.mean, 5.0) + assert math.isclose(acc.variance, 0.0) + + def test_matches_numpy(self): + """Welford's running stats should match numpy's batch computation.""" + np.random.seed(42) + values = np.random.randn(1000).tolist() + + acc = WelfordAccumulator() + acc.update_batch(values) + + expected_mean = np.mean(values) + expected_var = np.var(values) # population variance + + assert math.isclose(acc.mean, expected_mean, rel_tol=1e-9) + assert math.isclose(acc.variance, expected_var, rel_tol=1e-6) + + def test_min_max_tracking(self): + acc = WelfordAccumulator() + acc.update_batch([3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0]) + assert math.isclose(acc.min_val, 1.0) + assert math.isclose(acc.max_val, 9.0) + + def test_state_roundtrip(self): + acc = WelfordAccumulator() + acc.update_batch([1.0, 2.0, 3.0, 4.0, 5.0]) + state = acc.state_dict() + + acc2 = WelfordAccumulator() + acc2.load_state_dict(state) + + assert acc2.count == acc.count + assert math.isclose(acc2.mean, acc.mean) + assert math.isclose(acc2.variance, acc.variance) + assert math.isclose(acc2.min_val, acc.min_val) + assert math.isclose(acc2.max_val, acc.max_val) + + def test_empty_accumulator(self): + acc = WelfordAccumulator() + assert acc.count == 0 + assert math.isclose(acc.mean, 0.0) + assert math.isclose(acc.variance, 0.0) + assert math.isclose(acc.std, 0.0) + + +# --------------------------------------------------------------------------- +# RewardNormalizer z-score tests +# --------------------------------------------------------------------------- + + +class TestZScoreNormalization: + def test_zscore_centers_around_zero(self): + normalizer = RewardNormalizer(mode="zscore", clip=None, warmup=0) + # Feed enough data to establish stats + normalizer.normalize([1.0, 2.0, 3.0, 4.0, 5.0] * 10) + # Now normalize a new batch + result = normalizer.normalize([3.0]) # mean should be ~3.0 + assert abs(result[0]) < 0.1 # Should be near 0 + + def test_zscore_output_scale(self): + normalizer = RewardNormalizer(mode="zscore", clip=None, warmup=0) + # Standard normal-ish data + np.random.seed(42) + data = np.random.randn(500).tolist() + normalizer.normalize(data) + + # Normalize the same data again + result = normalizer.normalize(data) + # After normalization, std should be approximately 1.0 + result_std = np.std(result) + assert 0.8 < result_std < 1.2 + + def test_zscore_constant_values(self): + """Constant values should normalize to 0.""" + normalizer = RewardNormalizer(mode="zscore", clip=None, warmup=0) + result = normalizer.normalize([5.0, 5.0, 5.0, 5.0, 5.0]) + assert all(math.isclose(s, 0.0) for s in result) + + +# --------------------------------------------------------------------------- +# RewardNormalizer min-max tests +# --------------------------------------------------------------------------- + + +class TestMinMaxNormalization: + def test_minmax_scales_to_unit_range(self): + normalizer = RewardNormalizer(mode="minmax", clip=None, warmup=0) + normalizer.normalize([0.0, 10.0]) # Establish min=0, max=10 + result = normalizer.normalize([0.0, 5.0, 10.0]) + assert math.isclose(result[0], 0.0, abs_tol=1e-6) + assert math.isclose(result[1], 0.5, abs_tol=1e-3) + assert math.isclose(result[2], 1.0, abs_tol=1e-6) + + def test_minmax_constant_returns_half(self): + normalizer = RewardNormalizer(mode="minmax", clip=None, warmup=0) + result = normalizer.normalize([3.0, 3.0, 3.0]) + assert all(math.isclose(s, 0.5) for s in result) + + +# --------------------------------------------------------------------------- +# Clipping tests +# --------------------------------------------------------------------------- + + +class TestClipping: + def test_clip_bounds(self): + normalizer = RewardNormalizer(mode="zscore", clip=2.0, warmup=0) + # Feed data with a big outlier + normalizer.normalize([0.0] * 100) + result = normalizer.normalize([1000.0]) + assert result[0] <= 2.0 + + def test_no_clip_when_disabled(self): + normalizer = RewardNormalizer(mode="zscore", clip=None, warmup=0) + normalizer.normalize([0.0] * 100) + result = normalizer.normalize([1000.0]) + assert result[0] > 2.0 # Should NOT be clipped + + def test_negative_clip_disabled(self): + normalizer = RewardNormalizer(mode="zscore", clip=-1.0, warmup=0) + assert normalizer.clip is None + + +# --------------------------------------------------------------------------- +# Warmup tests +# --------------------------------------------------------------------------- + + +class TestWarmup: + def test_warmup_returns_raw(self): + normalizer = RewardNormalizer(mode="zscore", clip=None, warmup=10) + # During warmup, should return raw scores + result = normalizer.normalize([5.0, 10.0]) + assert math.isclose(result[0], 5.0) + assert math.isclose(result[1], 10.0) + + def test_warmup_transition(self): + normalizer = RewardNormalizer(mode="zscore", clip=None, warmup=5) + # Feed 3 values (under warmup) + r1 = normalizer.normalize([1.0, 2.0, 3.0]) + assert not normalizer.is_warmed_up + # Raw values during warmup + assert math.isclose(r1[0], 1.0) + + # Feed 3 more (now at 6, above warmup) + r2 = normalizer.normalize([4.0, 5.0, 6.0]) + assert normalizer.is_warmed_up + # Should be normalized now (not raw) + assert not math.isclose(r2[0], 4.0) + + +# --------------------------------------------------------------------------- +# State persistence tests +# --------------------------------------------------------------------------- + + +class TestStatePersistence: + def test_save_load_roundtrip(self): + normalizer = RewardNormalizer(mode="zscore", clip=3.0, warmup=5) + normalizer.normalize([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]) + + state = normalizer.state_dict() + + normalizer2 = RewardNormalizer() + normalizer2.load_state_dict(state) + + assert normalizer2.mode == "zscore" + assert normalizer2.clip == 3.0 + assert normalizer2.warmup == 5 + assert normalizer2.count == normalizer.count + assert math.isclose(normalizer2.mean, normalizer.mean) + assert math.isclose(normalizer2.std, normalizer.std) + + def test_loaded_normalizer_continues(self): + """A loaded normalizer should produce same results as the original.""" + normalizer = RewardNormalizer(mode="zscore", clip=5.0, warmup=0) + normalizer.normalize([1.0, 2.0, 3.0, 4.0, 5.0] * 10) + state = normalizer.state_dict() + + normalizer2 = RewardNormalizer() + normalizer2.load_state_dict(state) + + test_data = [2.5, 3.5, 4.5] + r1 = normalizer.normalize(test_data) + r2 = normalizer2.normalize(test_data) + + # Results won't be identical because normalize also updates stats, + # but they should be very close for the first call after loading + for a, b in zip(r1, r2): + assert math.isclose(a, b, rel_tol=1e-3) + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + def test_empty_input(self): + normalizer = RewardNormalizer(mode="zscore") + assert normalizer.normalize([]) == [] + + def test_none_mode_passthrough(self): + normalizer = RewardNormalizer(mode="none") + scores = [1.0, 2.0, 3.0] + assert normalizer.normalize(scores) == scores + + def test_invalid_mode_raises(self): + with pytest.raises(ValueError, match="Invalid normalization mode"): + RewardNormalizer(mode="invalid") + + def test_metrics_dict_keys(self): + normalizer = RewardNormalizer(mode="zscore", warmup=0) + normalizer.normalize([1.0, 2.0, 3.0]) + metrics = normalizer.metrics_dict() + assert "reward_norm/count" in metrics + assert "reward_norm/mean" in metrics + assert "reward_norm/std" in metrics + assert "reward_norm/min" in metrics + assert "reward_norm/max" in metrics diff --git a/atroposlib/tests/verify_e2e.py b/atroposlib/tests/verify_e2e.py new file mode 100644 index 000000000..046c431f4 --- /dev/null +++ b/atroposlib/tests/verify_e2e.py @@ -0,0 +1,128 @@ +import asyncio +import logging +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np + +from atroposlib.envs.base import BaseEnv, BaseEnvConfig, ScoredDataGroup, ScoredDataItem +from atroposlib.envs.server_handling.server_manager import ( + APIServerConfig, + ServerBaseline, +) +from atroposlib.type_definitions import Item + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("verify_e2e") + + +class MockEnvConfig(BaseEnvConfig): + """Configuration for E2E verification.""" + + # Inherits all new features from BaseEnvConfig + tokenizer_name: str = "gpt2" + group_size: int = 4 + num_difficulty_bins: int = 5 + + +class MockEnv(BaseEnv): + """A minimal environment to verify BaseEnv integration features.""" + + async def setup(self): + self.items = [f"item_{i}" for i in range(20)] + self.iter = 0 + + async def get_next_item(self) -> Item: + item = self.items[self.iter % len(self.items)] + self.iter += 1 + return item + + def format_prompt(self, item: Item) -> str: + return f"Prompt for {item}" + + async def collect_trajectory(self, item: Item) -> Tuple[ScoredDataItem, List[Item]]: + # Simulate a rollout with multiple rewards for the ensemble + # Rewards vary based on "item index" to test curriculum/normalization + idx = int(item.split("_")[1]) + base_reward = float(idx) / 20.0 + + # Multiple scores to trigger consensus + scores = [base_reward, base_reward + 0.1, base_reward - 0.1] + + # Add some noise to test stability + scores = [s + np.random.normal(0, 0.01) for s in scores] + + return { + "tokens": [1, 2, 3], + "masks": [0, 1, 1], + "scores": scores, # List of scores triggers Ensemble + }, [] + + async def evaluate(self, *args, **kwargs): + pass + + +async def main(): + logger.info("Starting E2E Readiness Check...") + + # 1. Setup Config with ALL features enabled + config = MockEnvConfig( + reward_mode="consensus", # Ensemble + reward_normalization="zscore", # Normalization + curriculum_strategy="easy_first", # Curriculum + track_api_perf=True, # Perf Tracker + use_wandb=False, # No real wandb for test + tokenizer_name="gpt2", + group_size=4, + warmup_steps=2, # Early normalization transition + ) + + # Mock server config (BaseEnv needs it but we won't use real server) + server_configs = [ + APIServerConfig( + model_name="mock", base_url="http://localhost:8000", api_key="test" + ) + ] + + # 2. Initialize Env + env = MockEnv(config, server_configs, testing=True) + await env.setup() + + logger.info("Environment initialized with all RL features.") + + # 3. Simulate 5 steps (rollout groups) + for step in range(1, 6): + logger.info(f"--- Step {step} ---") + + # Get item (this should go through the curriculum scheduler) + item = await env.get_next_item() + + # Collect (this triggers ensemble and normalization update) + results, next_items = await env.collect_trajectories(item) + + # Manually trigger wandb_log (it updates stats and formats metrics) + metrics = {} + await env.wandb_log(metrics) + + # 4. Verify presence of expected keys + expected_prefixes = ["reward_norm/", "curriculum/", "api_perf/"] + found_keys = [ + k for k in metrics.keys() if any(k.startswith(p) for p in expected_prefixes) + ] + + logger.info(f"Metrics keys found: {found_keys}") + + if "reward_norm/mean" in metrics: + logger.info(f"Normalization Mean: {metrics['reward_norm/mean']:.4f}") + if "curriculum/target_bin" in metrics: + logger.info(f"Curriculum Target Bin: {metrics['curriculum/target_bin']}") + if "api_perf/items_per_sec" in metrics: + logger.info( + f"API Throughput: {metrics['api_perf/items_per_sec']:.2f} items/s" + ) + + logger.info("E2E Readiness Check Completed Successfully!") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/atroposlib/tests/verify_hermes_compat.py b/atroposlib/tests/verify_hermes_compat.py new file mode 100644 index 000000000..a1aad26d9 --- /dev/null +++ b/atroposlib/tests/verify_hermes_compat.py @@ -0,0 +1,64 @@ +import os +import sys +from pathlib import Path + +# 1. Add hermes-agent and atropos to sys.path +repo_root = Path("/home/ruffy-369/NousResearch/hermes-agent") +atropos_root = Path("/home/ruffy-369/NousResearch/atropos") + +if str(repo_root) not in sys.path: + sys.path.insert(0, str(repo_root)) +if str(atropos_root) not in sys.path: + sys.path.insert(0, str(atropos_root)) + +try: + from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig + + from atroposlib.envs.server_handling.server_manager import APIServerConfig + + print("✅ Import successful") +except ImportError as e: + print(f"❌ Import failed: {e}") + sys.exit(1) + + +def test_init(): + print("Testing HermesAgentBaseEnv initialization...") + try: + # Create a config + config = HermesAgentEnvConfig(tokenizer_name="gpt2", group_size=4) + + # Verify inheritance of new Atropos fields + print( + f"Checking inherited fields: reward_mode={config.reward_mode}, track_api_perf={config.track_api_perf}" + ) + + # Mock server configs + server_configs = [ + APIServerConfig( + model_name="mock", base_url="http://localhost:8000", api_key="test" + ) + ] + + # Initialize (with testing=True to use ServerHarness) + env = HermesAgentBaseEnv(config, server_configs, testing=True) + print("✅ Initialization successful") + + # Verify wandb_log signature compatibility + print("Checking wandb_log signature compat...") + import asyncio + + asyncio.run(env.wandb_log({})) + print("✅ wandb_log call successful") + + except Exception as e: + print(f"❌ Initialization failed: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + test_init() + print("🚀 hermes-agent Compatibility Check PASSED!") diff --git a/atroposlib/utils/api_perf.py b/atroposlib/utils/api_perf.py new file mode 100644 index 000000000..989ffea0e --- /dev/null +++ b/atroposlib/utils/api_perf.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +""" +API performance tracker for trainer-inference communication optimization. + +Provides lightweight latency and throughput monitoring for the scored_data +API round-trip, enabling bottleneck identification in the trainer-inference +communication pipeline. + +Features: +- Rolling window latency tracking (configurable window size) +- Throughput computation (items/sec, requests/sec) +- Percentile latency statistics (p50, p95, p99) +- Compression ratio tracking +- WandB-compatible metrics output + +Usage: + tracker = APIPerformanceTracker(window_size=100) + + # Around API call + with tracker.track_request(n_items=group_size, payload_bytes=len(data)): + await send_scored_data(...) + + # Log to wandb + wandb.log(tracker.metrics_dict()) +""" + +import logging +import time +from collections import deque +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Dict + +import numpy as np + +logger = logging.getLogger(__name__) + + +@dataclass +class RequestRecord: + """Record of a single API request.""" + + latency_ms: float + n_items: int + payload_bytes: int + compressed_bytes: int + timestamp: float + success: bool = True + + +class APIPerformanceTracker: + """ + Lightweight performance tracker for trainer-inference API communication. + + Maintains a rolling window of request records for computing latency + and throughput statistics without unbounded memory growth. + + Args: + window_size: Number of recent requests to keep for stats. Default: 200. + slow_request_threshold_ms: Latency above this triggers a warning. Default: 5000. + """ + + def __init__( + self, + window_size: int = 200, + slow_request_threshold_ms: float = 5000.0, + ): + self.window_size = max(1, window_size) + self.slow_request_threshold_ms = slow_request_threshold_ms + self._records: deque = deque(maxlen=self.window_size) + self._total_requests: int = 0 + self._total_items: int = 0 + self._total_bytes_sent: int = 0 + self._total_compressed_bytes: int = 0 + self._failed_requests: int = 0 + + @contextmanager + def track_request( + self, + n_items: int = 1, + payload_bytes: int = 0, + compressed_bytes: int = 0, + ): + """ + Context manager to track a single API request. + + Args: + n_items: Number of items (completions) in this request. + payload_bytes: Size of the uncompressed payload. + compressed_bytes: Size of the compressed payload (0 if no compression). + + Yields: + None. Timing is handled automatically. + """ + start = time.monotonic() + success = True + try: + yield + except Exception: + success = False + self._failed_requests += 1 + raise + finally: + elapsed_ms = (time.monotonic() - start) * 1000.0 + + record = RequestRecord( + latency_ms=elapsed_ms, + n_items=n_items, + payload_bytes=payload_bytes, + compressed_bytes=( + compressed_bytes if compressed_bytes > 0 else payload_bytes + ), + timestamp=time.time(), + success=success, + ) + self._records.append(record) + self._total_requests += 1 + self._total_items += n_items + self._total_bytes_sent += payload_bytes + self._total_compressed_bytes += ( + compressed_bytes if compressed_bytes > 0 else payload_bytes + ) + + if elapsed_ms > self.slow_request_threshold_ms: + logger.warning( + "Slow API request: %.1fms (threshold: %.1fms, items: %d)", + elapsed_ms, + self.slow_request_threshold_ms, + n_items, + ) + + def record_request( + self, + latency_ms: float, + n_items: int = 1, + payload_bytes: int = 0, + compressed_bytes: int = 0, + success: bool = True, + ): + """ + Manually record a request (for cases where context manager isn't suitable). + + Args: + latency_ms: Request latency in milliseconds. + n_items: Number of items in the request. + payload_bytes: Uncompressed payload size. + compressed_bytes: Compressed payload size. + success: Whether the request succeeded. + """ + record = RequestRecord( + latency_ms=latency_ms, + n_items=n_items, + payload_bytes=payload_bytes, + compressed_bytes=( + compressed_bytes if compressed_bytes > 0 else payload_bytes + ), + timestamp=time.time(), + success=success, + ) + self._records.append(record) + self._total_requests += 1 + self._total_items += n_items + self._total_bytes_sent += payload_bytes + self._total_compressed_bytes += ( + compressed_bytes if compressed_bytes > 0 else payload_bytes + ) + if not success: + self._failed_requests += 1 + + if latency_ms > self.slow_request_threshold_ms: + logger.warning( + "Slow API request: %.1fms (threshold: %.1fms, items: %d)", + latency_ms, + self.slow_request_threshold_ms, + n_items, + ) + + @property + def n_records(self) -> int: + """Number of records in the rolling window.""" + return len(self._records) + + def latency_stats(self) -> Dict[str, float]: + """ + Compute latency statistics from the rolling window. + + Returns: + Dictionary with p50, p95, p99, mean, min, max latencies in ms. + """ + if not self._records: + return { + "p50_ms": 0.0, + "p95_ms": 0.0, + "p99_ms": 0.0, + "mean_ms": 0.0, + "min_ms": 0.0, + "max_ms": 0.0, + } + + latencies = np.array([r.latency_ms for r in self._records]) + return { + "p50_ms": float(np.percentile(latencies, 50)), + "p95_ms": float(np.percentile(latencies, 95)), + "p99_ms": float(np.percentile(latencies, 99)), + "mean_ms": float(np.mean(latencies)), + "min_ms": float(np.min(latencies)), + "max_ms": float(np.max(latencies)), + } + + def throughput_stats(self) -> Dict[str, float]: + """ + Compute throughput statistics from the rolling window. + + Returns: + Dictionary with items/sec and requests/sec over the window. + """ + if len(self._records) < 2: + return { + "items_per_sec": 0.0, + "requests_per_sec": 0.0, + } + + records = list(self._records) + time_span = records[-1].timestamp - records[0].timestamp + if time_span <= 0: + return { + "items_per_sec": 0.0, + "requests_per_sec": 0.0, + } + + total_items = sum(r.n_items for r in records) + return { + "items_per_sec": total_items / time_span, + "requests_per_sec": len(records) / time_span, + } + + def compression_stats(self) -> Dict[str, float]: + """ + Compute compression statistics from the rolling window. + + Returns: + Dictionary with mean compression ratio and total bytes sent. + """ + if not self._records: + return { + "mean_compression_ratio": 1.0, + "mean_payload_bytes": 0.0, + } + + ratios = [] + payloads = [] + for r in self._records: + if r.payload_bytes > 0: + ratios.append(r.compressed_bytes / r.payload_bytes) + payloads.append(float(r.payload_bytes)) + + return { + "mean_compression_ratio": float(np.mean(ratios)) if ratios else 1.0, + "mean_payload_bytes": float(np.mean(payloads)), + } + + def metrics_dict(self) -> Dict[str, float]: + """ + Return all performance metrics for WandB logging. + + Returns: + Dictionary with keys prefixed by 'api_perf/' for clean namespacing. + """ + metrics = {} + + latency = self.latency_stats() + for key, val in latency.items(): + metrics[f"api_perf/latency_{key}"] = val + + throughput = self.throughput_stats() + for key, val in throughput.items(): + metrics[f"api_perf/{key}"] = val + + compression = self.compression_stats() + for key, val in compression.items(): + metrics[f"api_perf/{key}"] = val + + metrics["api_perf/total_requests"] = float(self._total_requests) + metrics["api_perf/total_items"] = float(self._total_items) + metrics["api_perf/failed_requests"] = float(self._failed_requests) + metrics["api_perf/error_rate"] = self._failed_requests / max( + 1, self._total_requests + ) + + return metrics + + def reset(self): + """Clear all records and counters.""" + self._records.clear() + self._total_requests = 0 + self._total_items = 0 + self._total_bytes_sent = 0 + self._total_compressed_bytes = 0 + self._failed_requests = 0 diff --git a/pyproject.toml b/pyproject.toml index 6f23666c3..cdc47a2c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "jsonlines", "pydantic-cli", "hf_transfer", + "antlr4-python3-runtime==4.9.3", ] [project.scripts]