Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/google/adk/evaluation/eval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,13 @@ class MatchType(Enum):
),
)

ignore_args: bool = Field(
default=False,
description=(
"If True, only tool names are compared; arguments are ignored."
),
)

@field_validator("match_type", mode="before")
@classmethod
def _coerce_match_type(cls, value: object) -> object:
Expand Down
16 changes: 11 additions & 5 deletions src/google/adk/evaluation/trajectory_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(
)
self._threshold = criterion.threshold
self._match_type = criterion.match_type
self._ignore_args = criterion.ignore_args
except ValidationError as e:
expected_criterion_type_error = ValueError(
f"`{eval_metric.metric_name}` metric expects a criterion of type"
Expand All @@ -91,9 +92,11 @@ def __init__(
elif eval_metric:
self._threshold = eval_metric.threshold
self._match_type = ToolTrajectoryCriterion.MatchType.EXACT
self._ignore_args = False
else:
self._threshold = threshold
self._match_type = ToolTrajectoryCriterion.MatchType.EXACT
self._ignore_args = False

@override
def evaluate_invocations(
Expand Down Expand Up @@ -191,9 +194,8 @@ def _are_tool_calls_in_order_match(
try:
current_expected = next(expected_it)
for actual in actual_tool_calls:
if (
actual.name == current_expected.name
and actual.args == current_expected.args
if actual.name == current_expected.name and (
self._ignore_args or actual.args == current_expected.args
):
current_expected = next(expected_it)
except StopIteration:
Expand Down Expand Up @@ -229,7 +231,9 @@ def _are_tool_calls_any_order_match(
for expected in expected_tool_calls:
found = False
for i, actual in enumerate(actual_tool_calls_copy):
if actual.name == expected.name and actual.args == expected.args:
if actual.name == expected.name and (
self._ignore_args or actual.args == expected.args
):
actual_tool_calls_copy.pop(i)
found = True
break
Expand Down Expand Up @@ -260,7 +264,9 @@ def _are_tool_calls_exact_match(
return False

for actual, expected in zip(actual_tool_calls, expected_tool_calls):
if actual.name != expected.name or actual.args != expected.args:
if actual.name != expected.name or (
not self._ignore_args and actual.args != expected.args
):
return False

return True
Expand Down
327 changes: 327 additions & 0 deletions tests/unittests/evaluation/test_trajectory_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,330 @@ def test_evaluate_invocations_no_invocations(evaluator: TrajectoryEvaluator):
assert result.overall_score is None
assert result.overall_eval_status == EvalStatus.NOT_EVALUATED
assert not result.per_invocation_results


# --- ignore_args tests ---


def _make_ignore_args_evaluator(
match_type: ToolTrajectoryCriterion.MatchType,
) -> TrajectoryEvaluator:
return TrajectoryEvaluator(
eval_metric=EvalMetric(
metric_name=PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE.value,
criterion=ToolTrajectoryCriterion(
threshold=0.5,
match_type=match_type,
ignore_args=True,
),
)
)


def test_exact_ignore_args_passes_with_different_args():
ev = _make_ignore_args_evaluator(ToolTrajectoryCriterion.MatchType.EXACT)
actual = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[
genai_types.FunctionCall(name="t1", args={"a": 1}),
genai_types.FunctionCall(name="t2", args={"b": 2}),
]
),
)
expected = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[
genai_types.FunctionCall(name="t1", args={"x": 99}),
genai_types.FunctionCall(name="t2", args={"y": 100}),
]
),
)
result = ev.evaluate_invocations([actual], [expected])
assert result.overall_score == 1.0


def test_exact_ignore_args_fails_with_different_names():
ev = _make_ignore_args_evaluator(ToolTrajectoryCriterion.MatchType.EXACT)
actual = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[genai_types.FunctionCall(name="t1", args={})]
),
)
expected = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[genai_types.FunctionCall(name="t2", args={})]
),
)
result = ev.evaluate_invocations([actual], [expected])
assert result.overall_score == 0.0


def test_in_order_ignore_args_passes_with_different_args():
ev = _make_ignore_args_evaluator(ToolTrajectoryCriterion.MatchType.IN_ORDER)
actual = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[
genai_types.FunctionCall(name="t1", args={"a": 1}),
genai_types.FunctionCall(name="extra", args={}),
genai_types.FunctionCall(name="t2", args={"b": 2}),
]
),
)
expected = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[
genai_types.FunctionCall(name="t1", args={"x": 99}),
genai_types.FunctionCall(name="t2", args={"y": 100}),
]
),
)
result = ev.evaluate_invocations([actual], [expected])
assert result.overall_score == 1.0


def test_any_order_ignore_args_passes_with_different_args():
ev = _make_ignore_args_evaluator(ToolTrajectoryCriterion.MatchType.ANY_ORDER)
actual = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[
genai_types.FunctionCall(name="t2", args={"b": 2}),
genai_types.FunctionCall(name="t1", args={"a": 1}),
]
),
)
expected = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[
genai_types.FunctionCall(name="t1", args={"x": 99}),
genai_types.FunctionCall(name="t2", args={"y": 100}),
]
),
)
result = ev.evaluate_invocations([actual], [expected])
assert result.overall_score == 1.0


def test_any_order_ignore_args_fails_with_missing_tool():
ev = _make_ignore_args_evaluator(ToolTrajectoryCriterion.MatchType.ANY_ORDER)
actual = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[genai_types.FunctionCall(name="t1", args={})]
),
)
expected = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[
genai_types.FunctionCall(name="t1", args={}),
genai_types.FunctionCall(name="t2", args={}),
]
),
)
result = ev.evaluate_invocations([actual], [expected])
assert result.overall_score == 0.0


def test_ignore_args_from_dict_config():
"""Tests that ignore_args works when passed as a dict criterion."""
eval_metric = EvalMetric(
metric_name=PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE.value,
criterion={
"threshold": 0.5,
"match_type": "ANY_ORDER",
"ignore_args": True,
},
)
ev = TrajectoryEvaluator(eval_metric=eval_metric)
actual = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[genai_types.FunctionCall(name="t1", args={"a": 1})]
),
)
expected = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[genai_types.FunctionCall(name="t1", args={"z": 999})]
),
)
result = ev.evaluate_invocations([actual], [expected])
assert result.overall_score == 1.0


def test_exact_ignore_args_fails_with_different_count():
"""EXACT + ignore_args still fails when tool call counts differ."""
ev = _make_ignore_args_evaluator(ToolTrajectoryCriterion.MatchType.EXACT)
actual = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[
genai_types.FunctionCall(name="t1", args={}),
genai_types.FunctionCall(name="t2", args={}),
]
),
)
expected = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[genai_types.FunctionCall(name="t1", args={})]
),
)
result = ev.evaluate_invocations([actual], [expected])
assert result.overall_score == 0.0


def test_in_order_ignore_args_fails_with_wrong_order():
"""IN_ORDER + ignore_args still fails when order is wrong."""
ev = _make_ignore_args_evaluator(ToolTrajectoryCriterion.MatchType.IN_ORDER)
actual = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[
genai_types.FunctionCall(name="t2", args={}),
genai_types.FunctionCall(name="t1", args={}),
]
),
)
expected = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[
genai_types.FunctionCall(name="t1", args={}),
genai_types.FunctionCall(name="t2", args={}),
]
),
)
result = ev.evaluate_invocations([actual], [expected])
assert result.overall_score == 0.0


def test_in_order_ignore_args_fails_with_missing_tool():
"""IN_ORDER + ignore_args still fails when expected tool is missing."""
ev = _make_ignore_args_evaluator(ToolTrajectoryCriterion.MatchType.IN_ORDER)
actual = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[genai_types.FunctionCall(name="t1", args={})]
),
)
expected = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[
genai_types.FunctionCall(name="t1", args={}),
genai_types.FunctionCall(name="t2", args={}),
]
),
)
result = ev.evaluate_invocations([actual], [expected])
assert result.overall_score == 0.0


def test_ignore_args_false_still_checks_args():
"""Confirm ignore_args=False (default) still enforces arg matching."""
ev = TrajectoryEvaluator(
eval_metric=EvalMetric(
metric_name=PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE.value,
criterion=ToolTrajectoryCriterion(
threshold=0.5,
match_type=ToolTrajectoryCriterion.MatchType.EXACT,
ignore_args=False,
),
)
)
actual = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[genai_types.FunctionCall(name="t1", args={"a": 1})]
),
)
expected = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[genai_types.FunctionCall(name="t1", args={"a": 2})]
),
)
result = ev.evaluate_invocations([actual], [expected])
assert result.overall_score == 0.0


def test_ignore_args_empty_tool_lists():
"""ignore_args with empty tool lists on both sides should pass."""
ev = _make_ignore_args_evaluator(ToolTrajectoryCriterion.MatchType.EXACT)
inv = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(tool_uses=[]),
)
result = ev.evaluate_invocations([inv], [inv])
assert result.overall_score == 1.0


def test_ignore_args_multiple_invocations_mixed():
"""ignore_args with multiple invocations: one matches, one doesn't."""
ev = _make_ignore_args_evaluator(ToolTrajectoryCriterion.MatchType.EXACT)
inv1_actual = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[genai_types.FunctionCall(name="t1", args={"a": 1})]
),
)
inv1_expected = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[genai_types.FunctionCall(name="t1", args={"z": 99})]
),
)
inv2_actual = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[genai_types.FunctionCall(name="t1", args={})]
),
)
inv2_expected = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[genai_types.FunctionCall(name="t2", args={})]
),
)
result = ev.evaluate_invocations(
[inv1_actual, inv2_actual], [inv1_expected, inv2_expected]
)
assert result.overall_score == 0.5
assert result.per_invocation_results[0].score == 1.0
assert result.per_invocation_results[1].score == 0.0


def test_ignore_args_with_camel_case_dict_config():
"""Tests ignore_args works via camelCase key (ignoreArgs) in dict."""
eval_metric = EvalMetric(
metric_name=PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE.value,
criterion={
"threshold": 0.5,
"matchType": "EXACT",
"ignoreArgs": True,
},
)
ev = TrajectoryEvaluator(eval_metric=eval_metric)
actual = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[genai_types.FunctionCall(name="t1", args={"a": 1})]
),
)
expected = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[genai_types.FunctionCall(name="t1", args={"z": 999})]
),
)
result = ev.evaluate_invocations([actual], [expected])
assert result.overall_score == 1.0
Loading