diff --git a/src/oumi/analyze/utils/dataframe.py b/src/oumi/analyze/utils/dataframe.py index 26ce8f0c18..66cac3b987 100644 --- a/src/oumi/analyze/utils/dataframe.py +++ b/src/oumi/analyze/utils/dataframe.py @@ -28,7 +28,10 @@ def to_analysis_dataframe( conversations: list[Conversation], - results: Mapping[str, Sequence[BaseModel] | BaseModel], + results: Mapping[ + str, + Sequence[BaseModel | dict[str, Any]] | BaseModel | dict[str, Any], + ], message_to_conversation_idx: list[int] | None = None, ) -> pd.DataFrame: """Convert typed analysis results to a pandas DataFrame. @@ -117,7 +120,9 @@ def to_analysis_dataframe( "missing values.", ) - elif isinstance(analyzer_results, BaseModel): + elif isinstance(analyzer_results, (BaseModel, dict)): + # Dataset-level result — same for every conversation. May be a + # raw dict when loaded from the pipeline's JSON cache. _add_result_to_row(row, analyzer_results, prefix) rows.append(row) @@ -211,7 +216,10 @@ def _add_result_to_row( def results_to_dict( - results: Mapping[str, Sequence[BaseModel] | BaseModel], + results: Mapping[ + str, + Sequence[BaseModel | dict[str, Any]] | BaseModel | dict[str, Any], + ], ) -> dict[str, list[dict[str, Any]] | dict[str, Any]]: """Convert typed results to a serializable dictionary. @@ -227,8 +235,12 @@ def results_to_dict( for name, result in results.items(): if isinstance(result, list): - output[name] = [r.model_dump() for r in result] + output[name] = [ + r.model_dump() if isinstance(r, BaseModel) else r for r in result + ] elif isinstance(result, BaseModel): output[name] = result.model_dump() + elif isinstance(result, dict): + output[name] = result return output diff --git a/tests/unit/analyze/test_dataframe_utils.py b/tests/unit/analyze/test_dataframe_utils.py index 3263a4d1d5..c05dbd0299 100644 --- a/tests/unit/analyze/test_dataframe_utils.py +++ b/tests/unit/analyze/test_dataframe_utils.py @@ -417,3 +417,52 @@ def test_results_to_dict_empty(): output = results_to_dict({}) assert output == {} + + +# ----------------------------------------------------------------------------- +# Tests: Cached (raw dict) results +# ----------------------------------------------------------------------------- + + +def test_dataset_level_cached_dict_in_dataframe(sample_conversations): + """Dataset-level results as raw dicts (from cache) are included.""" + results = { + "Stats": {"total_count": 100, "avg_score": 75.5}, + } + + df = to_analysis_dataframe(sample_conversations, results) + + assert "stats__total_count" in df.columns + assert "stats__avg_score" in df.columns + assert (df["stats__total_count"] == 100).all() + assert (df["stats__avg_score"] == 75.5).all() + + +def test_per_conversation_cached_dicts_in_dataframe(sample_conversations): + """Per-conversation results as raw dicts (from cache) work.""" + results = { + "Simple": [ + {"score": 80, "name": "a"}, + {"score": 90, "name": "b"}, + ], + } + + df = to_analysis_dataframe(sample_conversations, results) + + assert "simple__score" in df.columns + assert "simple__name" in df.columns + assert df["simple__score"].tolist() == [80, 90] + assert df["simple__name"].tolist() == ["a", "b"] + + +def test_results_to_dict_passes_through_cached_dicts(): + """Raw dict results pass through without model_dump.""" + results = { + "Stats": {"total_count": 100}, + "PerConv": [{"score": 1}, {"score": 2}], + } + + output = results_to_dict(results) + + assert output["Stats"] == {"total_count": 100} + assert output["PerConv"] == [{"score": 1}, {"score": 2}]