diff --git a/codeflash/languages/java/instrumentation.py b/codeflash/languages/java/instrumentation.py index c1a8ae683..0e2e556a1 100644 --- a/codeflash/languages/java/instrumentation.py +++ b/codeflash/languages/java/instrumentation.py @@ -508,6 +508,58 @@ def _infer_array_cast_type(line: str) -> str | None: return None +def _erase_method_type_params(return_type: str, method_node: Any) -> str: + """Erase method-level type parameters from the return type. + + Generic methods like ``> List mergeSorted(...)`` + declare type parameters that only exist within the method scope. When the + return type (e.g. ``List``) is used as a cast in an instrumented test + class, those type variables are not in scope and cause compilation errors. + + This function detects method-level type parameters via the tree-sitter AST + and replaces any occurrences in the return type with the wildcard ``?``. + If the return type *itself* is a bare type variable (e.g. ``T``), the type + is erased to ``Object``. + """ + # Find method-level type_parameters node via tree-sitter AST + ts_node = getattr(method_node, "node", None) + if ts_node is None: + return return_type + + type_params_node = None + for child in ts_node.children: + if child.type == "type_parameters": + type_params_node = child + break + + if type_params_node is None: + return return_type + + # Collect declared type variable names (e.g. T, E, K, V) + type_var_names: set[str] = set() + for child in type_params_node.children: + if child.type == "type_parameter": + name_node = child.child_by_field_name("name") or (child.children[0] if child.children else None) + if name_node: + type_var_names.add( + name_node.text.decode("utf8") if isinstance(name_node.text, bytes) else str(name_node.text) + ) + + if not type_var_names: + return return_type + + # If the entire return type is a bare type variable, erase to Object + if return_type.strip() in type_var_names: + return "Object" + + # Replace type variables used as generic arguments with '?' + # Match whole-word type variable names that appear as generic type arguments + for tv in type_var_names: + return_type = re.sub(rf"\b{re.escape(tv)}\b", "?", return_type) + + return return_type + + def _extract_return_type(function_to_optimize: Any) -> str: """Extract the return type of a Java function from its source file using tree-sitter.""" file_path = getattr(function_to_optimize, "file_path", None) @@ -522,7 +574,7 @@ def _extract_return_type(function_to_optimize: Any) -> str: methods = analyzer.find_methods(source_text) for method in methods: if method.name == func_name and method.return_type: - return method.return_type + return _erase_method_type_params(method.return_type, method) except Exception: logger.debug("Could not extract return type for %s", func_name) return "" diff --git a/tests/test_languages/test_java/test_instrumentation.py b/tests/test_languages/test_java/test_instrumentation.py index 4290766db..304b0a75c 100644 --- a/tests/test_languages/test_java/test_instrumentation.py +++ b/tests/test_languages/test_java/test_instrumentation.py @@ -22,7 +22,6 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.base import Language from codeflash.languages.current import set_current_language -from codeflash.languages.java.maven_strategy import MavenStrategy from codeflash.languages.java.discovery import discover_functions_from_source from codeflash.languages.java.instrumentation import ( _add_behavior_instrumentation, @@ -34,6 +33,7 @@ instrument_generated_java_test, remove_instrumentation, ) +from codeflash.languages.java.maven_strategy import MavenStrategy class TestInstrumentForBehavior: @@ -2177,7 +2177,7 @@ def test_instrument_with_multibyte_in_comment(self, tmp_path: Path): # Skip all E2E tests if Maven is not available requires_maven = pytest.mark.skipif( - MavenStrategy().find_executable(Path(".")) is None, reason="Maven not found - skipping execution tests" + MavenStrategy().find_executable(Path()) is None, reason="Maven not found - skipping execution tests" ) @@ -3485,3 +3485,337 @@ def __init__(self, path): assert math.isclose(duration, 100_000_000, rel_tol=0.15), ( f"Long spin measured {duration}ns, expected ~100_000_000ns (15% tolerance)" ) + + +class TestGenericMethodTypeErasureInstrumentation: + """Tests that generic method type parameters are erased in instrumented output.""" + + def test_generic_list_return_type_erased_in_behavior_cast(self, tmp_path): + """Generic List return type should produce (List)cast in behavior mode.""" + src_file = (tmp_path / "CollectionUtils.java").resolve() + src_file.write_text( + """package com.example; +import java.util.List; + +public class CollectionUtils { + public static > List mergeSorted(List a, List b) { + return null; + } +} +""", + encoding="utf-8", + ) + + test_source = """package com.example; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; +import java.util.List; +import java.util.Arrays; + +public class CollectionUtilsTest { + @Test + public void testMergeSorted() { + assertEquals(Arrays.asList(1, 2, 3, 4), CollectionUtils.mergeSorted(Arrays.asList(1, 3), Arrays.asList(2, 4))); + } +} +""" + test_file = (tmp_path / "CollectionUtilsTest.java").resolve() + test_file.write_text(test_source, encoding="utf-8") + + func_info = FunctionToOptimize( + function_name="mergeSorted", + file_path=src_file, + starting_line=5, + ending_line=7, + parents=[], + is_method=False, + language="java", + ) + + success, instrumented = instrument_existing_test( + test_string=test_source, function_to_optimize=func_info, mode="behavior", test_path=test_file + ) + assert success + + expected_instrumented = """package com.example; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; +import java.util.List; +import java.util.Arrays; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; + +@SuppressWarnings("CheckReturnValue") +public class CollectionUtilsTest__perfinstrumented { + @Test + public void testMergeSorted() { + // Codeflash behavior instrumentation + int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX")); + int _cf_iter1 = 1; + String _cf_mod1 = "CollectionUtilsTest__perfinstrumented"; + String _cf_cls1 = "CollectionUtilsTest__perfinstrumented"; + String _cf_fn1 = "mergeSorted"; + String _cf_outputFile1 = System.getenv("CODEFLASH_OUTPUT_FILE"); + String _cf_testIteration1 = System.getenv("CODEFLASH_TEST_ITERATION"); + if (_cf_testIteration1 == null) _cf_testIteration1 = "0"; + String _cf_test1 = "testMergeSorted"; + Object _cf_result1_1 = null; + long _cf_end1_1 = -1; + long _cf_start1_1 = 0; + byte[] _cf_serializedResult1_1 = null; + System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":L15_1" + "######$!"); + try { + _cf_start1_1 = System.nanoTime(); + _cf_result1_1 = CollectionUtils.mergeSorted(Arrays.asList(1, 3), Arrays.asList(2, 4)); + _cf_end1_1 = System.nanoTime(); + _cf_serializedResult1_1 = com.codeflash.Serializer.serialize((Object) _cf_result1_1); + } finally { + long _cf_end1_1_finally = System.nanoTime(); + long _cf_dur1_1 = (_cf_end1_1 != -1 ? _cf_end1_1 : _cf_end1_1_finally) - _cf_start1_1; + System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + "L15_1" + "######!"); + // Write to SQLite if output file is set + if (_cf_outputFile1 != null && !_cf_outputFile1.isEmpty()) { + try { + Class.forName("org.sqlite.JDBC"); + try (Connection _cf_conn1_1 = DriverManager.getConnection("jdbc:sqlite:" + _cf_outputFile1)) { + try (java.sql.Statement _cf_stmt1_1 = _cf_conn1_1.createStatement()) { + _cf_stmt1_1.execute("CREATE TABLE IF NOT EXISTS test_results (" + + "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " + + "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " + + "runtime INTEGER, return_value BLOB, verification_type TEXT)"); + } + String _cf_sql1_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; + try (PreparedStatement _cf_pstmt1_1 = _cf_conn1_1.prepareStatement(_cf_sql1_1)) { + _cf_pstmt1_1.setString(1, _cf_mod1); + _cf_pstmt1_1.setString(2, _cf_cls1); + _cf_pstmt1_1.setString(3, _cf_test1); + _cf_pstmt1_1.setString(4, _cf_fn1); + _cf_pstmt1_1.setInt(5, _cf_loop1); + _cf_pstmt1_1.setString(6, "L15_1"); + _cf_pstmt1_1.setLong(7, _cf_dur1_1); + _cf_pstmt1_1.setBytes(8, _cf_serializedResult1_1); + _cf_pstmt1_1.setString(9, "function_call"); + _cf_pstmt1_1.executeUpdate(); + } + } + } catch (Exception _cf_e1_1) { + System.err.println("CodeflashHelper: SQLite error: " + _cf_e1_1.getMessage()); + } + } + } + assertEquals(Arrays.asList(1, 2, 3, 4), (List)_cf_result1_1); + } +} +""" + assert instrumented == expected_instrumented + + def test_bare_type_variable_erased_to_object_in_behavior_cast(self, tmp_path): + """Generic T return type should produce (Object)cast in behavior mode.""" + src_file = (tmp_path / "Utils.java").resolve() + src_file.write_text( + """package com.example; + +public class Utils { + public static > T max(T a, T b) { + return a.compareTo(b) >= 0 ? a : b; + } +} +""", + encoding="utf-8", + ) + + test_source = """package com.example; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; + +public class UtilsTest { + @Test + public void testMax() { + assertEquals(5, Utils.max(3, 5)); + } +} +""" + test_file = (tmp_path / "UtilsTest.java").resolve() + test_file.write_text(test_source, encoding="utf-8") + + func_info = FunctionToOptimize( + function_name="max", + file_path=src_file, + starting_line=4, + ending_line=6, + parents=[], + is_method=False, + language="java", + ) + + success, instrumented = instrument_existing_test( + test_string=test_source, function_to_optimize=func_info, mode="behavior", test_path=test_file + ) + assert success + + expected_instrumented = """package com.example; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; + +@SuppressWarnings("CheckReturnValue") +public class UtilsTest__perfinstrumented { + @Test + public void testMax() { + // Codeflash behavior instrumentation + int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX")); + int _cf_iter1 = 1; + String _cf_mod1 = "UtilsTest__perfinstrumented"; + String _cf_cls1 = "UtilsTest__perfinstrumented"; + String _cf_fn1 = "max"; + String _cf_outputFile1 = System.getenv("CODEFLASH_OUTPUT_FILE"); + String _cf_testIteration1 = System.getenv("CODEFLASH_TEST_ITERATION"); + if (_cf_testIteration1 == null) _cf_testIteration1 = "0"; + String _cf_test1 = "testMax"; + Object _cf_result1_1 = null; + long _cf_end1_1 = -1; + long _cf_start1_1 = 0; + byte[] _cf_serializedResult1_1 = null; + System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":L13_1" + "######$!"); + try { + _cf_start1_1 = System.nanoTime(); + _cf_result1_1 = Utils.max(3, 5); + _cf_end1_1 = System.nanoTime(); + _cf_serializedResult1_1 = com.codeflash.Serializer.serialize((Object) _cf_result1_1); + } finally { + long _cf_end1_1_finally = System.nanoTime(); + long _cf_dur1_1 = (_cf_end1_1 != -1 ? _cf_end1_1 : _cf_end1_1_finally) - _cf_start1_1; + System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + "L13_1" + "######!"); + // Write to SQLite if output file is set + if (_cf_outputFile1 != null && !_cf_outputFile1.isEmpty()) { + try { + Class.forName("org.sqlite.JDBC"); + try (Connection _cf_conn1_1 = DriverManager.getConnection("jdbc:sqlite:" + _cf_outputFile1)) { + try (java.sql.Statement _cf_stmt1_1 = _cf_conn1_1.createStatement()) { + _cf_stmt1_1.execute("CREATE TABLE IF NOT EXISTS test_results (" + + "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " + + "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " + + "runtime INTEGER, return_value BLOB, verification_type TEXT)"); + } + String _cf_sql1_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; + try (PreparedStatement _cf_pstmt1_1 = _cf_conn1_1.prepareStatement(_cf_sql1_1)) { + _cf_pstmt1_1.setString(1, _cf_mod1); + _cf_pstmt1_1.setString(2, _cf_cls1); + _cf_pstmt1_1.setString(3, _cf_test1); + _cf_pstmt1_1.setString(4, _cf_fn1); + _cf_pstmt1_1.setInt(5, _cf_loop1); + _cf_pstmt1_1.setString(6, "L13_1"); + _cf_pstmt1_1.setLong(7, _cf_dur1_1); + _cf_pstmt1_1.setBytes(8, _cf_serializedResult1_1); + _cf_pstmt1_1.setString(9, "function_call"); + _cf_pstmt1_1.executeUpdate(); + } + } + } catch (Exception _cf_e1_1) { + System.err.println("CodeflashHelper: SQLite error: " + _cf_e1_1.getMessage()); + } + } + } + assertEquals(5, (Object)_cf_result1_1); + } +} +""" + assert instrumented == expected_instrumented + + def test_generic_return_type_performance_mode(self, tmp_path): + """Generic method in performance mode should compile without type variable errors.""" + src_file = (tmp_path / "CollectionUtils.java").resolve() + src_file.write_text( + """package com.example; +import java.util.List; + +public class CollectionUtils { + public static > List mergeSorted(List a, List b) { + return null; + } +} +""", + encoding="utf-8", + ) + + test_source = """package com.example; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; +import java.util.List; +import java.util.Arrays; + +public class CollectionUtilsTest { + @Test + public void testMergeSorted() { + List result = CollectionUtils.mergeSorted(Arrays.asList(1, 3), Arrays.asList(2, 4)); + assertEquals(Arrays.asList(1, 2, 3, 4), result); + } +} +""" + test_file = (tmp_path / "CollectionUtilsTest.java").resolve() + test_file.write_text(test_source, encoding="utf-8") + + func_info = FunctionToOptimize( + function_name="mergeSorted", + file_path=src_file, + starting_line=5, + ending_line=7, + parents=[], + is_method=False, + language="java", + ) + + success, instrumented = instrument_existing_test( + test_string=test_source, function_to_optimize=func_info, mode="performance", test_path=test_file + ) + assert success + + expected_instrumented = """package com.example; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; +import java.util.List; +import java.util.Arrays; + +@SuppressWarnings("CheckReturnValue") +public class CollectionUtilsTest__perfonlyinstrumented { + @Test + public void testMergeSorted() { + // Codeflash timing instrumentation with inner loop for JIT warmup + int _cf_outerLoop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX")); + int _cf_maxInnerIterations1 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "10")); + int _cf_innerIterations1 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "10")); + String _cf_mod1 = "CollectionUtilsTest__perfonlyinstrumented"; + String _cf_cls1 = "CollectionUtilsTest__perfonlyinstrumented"; + String _cf_test1 = "testMergeSorted"; + String _cf_fn1 = "mergeSorted"; + + List result = null; + for (int _cf_i1 = 0; _cf_i1 < _cf_innerIterations1; _cf_i1++) { + int _cf_loopId1 = _cf_outerLoop1 * _cf_maxInnerIterations1 + _cf_i1; + System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loopId1 + ":" + "L12_1" + "######$!"); + long _cf_end1 = -1; + long _cf_start1 = 0; + try { + _cf_start1 = System.nanoTime(); + result = CollectionUtils.mergeSorted(Arrays.asList(1, 3), Arrays.asList(2, 4)); + _cf_end1 = System.nanoTime(); + } finally { + long _cf_end1_finally = System.nanoTime(); + long _cf_dur1 = (_cf_end1 != -1 ? _cf_end1 : _cf_end1_finally) - _cf_start1; + System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loopId1 + ":" + "L12_1" + ":" + _cf_dur1 + "######!"); + } + } + assertEquals(Arrays.asList(1, 2, 3, 4), result); + } +} +""" + assert instrumented == expected_instrumented