diff --git a/.secrets.baseline b/.secrets.baseline index f93f324a6..4bd3270a7 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -151,7 +151,7 @@ "filename": "atroposlib/envs/eval.py", "hashed_secret": "829c3804401b0727f70f73d4415e162400cbe57b", "is_verified": false, - "line_number": 218 + "line_number": 225 } ], "atroposlib/tests/test_reasoning_models.py": [ diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 3d3b6c207..11499cad6 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -791,6 +791,30 @@ async def evaluate_log( writer.write(sample) logger.info("Evaluation samples saved to %s", samples_filepath) + try: + from atroposlib.frontend.jsonl2html import generate_eval_html + + generate_eval_html(samples_filepath) + except Exception as e: + logger.warning("Failed to generate eval HTML viewer: %s", e) + + def log_eval_sample(self, sample): + """Stream-write a single eval sample to samples.jsonl. + + Lazy-initializes the writer on first call. Use this inside evaluate() + to write samples as they complete rather than batching at the end. + If using this, omit the samples= parameter from evaluate_log(). + """ + if self._eval_sample_writer is None: + if self.config.data_dir_to_save_evals is None: + return + os.makedirs(self.config.data_dir_to_save_evals, exist_ok=True) + self._eval_samples_path = os.path.join( + self.config.data_dir_to_save_evals, "samples.jsonl" + ) + self._eval_sample_writer = jsonlines.open(self._eval_samples_path, "w") + self._eval_sample_writer.write(sample) + @retry( stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10), @@ -1336,8 +1360,32 @@ async def _run_evaluate(self): """ Internal method to run evaluation with proper setup. """ + self._eval_sample_writer = None + self._eval_samples_path = None await self.setup() - await self.evaluate() + try: + await self.evaluate() + finally: + # Close streaming eval sample writer if it was used + if self._eval_sample_writer is not None: + self._eval_sample_writer.close() + if self._eval_samples_path: + try: + from atroposlib.frontend.jsonl2html import generate_eval_html + + generate_eval_html(self._eval_samples_path) + except Exception as e: + logger.warning("Failed to generate eval HTML: %s", e) + # Close JSONL trajectory writer if it was used + if self.jsonl_writer is not None: + self.jsonl_writer.close() + if self.config.data_path_to_save_groups: + try: + from atroposlib.frontend.jsonl2html import generate_html + + generate_html(self.config.data_path_to_save_groups) + except Exception as e: + logger.warning("Failed to generate trajectory HTML: %s", e) @classmethod def cli(cls): @@ -1928,6 +1976,10 @@ def run(self) -> None: env_config_dict_base = default_env_config_from_init.model_dump() # Apply specific overrides for evaluate mode that are generally useful env_config_dict_base["use_wandb"] = True + if env_config_dict_base.get("data_dir_to_save_evals") is None: + env_config_dict_base["data_dir_to_save_evals"] = ( + f"eval_results/{cls.name or 'eval'}" + ) env_config_dict = merge_dicts( env_config_dict_base, # `config_init` defaults with evaluate adjustments diff --git a/atroposlib/envs/eval.py b/atroposlib/envs/eval.py index 21090b03f..0760c128a 100644 --- a/atroposlib/envs/eval.py +++ b/atroposlib/envs/eval.py @@ -104,6 +104,13 @@ def evaluate_log( writer.write(sample) print(f"Evaluation samples saved to {samples_filepath}") + try: + from atroposlib.frontend.jsonl2html import generate_eval_html + + generate_eval_html(samples_filepath) + except Exception as e: + print(f"Warning: Failed to generate eval HTML viewer: {e}") + class EvalBase(ABC): """ """ diff --git a/atroposlib/frontend/jsonl2html.py b/atroposlib/frontend/jsonl2html.py index 264dbb64a..57e8c9741 100644 --- a/atroposlib/frontend/jsonl2html.py +++ b/atroposlib/frontend/jsonl2html.py @@ -117,6 +117,117 @@ def create_html_for_group(group_data, index): return group_html +# --- Eval Sample Conversion --- + + +def _eval_sample_to_viewable(sample): + """Convert a single eval sample dict to {messages, scores} format for the HTML viewer.""" + # Extract score + score = sample.get("score") + if score is None: + is_correct = sample.get("is_correct") + if is_correct is not None: + score = 1.0 if is_correct else 0.0 + else: + grade = sample.get("grade", "") + score = 1.0 if grade == "CORRECT" else 0.0 + + # Build conversation from available fields + if "messages" in sample and isinstance(sample["messages"], list): + conversation = sample["messages"] + else: + conversation = [] + question = sample.get("question") or sample.get("problem") or "" + if question: + conversation.append({"role": "user", "content": str(question)}) + + response = sample.get("model_response") or sample.get("response") or "" + if response: + conversation.append({"role": "assistant", "content": str(response)}) + + gold = sample.get("gold_answer") or sample.get("answer") or "" + if gold: + conversation.append({"role": "system", "content": f"[Gold Answer]: {gold}"}) + + if not conversation: + return None + + return {"messages": [conversation], "scores": [score]} + + +def generate_eval_html(input_path, output_path=None): + """Generate an HTML viewer from eval-format samples.jsonl. + + Each line is a flat dict with task-specific fields (question, model_response, score, etc.). + Converts them to the {messages, scores} format used by the existing HTML template. + """ + input_filepath = Path(input_path) + if not input_filepath.is_file(): + print(f"Error: Input file not found: {input_filepath}", file=sys.stderr) + return + + if output_path is None: + output_filepath = input_filepath.with_suffix(".html") + else: + output_filepath = Path(output_path) + + output_filepath.parent.mkdir(parents=True, exist_ok=True) + + try: + with open(TEMPLATE_FILE, "r", encoding="utf-8") as f_template: + html_template_content = f_template.read() + except FileNotFoundError: + print(f"Error: Template file not found: {TEMPLATE_FILE}", file=sys.stderr) + return + + all_groups_html = [] + group_index = 0 + with open(input_filepath, "r", encoding="utf-8") as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + try: + sample = json.loads(line) + except json.JSONDecodeError: + print( + f"Warning: Skipping line {line_num}. Invalid JSON.", + file=sys.stderr, + ) + continue + + viewable = _eval_sample_to_viewable(sample) + if viewable is None: + continue + + group_html = create_html_for_group(viewable, group_index) + if group_html: + all_groups_html.append(group_html) + group_index += 1 + + if not all_groups_html: + print("Warning: No valid eval samples to render.", file=sys.stderr) + groups_content = "

No data to display.

" + else: + groups_content = "\n".join(all_groups_html) + + title = f"Eval Results - {input_filepath.name}" + try: + final_html = html_template_content.format( + title=html.escape(title), groups_html=groups_content + ) + except KeyError as e: + print( + f"Error: Template missing placeholder: {{{e}}}", + file=sys.stderr, + ) + return + + with open(output_filepath, "w", encoding="utf-8") as f: + f.write(final_html) + print(f"Generated eval HTML viewer: {output_filepath.absolute()}") + + # --- Main Function --- diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py index 87823526e..09eca7c5b 100644 --- a/environments/gsm8k_server.py +++ b/environments/gsm8k_server.py @@ -189,17 +189,18 @@ async def rollout_and_score_eval(self, question: str, answer: str) -> dict: async def evaluate(self, *args, **kwargs): start_time = time.time() - eval_tasks = [] - for item in self.test: - eval_tasks.append( - self.rollout_and_score_eval(item["question"], item["gold_answer"]) + async def rollout_and_log(item): + result = await self.rollout_and_score_eval( + item["question"], item["gold_answer"] ) + if result is not None: + self.log_eval_sample(result.get("sample", result)) + return result + + eval_tasks = [rollout_and_log(item) for item in self.test] results = await tqdm_asyncio.gather(*eval_tasks) - # Extract scores and samples scores = [result["score"] for result in results] - samples = [result["sample"] for result in results] - percent_correct = sum(scores) / len(scores) end_time = time.time() @@ -207,14 +208,8 @@ async def evaluate(self, *args, **kwargs): # Add to existing metrics for wandb self.eval_metrics.append(("eval/percent_correct", percent_correct)) - # Log evaluation results - eval_metrics = { - "eval/percent_correct": percent_correct, - } - await self.evaluate_log( - metrics=eval_metrics, - samples=samples, + metrics={"eval/percent_correct": percent_correct}, start_time=start_time, end_time=end_time, generation_parameters={