Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/oumi/analyze/testing/batch_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,11 @@ def _build_final_result(self, acc: _TestAccumulator) -> TestResult:
affected_count=len(affected_ids),
total_count=acc.total_count,
affected_percentage=round(affected_pct, 2),
threshold=test.max_percentage or test.min_percentage,
threshold=(
test.max_percentage
if test.max_percentage is not None
else test.min_percentage
),
actual_value=None,
sample_indices=[], # Not meaningful for batch mode
all_affected_indices=[], # Not meaningful for batch mode
Expand Down
54 changes: 33 additions & 21 deletions src/oumi/analyze/testing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,11 @@ def _build_test_result(
affected_count=len(affected_indices),
total_count=total_count,
affected_percentage=round(affected_pct, 2),
threshold=test.max_percentage or test.min_percentage,
threshold=(
test.max_percentage
if test.max_percentage is not None
else test.min_percentage
),
actual_value=actual_value,
sample_indices=affected_indices[:MAX_SAMPLE_INDICES],
all_affected_indices=affected_indices,
Expand Down Expand Up @@ -179,24 +183,30 @@ def _run_single_test(
if not test.metric:
return self._create_error_result(test, "Test requires 'metric' field")

values = self._extract_metric_values(test.metric, results)
indexed_values = self._extract_metric_values(test.metric, results)

if not values:
if not indexed_values:
return self._create_error_result(
test, f"Metric '{test.metric}' not found in results"
)

if test.type == TestType.THRESHOLD:
return self._run_threshold_test(test, values)
return self._run_threshold_test(test, indexed_values)
else:
return self._create_error_result(test, f"Unknown test type: {test.type}")

def _extract_metric_values(
self,
metric: str,
results: dict[str, list[BaseModel] | BaseModel],
) -> list[Any]:
"""Extract values for a metric path like "instance_id.field_name"."""
) -> list[tuple[int, Any]]:
"""Extract ``(original_index, value)`` pairs for a metric path.

Preserving the original sample index ensures that ``sample_indices``
and ``all_affected_indices`` in the resulting ``TestResult`` map back
to actual conversation positions, even if some samples are missing
the metric (and therefore filtered out).
"""
parts = metric.split(".")
if len(parts) < 2:
return []
Expand All @@ -211,15 +221,15 @@ def _extract_metric_values(

if isinstance(analyzer_results, BaseModel):
value = self._get_nested_value(analyzer_results, field_path)
return [value] if value is not None else []
return [(0, value)] if value is not None else []

values = []
for result in analyzer_results:
indexed_values: list[tuple[int, Any]] = []
for i, result in enumerate(analyzer_results):
value = self._get_nested_value(result, field_path)
if value is not None:
values.append(value)
indexed_values.append((i, value))

return values
return indexed_values

def _get_nested_value(self, obj: Any, field_path: list[str]) -> Any:
"""Get a nested field value from a Pydantic model or dict."""
Expand Down Expand Up @@ -259,7 +269,7 @@ def _traverse_dict(self, d: dict, path: list[str]) -> Any | None:
def _run_threshold_test(
self,
test: TestParams,
values: list[Any],
indexed_values: list[tuple[int, Any]],
) -> TestResult:
"""Run a threshold test against metric values."""
if test.operator is None or test.value is None:
Expand All @@ -276,25 +286,25 @@ def _run_threshold_test(
matching_reasons: dict[int, str] = {}
non_matching_reasons: dict[int, str] = {}

for i, value in enumerate(values):
for orig_idx, value in indexed_values:
try:
if op_func(value, test.value):
matching_indices.append(i)
matching_reasons[i] = (
matching_indices.append(orig_idx)
matching_reasons[orig_idx] = (
f"Flagged: {test.metric} {test.operator} {test.value}"
f" (value={value})"
)
else:
non_matching_indices.append(i)
non_matching_reasons[i] = (
non_matching_indices.append(orig_idx)
non_matching_reasons[orig_idx] = (
f"Not flagged: {test.metric} {test.operator} {test.value}"
f" (value={value})"
)
except (TypeError, ValueError):
non_matching_indices.append(i)
non_matching_reasons[i] = f"Cannot evaluate: {value}"
non_matching_indices.append(orig_idx)
non_matching_reasons[orig_idx] = f"Cannot evaluate: {value}"

total_count = len(values)
total_count = len(indexed_values)
matching_count = len(matching_indices)
if total_count > 0:
matching_pct = 100.0 * matching_count / total_count
Expand Down Expand Up @@ -330,13 +340,15 @@ def _run_threshold_test(
affected_pct = matching_pct
failure_reasons = matching_reasons

raw_values = [v for _, v in indexed_values]

return self._build_test_result(
test=test,
passed=passed,
total_count=total_count,
affected_indices=affected_indices,
affected_pct=affected_pct,
actual_value=self._get_actual_value(values),
actual_value=self._get_actual_value(raw_values),
details={
"operator": test.operator,
"value": test.value,
Expand Down
52 changes: 52 additions & 0 deletions tests/unit/analyze/test_testing_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,3 +509,55 @@ def test_multi_instance_metrics_resolve_correctly():

assert summary.total_tests == 2
assert summary.passed_tests == 2


class _PartialMetrics(BaseModel):
"""Metrics with an optional value for index-tracking tests."""

value: int | None = None


def test_sample_indices_map_to_original_positions_when_values_are_missing():
"""When some samples lack the metric, sample_indices must point to the
original conversation positions, not filtered-list positions."""
results: dict[str, list[BaseModel] | BaseModel] = {
"m": [
_PartialMetrics(value=1),
_PartialMetrics(value=None), # index 1 is skipped during extraction
_PartialMetrics(value=999), # the flagged one; must be reported as 2
_PartialMetrics(value=2),
]
}
tests = [
TestParams(
id="flag_high",
type=TestType.THRESHOLD,
metric="m.value",
operator=">",
value=100,
)
]
summary = TestEngine(tests).run(results)

result = summary.results[0]
assert result.passed is False
assert result.sample_indices == [2]
assert result.all_affected_indices == [2]


def test_threshold_with_max_percentage_zero_sets_threshold_field():
"""A max_percentage of 0 must surface as threshold=0.0, not fall through
to min_percentage (fixes the ``a or b`` truthiness bug)."""
results: dict[str, list[BaseModel] | BaseModel] = {"m": [_PartialMetrics(value=1)]}
tests = [
TestParams(
id="zero_pct",
type=TestType.THRESHOLD,
metric="m.value",
operator=">",
value=100,
max_percentage=0.0,
)
]
summary = TestEngine(tests).run(results)
assert summary.results[0].threshold == 0.0
Loading