Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions environments/community/arithmetic_chain/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Arithmetic Chain

Self-contained RL environment: procedurally generated multi-step integer problems (add / subtract / multiply from a starting value). The model must answer with `\boxed{integer}`; rewards use the same `math_verify` path as GSM8K.

**No Hugging Face dataset** — training items are sampled on the fly.

## Run (serve)

From the repo root, with Atropos API and an OpenAI-compatible inference server configured in `config_init` or via CLI overrides:

```bash
python environments/community/arithmetic_chain/arithmetic_chain_server.py serve --slurm false
```

## Process (debug rollouts)

```bash
python environments/community/arithmetic_chain/arithmetic_chain_server.py process \
--env.data_path_to_save_groups rollouts.jsonl \
--slurm false
```

Uses `ManagedServer` for token/logprob tracking (compatible with trainers that expect Atropos’ standard scored groups).
324 changes: 324 additions & 0 deletions environments/community/arithmetic_chain/arithmetic_chain_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,324 @@
"""
Procedural multi-step arithmetic chains: start from an integer, apply add/sub/mul steps,
then answer the final value in \\boxed{}. Self-contained (no dataset download).
"""

import random
import time
from typing import List, Optional, Tuple, TypedDict, Union

from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from tqdm.asyncio import tqdm_asyncio

from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
ServerBaseline,
)
from atroposlib.type_definitions import Item

system_prompt = (
"You solve short arithmetic word problems. Think step by step if helpful, "
"then give the final integer inside \\boxed{} with no extra text after it.\n\n"
)


class ArithmeticChainRow(TypedDict):
question: str
answer: str


def sample_chain(
rng: random.Random, min_steps: int = 2, max_steps: int = 4
) -> ArithmeticChainRow:
value = rng.randint(2, 24)
parts = [f"You start with {value}."]
num_steps = rng.randint(min_steps, max_steps)
for _ in range(num_steps):
choices = ["add", "mul"]
if value > 2:
choices.append("sub")
op = rng.choice(choices)
if op == "add":
n = rng.randint(1, 18)
value = value + n
parts.append(f"Add {n}.")
elif op == "sub":
n = rng.randint(1, min(17, value - 1))
value = value - n
parts.append(f"Subtract {n}.")
else:
n = rng.randint(2, 9)
value = value * n
parts.append(f"Multiply by {n}.")
if abs(value) > 900:
break
parts.append("What is the resulting integer? Answer with \\boxed{your_answer}.")
question = " ".join(parts)
return {"question": question, "answer": str(int(value))}


class ArithmeticChainEnv(BaseEnv):
name = "arithmetic_chain"

def __init__(
self,
config: BaseEnvConfig,
server_configs: List[APIServerConfig],
slurm=True,
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer: list[float] = []
self.eval_metrics: list[tuple[str, float]] = []
self.train_rng = random.Random(42)
self.eval_rng = random.Random(2025)

@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, ServerBaseline]:
env_config = BaseEnvConfig(
tokenizer_name="meta-llama/Llama-3.2-1B",
group_size=8,
use_wandb=False,
rollout_server_url="http://localhost:8000",
total_steps=500,
batch_size=16,
steps_per_eval=50,
max_token_length=512,
wandb_name="arithmetic_chain",
)
server_config = APIServerConfig(
model_name="meta-llama/Llama-3.2-1B",
base_url="http://localhost:8001/v1",
api_key="x",
num_requests_for_eval=128,
)
return env_config, server_config

async def wandb_log(self, wandb_metrics: Optional[dict] = None):
if wandb_metrics is None:
wandb_metrics = {}
if self.percent_correct_buffer:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / len(self.percent_correct_buffer)
self.percent_correct_buffer = []
for key, val in self.eval_metrics:
wandb_metrics[key] = val
self.eval_metrics = []
await super().wandb_log(wandb_metrics)

async def setup(self):
self.train = [sample_chain(self.train_rng) for _ in range(4096)]
self.test = [sample_chain(self.eval_rng) for _ in range(64)]
self.iter = 0

def save_checkpoint(self, step, data=None):
if data is None:
data = {}
data["iter"] = self.iter
super().save_checkpoint(step, data)

async def rollout_and_score_eval(self, question: str, answer: str) -> dict:
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completion = await managed.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": question},
],
n=1,
max_tokens=self.config.max_token_length,
temperature=0.0,
stop=(
[self.tokenizer.eos_token_id]
if self.tokenizer.eos_token_id is not None
else None
),
)
response_content = completion.choices[0].message.content

gold_parsed = parse(
"\\boxed{" + answer + "}",
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
answer_parsed = parse(
response_content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed="all",
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
score = 1 if verify(answer_parsed, gold_parsed) else 0
sample = {
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question},
{"role": "assistant", "content": response_content},
],
"question": question,
"gold_answer": answer,
"score": int(score),
"correct": bool(score),
"finish_reason": completion.choices[0].finish_reason,
}
return {"score": score, "sample": sample}

async def evaluate(self, *args, **kwargs):
start_time = time.time()
eval_tasks = [
self.rollout_and_score_eval(item["question"], item["answer"])
for item in self.test
]
results = await tqdm_asyncio.gather(*eval_tasks)
scores = [r["score"] for r in results]
samples = [r["sample"] for r in results]
percent_correct = sum(scores) / len(scores)
end_time = time.time()
self.eval_metrics.append(("eval/percent_correct", percent_correct))
await self.evaluate_log(
metrics={"eval/percent_correct": percent_correct},
samples=samples,
start_time=start_time,
end_time=end_time,
generation_parameters={
"temperature": 0.0,
"max_tokens": self.config.max_token_length,
},
)

async def collect_trajectories(
self, item: ArithmeticChainRow
) -> Tuple[ScoredDataGroup, list[Item]]:
user_message = {"role": "user", "content": item["question"]}
gold_answer = "\\boxed{" + item["answer"] + "}"
stop = (
[self.tokenizer.eos_token_id]
if self.tokenizer.eos_token_id is not None
else None
)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
chat_completions = await managed.chat_completion(
messages=[{"role": "system", "content": system_prompt}, user_message],
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=1.0,
stop=stop,
)
state = managed.get_state()
nodes = state["nodes"]

to_score = []
to_backlog = []
for i, chat_completion in enumerate(chat_completions.choices):
messages = (
{"role": "system", "content": system_prompt},
user_message,
{"role": "assistant", "content": chat_completion.message.content},
)
to_score.append(
{
"messages": messages,
"gold_answer": gold_answer,
"finish_reason": chat_completion.finish_reason,
"tokens": nodes[i].tokens,
"masks": nodes[i].masked_tokens,
"logprobs": nodes[i].logprobs,
}
)
to_postprocess = await self.score(to_score)
return to_postprocess, to_backlog

async def score(
self, rollout_group_data
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
scores = ScoredDataGroup()
scores["tokens"] = []
scores["masks"] = []
scores["scores"] = []
scores["inference_logprobs"] = []
gold_parsed = parse(
rollout_group_data[0]["gold_answer"],
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
if len(gold_parsed) == 0:
return None
random.shuffle(rollout_group_data)
for item in rollout_group_data:
answer_parsed = parse(
item["messages"][-1]["content"],
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed="all",
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
reward = verify(answer_parsed, gold_parsed)
tokens = item["tokens"]
masks = item["masks"]
logprobs = item["logprobs"]
if len([1 for m in masks if m != -100]) < 8:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["inference_logprobs"].append(logprobs)
scores["scores"].append(1.0 if reward else -1.0)
if len(scores["tokens"]) >= self.config.group_size:
break
if not scores["scores"]:
return None
for s in scores["scores"]:
self.percent_correct_buffer.append(max(s, 0))
if all(s == 1 for s in scores["scores"]):
token_lengths = [len(t) for t in scores["tokens"]]
if not token_lengths:
return None
max_allowed = self.config.max_token_length
threshold = max_allowed * 0.5
scores["scores"] = []
for length in token_lengths:
if length <= threshold:
scores["scores"].append(1.0)
else:
pct = (length - threshold) / (max_allowed - threshold)
pct = min(pct, 1.0)
scores["scores"].append(1.0 - pct)
if len(scores["scores"]) >= 2 and all(
scores["scores"][0] == s for s in scores["scores"]
):
return None
return scores

async def get_next_item(self) -> ArithmeticChainRow:
item = self.train[self.iter % len(self.train)]
self.iter += 1
return item


if __name__ == "__main__":
ArithmeticChainEnv.cli()