Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion codeflash/languages/java/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,58 @@ def _infer_array_cast_type(line: str) -> str | None:
return None


def _erase_method_type_params(return_type: str, method_node: Any) -> str:
"""Erase method-level type parameters from the return type.

Generic methods like ``<T extends Comparable<T>> List<T> mergeSorted(...)``
declare type parameters that only exist within the method scope. When the
return type (e.g. ``List<T>``) is used as a cast in an instrumented test
class, those type variables are not in scope and cause compilation errors.

This function detects method-level type parameters via the tree-sitter AST
and replaces any occurrences in the return type with the wildcard ``?``.
If the return type *itself* is a bare type variable (e.g. ``T``), the type
is erased to ``Object``.
"""
# Find method-level type_parameters node via tree-sitter AST
ts_node = getattr(method_node, "node", None)
if ts_node is None:
return return_type

type_params_node = None
for child in ts_node.children:
if child.type == "type_parameters":
type_params_node = child
break

if type_params_node is None:
return return_type

# Collect declared type variable names (e.g. T, E, K, V)
type_var_names: set[str] = set()
for child in type_params_node.children:
if child.type == "type_parameter":
name_node = child.child_by_field_name("name") or (child.children[0] if child.children else None)
if name_node:
type_var_names.add(
name_node.text.decode("utf8") if isinstance(name_node.text, bytes) else str(name_node.text)
)

if not type_var_names:
return return_type

# If the entire return type is a bare type variable, erase to Object
if return_type.strip() in type_var_names:
return "Object"

# Replace type variables used as generic arguments with '?'
# Match whole-word type variable names that appear as generic type arguments
for tv in type_var_names:
return_type = re.sub(rf"\b{re.escape(tv)}\b", "?", return_type)

return return_type


def _extract_return_type(function_to_optimize: Any) -> str:
"""Extract the return type of a Java function from its source file using tree-sitter."""
file_path = getattr(function_to_optimize, "file_path", None)
Expand All @@ -522,7 +574,7 @@ def _extract_return_type(function_to_optimize: Any) -> str:
methods = analyzer.find_methods(source_text)
for method in methods:
if method.name == func_name and method.return_type:
return method.return_type
return _erase_method_type_params(method.return_type, method)
except Exception:
logger.debug("Could not extract return type for %s", func_name)
return ""
Expand Down
113 changes: 113 additions & 0 deletions tests/test_languages/test_java/test_instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from codeflash.languages.java.instrumentation import (
_add_behavior_instrumentation,
_add_timing_instrumentation,
_erase_method_type_params,
_extract_return_type,
create_benchmark_test,
instrument_existing_test,
instrument_for_behavior,
Expand Down Expand Up @@ -3485,3 +3487,114 @@ def __init__(self, path):
assert math.isclose(duration, 100_000_000, rel_tol=0.15), (
f"Long spin measured {duration}ns, expected ~100_000_000ns (15% tolerance)"
)


class TestEraseMethodTypeParams:
"""Tests for _erase_method_type_params — erasing method-level type variables from return types."""

def test_generic_return_type_list(self):
"""Generic method <T> List<T> should have T erased to ? in return type."""
source = """public class CollectionUtils {
public static <T extends Comparable<T>> List<T> mergeSorted(List<T> a, List<T> b) {
return null;
}
}
"""
from codeflash.languages.java.parser import get_java_analyzer

analyzer = get_java_analyzer()
methods = analyzer.find_methods(source)
assert len(methods) == 1
result = _erase_method_type_params(methods[0].return_type, methods[0])
assert result == "List<?>", f"Expected 'List<?>' but got '{result}'"

def test_bare_type_variable_erased_to_object(self):
"""Generic method <T> T max(...) should erase bare T to Object."""
source = """public class Utils {
public static <T extends Comparable<T>> T max(T a, T b) {
return a.compareTo(b) >= 0 ? a : b;
}
}
"""
from codeflash.languages.java.parser import get_java_analyzer

analyzer = get_java_analyzer()
methods = analyzer.find_methods(source)
assert len(methods) == 1
result = _erase_method_type_params(methods[0].return_type, methods[0])
assert result == "Object", f"Expected 'Object' but got '{result}'"

def test_multiple_type_params(self):
"""Generic method <K, V> Map<K, V> should erase both K and V."""
source = """public class Utils {
public static <K, V> Map<K, V> combine(Map<K, V> a, Map<K, V> b) {
return null;
}
}
"""
from codeflash.languages.java.parser import get_java_analyzer

analyzer = get_java_analyzer()
methods = analyzer.find_methods(source)
assert len(methods) == 1
result = _erase_method_type_params(methods[0].return_type, methods[0])
assert result == "Map<?, ?>", f"Expected 'Map<?, ?>' but got '{result}'"

def test_non_generic_method_unchanged(self):
"""Non-generic method return type should be unchanged."""
source = """public class Calculator {
public int add(int a, int b) {
return a + b;
}
}
"""
from codeflash.languages.java.parser import get_java_analyzer

analyzer = get_java_analyzer()
methods = analyzer.find_methods(source)
assert len(methods) == 1
result = _erase_method_type_params("int", methods[0])
assert result == "int", f"Expected 'int' but got '{result}'"

def test_class_level_generics_not_erased(self):
"""Class-level type params should NOT be erased (only method-level ones)."""
source = """public class Box<T> {
public T getValue() {
return null;
}
}
"""
from codeflash.languages.java.parser import get_java_analyzer

analyzer = get_java_analyzer()
methods = analyzer.find_methods(source)
assert len(methods) == 1
# T is a class-level param, not method-level — should not be erased
result = _erase_method_type_params("T", methods[0])
assert result == "T", f"Expected 'T' (class-level generic unchanged) but got '{result}'"


class TestExtractReturnTypeGeneric:
"""Test that _extract_return_type erases method-level type params."""

def test_extract_return_type_generic_method(self, tmp_path):
"""_extract_return_type should return erased type for generic methods."""
java_file = tmp_path / "CollectionUtils.java"
java_file.write_text("""package com.example;
import java.util.List;

public class CollectionUtils {
public static <T extends Comparable<T>> List<T> mergeSorted(List<T> a, List<T> b) {
return null;
}
}
""")

class FakeFunc:
file_path = java_file
function_name = "mergeSorted"
qualified_name = "CollectionUtils.mergeSorted"
parents = []

result = _extract_return_type(FakeFunc())
assert result == "List<?>", f"Expected 'List<?>' but got '{result}'"
Loading