Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion codeflash/languages/java/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,58 @@ def _infer_array_cast_type(line: str) -> str | None:
return None


def _erase_method_type_params(return_type: str, method_node: Any) -> str:
"""Erase method-level type parameters from the return type.

Generic methods like ``<T extends Comparable<T>> List<T> mergeSorted(...)``
declare type parameters that only exist within the method scope. When the
return type (e.g. ``List<T>``) is used as a cast in an instrumented test
class, those type variables are not in scope and cause compilation errors.

This function detects method-level type parameters via the tree-sitter AST
and replaces any occurrences in the return type with the wildcard ``?``.
If the return type *itself* is a bare type variable (e.g. ``T``), the type
is erased to ``Object``.
"""
# Find method-level type_parameters node via tree-sitter AST
ts_node = getattr(method_node, "node", None)
if ts_node is None:
return return_type

type_params_node = None
for child in ts_node.children:
if child.type == "type_parameters":
type_params_node = child
break

if type_params_node is None:
return return_type

# Collect declared type variable names (e.g. T, E, K, V)
type_var_names: set[str] = set()
for child in type_params_node.children:
if child.type == "type_parameter":
name_node = child.child_by_field_name("name") or (child.children[0] if child.children else None)
if name_node:
type_var_names.add(
name_node.text.decode("utf8") if isinstance(name_node.text, bytes) else str(name_node.text)
)

if not type_var_names:
return return_type

# If the entire return type is a bare type variable, erase to Object
if return_type.strip() in type_var_names:
return "Object"

# Replace type variables used as generic arguments with '?'
# Match whole-word type variable names that appear as generic type arguments
for tv in type_var_names:
return_type = re.sub(rf"\b{re.escape(tv)}\b", "?", return_type)

return return_type


def _extract_return_type(function_to_optimize: Any) -> str:
"""Extract the return type of a Java function from its source file using tree-sitter."""
file_path = getattr(function_to_optimize, "file_path", None)
Expand All @@ -522,7 +574,7 @@ def _extract_return_type(function_to_optimize: Any) -> str:
methods = analyzer.find_methods(source_text)
for method in methods:
if method.name == func_name and method.return_type:
return method.return_type
return _erase_method_type_params(method.return_type, method)
except Exception:
logger.debug("Could not extract return type for %s", func_name)
return ""
Expand Down
Loading
Loading