diff --git a/src/oumi/analyze/testing/batch_engine.py b/src/oumi/analyze/testing/batch_engine.py index eafc723b13..7cb0050cd6 100644 --- a/src/oumi/analyze/testing/batch_engine.py +++ b/src/oumi/analyze/testing/batch_engine.py @@ -289,7 +289,11 @@ def _build_final_result(self, acc: _TestAccumulator) -> TestResult: affected_count=len(affected_ids), total_count=acc.total_count, affected_percentage=round(affected_pct, 2), - threshold=test.max_percentage or test.min_percentage, + threshold=( + test.max_percentage + if test.max_percentage is not None + else test.min_percentage + ), actual_value=None, sample_indices=[], # Not meaningful for batch mode all_affected_indices=[], # Not meaningful for batch mode diff --git a/src/oumi/analyze/testing/engine.py b/src/oumi/analyze/testing/engine.py index a3fc5629cd..64f5e0335e 100644 --- a/src/oumi/analyze/testing/engine.py +++ b/src/oumi/analyze/testing/engine.py @@ -113,7 +113,11 @@ def _build_test_result( affected_count=len(affected_indices), total_count=total_count, affected_percentage=round(affected_pct, 2), - threshold=test.max_percentage or test.min_percentage, + threshold=( + test.max_percentage + if test.max_percentage is not None + else test.min_percentage + ), actual_value=actual_value, sample_indices=affected_indices[:MAX_SAMPLE_INDICES], all_affected_indices=affected_indices, @@ -179,15 +183,15 @@ def _run_single_test( if not test.metric: return self._create_error_result(test, "Test requires 'metric' field") - values = self._extract_metric_values(test.metric, results) + indexed_values = self._extract_metric_values(test.metric, results) - if not values: + if not indexed_values: return self._create_error_result( test, f"Metric '{test.metric}' not found in results" ) if test.type == TestType.THRESHOLD: - return self._run_threshold_test(test, values) + return self._run_threshold_test(test, indexed_values) else: return self._create_error_result(test, f"Unknown test type: {test.type}") @@ -195,8 +199,14 @@ def _extract_metric_values( self, metric: str, results: dict[str, list[BaseModel] | BaseModel], - ) -> list[Any]: - """Extract values for a metric path like "instance_id.field_name".""" + ) -> list[tuple[int, Any]]: + """Extract ``(original_index, value)`` pairs for a metric path. + + Preserving the original sample index ensures that ``sample_indices`` + and ``all_affected_indices`` in the resulting ``TestResult`` map back + to actual conversation positions, even if some samples are missing + the metric (and therefore filtered out). + """ parts = metric.split(".") if len(parts) < 2: return [] @@ -211,15 +221,15 @@ def _extract_metric_values( if isinstance(analyzer_results, BaseModel): value = self._get_nested_value(analyzer_results, field_path) - return [value] if value is not None else [] + return [(0, value)] if value is not None else [] - values = [] - for result in analyzer_results: + indexed_values: list[tuple[int, Any]] = [] + for i, result in enumerate(analyzer_results): value = self._get_nested_value(result, field_path) if value is not None: - values.append(value) + indexed_values.append((i, value)) - return values + return indexed_values def _get_nested_value(self, obj: Any, field_path: list[str]) -> Any: """Get a nested field value from a Pydantic model or dict.""" @@ -259,7 +269,7 @@ def _traverse_dict(self, d: dict, path: list[str]) -> Any | None: def _run_threshold_test( self, test: TestParams, - values: list[Any], + indexed_values: list[tuple[int, Any]], ) -> TestResult: """Run a threshold test against metric values.""" if test.operator is None or test.value is None: @@ -276,25 +286,25 @@ def _run_threshold_test( matching_reasons: dict[int, str] = {} non_matching_reasons: dict[int, str] = {} - for i, value in enumerate(values): + for orig_idx, value in indexed_values: try: if op_func(value, test.value): - matching_indices.append(i) - matching_reasons[i] = ( + matching_indices.append(orig_idx) + matching_reasons[orig_idx] = ( f"Flagged: {test.metric} {test.operator} {test.value}" f" (value={value})" ) else: - non_matching_indices.append(i) - non_matching_reasons[i] = ( + non_matching_indices.append(orig_idx) + non_matching_reasons[orig_idx] = ( f"Not flagged: {test.metric} {test.operator} {test.value}" f" (value={value})" ) except (TypeError, ValueError): - non_matching_indices.append(i) - non_matching_reasons[i] = f"Cannot evaluate: {value}" + non_matching_indices.append(orig_idx) + non_matching_reasons[orig_idx] = f"Cannot evaluate: {value}" - total_count = len(values) + total_count = len(indexed_values) matching_count = len(matching_indices) if total_count > 0: matching_pct = 100.0 * matching_count / total_count @@ -330,13 +340,15 @@ def _run_threshold_test( affected_pct = matching_pct failure_reasons = matching_reasons + raw_values = [v for _, v in indexed_values] + return self._build_test_result( test=test, passed=passed, total_count=total_count, affected_indices=affected_indices, affected_pct=affected_pct, - actual_value=self._get_actual_value(values), + actual_value=self._get_actual_value(raw_values), details={ "operator": test.operator, "value": test.value, diff --git a/tests/unit/analyze/test_testing_engine.py b/tests/unit/analyze/test_testing_engine.py index 1c228ae31b..74ffbde3fb 100644 --- a/tests/unit/analyze/test_testing_engine.py +++ b/tests/unit/analyze/test_testing_engine.py @@ -509,3 +509,55 @@ def test_multi_instance_metrics_resolve_correctly(): assert summary.total_tests == 2 assert summary.passed_tests == 2 + + +class _PartialMetrics(BaseModel): + """Metrics with an optional value for index-tracking tests.""" + + value: int | None = None + + +def test_sample_indices_map_to_original_positions_when_values_are_missing(): + """When some samples lack the metric, sample_indices must point to the + original conversation positions, not filtered-list positions.""" + results: dict[str, list[BaseModel] | BaseModel] = { + "m": [ + _PartialMetrics(value=1), + _PartialMetrics(value=None), # index 1 is skipped during extraction + _PartialMetrics(value=999), # the flagged one; must be reported as 2 + _PartialMetrics(value=2), + ] + } + tests = [ + TestParams( + id="flag_high", + type=TestType.THRESHOLD, + metric="m.value", + operator=">", + value=100, + ) + ] + summary = TestEngine(tests).run(results) + + result = summary.results[0] + assert result.passed is False + assert result.sample_indices == [2] + assert result.all_affected_indices == [2] + + +def test_threshold_with_max_percentage_zero_sets_threshold_field(): + """A max_percentage of 0 must surface as threshold=0.0, not fall through + to min_percentage (fixes the ``a or b`` truthiness bug).""" + results: dict[str, list[BaseModel] | BaseModel] = {"m": [_PartialMetrics(value=1)]} + tests = [ + TestParams( + id="zero_pct", + type=TestType.THRESHOLD, + metric="m.value", + operator=">", + value=100, + max_percentage=0.0, + ) + ] + summary = TestEngine(tests).run(results) + assert summary.results[0].threshold == 0.0