Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions src/oumi/analyze/utils/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@

def to_analysis_dataframe(
conversations: list[Conversation],
results: Mapping[str, Sequence[BaseModel] | BaseModel],
results: Mapping[
str,
Sequence[BaseModel | dict[str, Any]] | BaseModel | dict[str, Any],
],
message_to_conversation_idx: list[int] | None = None,
) -> pd.DataFrame:
"""Convert typed analysis results to a pandas DataFrame.
Expand Down Expand Up @@ -117,7 +120,9 @@ def to_analysis_dataframe(
"missing values.",
)

elif isinstance(analyzer_results, BaseModel):
elif isinstance(analyzer_results, (BaseModel, dict)):
# Dataset-level result — same for every conversation. May be a
# raw dict when loaded from the pipeline's JSON cache.
_add_result_to_row(row, analyzer_results, prefix)

rows.append(row)
Expand Down Expand Up @@ -211,7 +216,10 @@ def _add_result_to_row(


def results_to_dict(
results: Mapping[str, Sequence[BaseModel] | BaseModel],
results: Mapping[
str,
Sequence[BaseModel | dict[str, Any]] | BaseModel | dict[str, Any],
],
) -> dict[str, list[dict[str, Any]] | dict[str, Any]]:
"""Convert typed results to a serializable dictionary.

Expand All @@ -227,8 +235,12 @@ def results_to_dict(

for name, result in results.items():
if isinstance(result, list):
output[name] = [r.model_dump() for r in result]
output[name] = [
r.model_dump() if isinstance(r, BaseModel) else r for r in result
]
elif isinstance(result, BaseModel):
output[name] = result.model_dump()
elif isinstance(result, dict):
output[name] = result

return output
49 changes: 49 additions & 0 deletions tests/unit/analyze/test_dataframe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,52 @@ def test_results_to_dict_empty():
output = results_to_dict({})

assert output == {}


# -----------------------------------------------------------------------------
# Tests: Cached (raw dict) results
# -----------------------------------------------------------------------------


def test_dataset_level_cached_dict_in_dataframe(sample_conversations):
"""Dataset-level results as raw dicts (from cache) are included."""
results = {
"Stats": {"total_count": 100, "avg_score": 75.5},
}

df = to_analysis_dataframe(sample_conversations, results)

assert "stats__total_count" in df.columns
assert "stats__avg_score" in df.columns
assert (df["stats__total_count"] == 100).all()
assert (df["stats__avg_score"] == 75.5).all()


def test_per_conversation_cached_dicts_in_dataframe(sample_conversations):
"""Per-conversation results as raw dicts (from cache) work."""
results = {
"Simple": [
{"score": 80, "name": "a"},
{"score": 90, "name": "b"},
],
}

df = to_analysis_dataframe(sample_conversations, results)

assert "simple__score" in df.columns
assert "simple__name" in df.columns
assert df["simple__score"].tolist() == [80, 90]
assert df["simple__name"].tolist() == ["a", "b"]


def test_results_to_dict_passes_through_cached_dicts():
"""Raw dict results pass through without model_dump."""
results = {
"Stats": {"total_count": 100},
"PerConv": [{"score": 1}, {"score": 2}],
}

output = results_to_dict(results)

assert output["Stats"] == {"total_count": 100}
assert output["PerConv"] == [{"score": 1}, {"score": 2}]
Loading