Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions atroposlib/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,24 @@ class BaseEnvConfig(BaseModel):
"no thinking prompt is injected. Use HERMES_REASONING_PROMPT from "
"eval_helpers for the standard Hermes reasoning prompt.",
)
curriculum_strategy: str = Field(
default="uniform",
description="Curriculum learning strategy. 'uniform' = no curriculum (default), "
"'easy_first' = oversample easy items early then anneal, "
"'competence_based' = sample at competence frontier. "
"See Platanios et al. 2019 for competence-based curriculum.",
)
curriculum_bins: int = Field(
default=5,
ge=1,
description="Number of difficulty bins for curriculum scheduling.",
)
curriculum_temperature: float = Field(
default=1.0,
gt=0,
description="Temperature for curriculum bin sampling. Higher = more uniform, "
"lower = more concentrated on target difficulty.",
)


class BaseEnv(ABC):
Expand Down Expand Up @@ -262,6 +280,17 @@ def __init__(
self.max_token_len = -1
self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
self.completion_lengths = []
# Initialize curriculum scheduler (opt-in via config)
if config.curriculum_strategy != "uniform":
from atroposlib.envs.curriculum import CurriculumScheduler

self.curriculum = CurriculumScheduler(
strategy=config.curriculum_strategy,
n_bins=config.curriculum_bins,
temperature=config.curriculum_temperature,
)
else:
self.curriculum = None
self.max_num_workers = config.max_num_workers
if self.max_num_workers == -1:
self.max_num_workers = config.max_num_workers_per_node * len(
Expand Down Expand Up @@ -674,6 +703,9 @@ def wandb_log(self, wandb_metrics: Optional[Dict] = None):
wandb_metrics["train/completion_lengths_p95"] = (
np.array(self.completion_lengths) > (0.95 * self.max_token_len)
).mean()
# Log curriculum metrics if active
if self.curriculum is not None:
wandb_metrics.update(self.curriculum.metrics_dict())
wandb_metrics = await self.create_rollout_table(wandb_metrics)
wandb_metrics = self.perf_stats(wandb_metrics)
self.rollouts_for_wandb = []
Expand Down
Loading