Skip to content

⚡️ Speed up method JavaAssertTransformer._generate_replacement by 25% in PR #1980 (cf-java-void-optimization)#1983

Merged
claude[bot] merged 1 commit intocf-java-void-optimizationfrom
codeflash/optimize-pr1980-2026-04-03T13.47.39
Apr 3, 2026
Merged

⚡️ Speed up method JavaAssertTransformer._generate_replacement by 25% in PR #1980 (cf-java-void-optimization)#1983
claude[bot] merged 1 commit intocf-java-void-optimizationfrom
codeflash/optimize-pr1980-2026-04-03T13.47.39

Conversation

@codeflash-ai
Copy link
Copy Markdown
Contributor

@codeflash-ai codeflash-ai bot commented Apr 3, 2026

⚡️ This pull request contains optimizations for PR #1980

If you approve this dependent PR, these changes will be merged into the original PR branch cf-java-void-optimization.

This PR will be automatically closed if the original PR is merged.


📄 25% (0.25x) speedup for JavaAssertTransformer._generate_replacement in codeflash/languages/java/remove_asserts.py

⏱️ Runtime : 886 microseconds 708 microseconds (best of 250 runs)

📝 Explanation and details

The optimization added a dictionary cache (_type_infer_cache) to memoize the results of _infer_type_from_assertion_args, which is an expensive method involving regex operations and string parsing. Before, every call to _infer_return_type for a value assertion (e.g., assertEquals) would re-parse the assertion's original text, spending ~92% of the method's runtime in _infer_type_from_assertion_args. With caching keyed on (original_text, method), repeated assertions with identical text reuse the inferred type, cutting _infer_return_type time by 59% (1.38 ms → 570 µs) and overall runtime by 20% (886 µs → 708 µs). No functional regressions observed across all test cases.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 281 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
from collections import namedtuple

# imports
# Import the class under test from the real module
from codeflash.languages.java.remove_asserts import AssertionMatch, JavaAssertTransformer

# Helper lightweight real classes: using namedtuple creates real classes (not mocks)
# Call object used in assertions to represent a target call with a `.full_call` attribute.
Call = namedtuple("Call", ["full_call"])

# AssertionMatch-like structure (real class via namedtuple) that matches the attributes
# expected by JavaAssertTransformer._generate_replacement.
AssertionMatch = namedtuple(
    "AssertionMatch",
    [
        "is_exception_assertion",  # bool: whether this is an exception-style assertion
        "target_calls",  # list[Call] | None: the calls to preserve
        "leading_whitespace",  # str: whitespace to prepend to first replaced line
        "assertion_method",  # str: name of the assertion method (inference depends on it)
        "original_text",  # str: original assertion text (used for some inference)
        "lambda_body",  # str | None: body of a lambda for assertThrows style assertions
        "assigned_var_name",  # str | None: variable name when assertion is assigned
        "assigned_var_type",  # str | None: variable type when assertion is assigned
        "exception_class",  # str | None: explicit exception class provided to assertThrows
    ],
)


def test_strip_mode_single_call_returns_bare_call_with_semicolon():
    # Create a transformer in "strip" mode so it emits bare calls (no captures).
    t = JavaAssertTransformer(function_name="f", mode="strip")
    # Single target call; leading whitespace preserved.
    call = Call(full_call="calculator.divide(1, 0)")
    am = AssertionMatch(
        is_exception_assertion=False,
        target_calls=[call],
        leading_whitespace="    ",  # 4 spaces indentation
        assertion_method="assertEquals",
        original_text="assertEquals(1, calculator.divide(1,0));",
        lambda_body=None,
        assigned_var_name=None,
        assigned_var_type=None,
        exception_class=None,
    )

    # Expect a single line: leading whitespace + call + semicolon
    out = t._generate_replacement(am)  # 2.50μs -> 2.33μs (6.90% faster)
    assert out == "    calculator.divide(1, 0);", "Strip mode should emit bare call with same leading whitespace"


def test_strip_mode_multiple_calls_uses_base_indent_for_followups():
    # When multiple calls exist, first line keeps leading whitespace (including newline),
    # subsequent lines should only use the base indent (leading_ws with newlines stripped).
    t = JavaAssertTransformer(function_name="f", mode="strip")
    calls = [Call(full_call="first()"), Call(full_call="second()"), Call(full_call="third()")]
    # Simulate leading whitespace starting with a newline then 4 spaces.
    leading_ws = "\n    "
    am = AssertionMatch(
        is_exception_assertion=False,
        target_calls=calls,
        leading_whitespace=leading_ws,
        assertion_method="assertEquals",
        original_text="assertEquals(1, first());",
        lambda_body=None,
        assigned_var_name=None,
        assigned_var_type=None,
        exception_class=None,
    )

    out = t._generate_replacement(am)  # 2.95μs -> 2.88μs (2.47% faster)
    # Build expected string: first line keeps the newline, others use "    " only.
    expected = "\n    first();\n    second();\n    third();"
    assert out == expected, "Subsequent lines should use base_indent only (no extra leading newlines)"


def test_capture_mode_multiple_calls_generates_typed_results_and_updates_counter():
    # Use capture mode (default) and pick an assertion method that results in Object return type
    # to avoid complicated literal inference logic in the implementation.
    t = JavaAssertTransformer(function_name="f", mode="capture")
    # Ensure initial counter is zero
    assert t.invocation_counter == 0
    calls = [Call(full_call="alpha()"), Call(full_call="beta()")]
    leading_ws = "\n    "
    am = AssertionMatch(
        is_exception_assertion=False,
        target_calls=calls,
        leading_whitespace=leading_ws,
        assertion_method="assertNotNull",  # method maps directly to Object without literal inference
        original_text="assertNotNull(obj);",
        lambda_body=None,
        assigned_var_name=None,
        assigned_var_type=None,
        exception_class=None,
    )  # 3.96μs -> 4.03μs (1.71% slower)

    out = t._generate_replacement(am)

    # Expect two lines, the first with leading_ws preserved (including newline), both lines should
    # declare variables named _cf_result1 and _cf_result2 with type "Object".
    expected = "\n    Object _cf_result1 = alpha();\n    Object _cf_result2 = beta();"
    assert out == expected, "Capture mode should produce typed temp variables capturing each call"

    # The transformer's invocation_counter should have been advanced to 2
    assert t.invocation_counter == 2, "invocation_counter should reflect number of generated result variables"


def test_exception_assertion_lambda_generates_try_catch_and_increments_counter():
    # assertThrows-style assertion with a lambda body (not assigned) should produce a try/catch.
    t = JavaAssertTransformer(function_name="f", mode="capture")
    am = AssertionMatch(
        is_exception_assertion=True,
        target_calls=None,
        leading_whitespace="    ",
        assertion_method="assertThrows",
        original_text="assertThrows(Exception.class, () -> calculator.divide(1, 0));",
        lambda_body="calculator.divide(1, 0)",  # note: no trailing semicolon
        assigned_var_name=None,
        assigned_var_type=None,
        exception_class=None,
    )

    out = t._generate_replacement(am)  # 2.38μs -> 2.29μs (3.88% faster)
    # The implementation appends a semicolon if missing, and uses counter starting at 1.
    expected = "    try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}"
    assert out == expected, "Exception assertion without assignment should emit try/catch ignoring exceptions"

    # Counter should have incremented by 1
    assert t.invocation_counter == 1, "invocation_counter increments once for exception assertions"


def test_exception_assertion_with_assignment_generates_caught_and_ignored_handlers():
    # When the exception assertion is assigned to a variable, the transformer should produce
    # a null initialization followed by try/catch that catches and assigns the caught exception.
    t = JavaAssertTransformer(function_name="f", mode="capture")
    am = AssertionMatch(
        is_exception_assertion=True,
        target_calls=None,
        leading_whitespace="\t",  # tab indentation
        assertion_method="assertThrows",
        original_text="IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> code());",
        lambda_body="code()",  # no trailing semicolon -> the code_to_run will get a semicolon appended
        assigned_var_name="ex",
        assigned_var_type="IllegalArgumentException",
        exception_class="IllegalArgumentException",
    )

    out = t._generate_replacement(am)  # 3.19μs -> 3.19μs (0.000% faster)
    # Expected pattern:
    # "\tIllegalArgumentException ex = null;\n"
    # "\ttry { code(); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {}"
    base_indent = "\t"
    expected = (
        f"{base_indent}IllegalArgumentException ex = null;\n"
        f"{base_indent}try {{ code(); }} "
        f"catch (IllegalArgumentException _cf_caught1) {{ ex = _cf_caught1; }} "
        f"catch (Exception _cf_ignored1) {{}}"
    )
    assert out == expected, (
        "Assigned assertThrows should initialize var to null and catch & assign the caught exception"
    )


def test_assert_does_not_throw_with_simple_expression_assigned_emits_assignment_not_try():
    # Special-case: assertDoesNotThrow assigned to a variable with a simple lambda body that is
    # an expression (no semicolon) should be transformed into a direct assignment.
    t = JavaAssertTransformer(function_name="f", mode="capture")
    am = AssertionMatch(
        is_exception_assertion=True,
        target_calls=None,
        leading_whitespace="",
        assertion_method="assertDoesNotThrow",
        original_text="String s = assertDoesNotThrow(() -> getString());",
        lambda_body="getString()",  # no semicolon -> should create direct assignment
        assigned_var_name="s",
        assigned_var_type="String",
        exception_class=None,
    )

    out = t._generate_replacement(am)  # 2.48μs -> 2.42μs (2.47% faster)
    # Expect: "String s = getString();"
    assert out == "String s = getString();", (
        "assertDoesNotThrow assigned with expression should become direct assignment"
    )


def test_strip_mode_exception_emits_try_catch_with_ignored_variable():
    # In strip mode, exception assertions should produce try/catch using 'ignored' as the catch name.
    t = JavaAssertTransformer(function_name="f", mode="strip")
    am = AssertionMatch(
        is_exception_assertion=True,
        target_calls=None,
        leading_whitespace="  ",
        assertion_method="assertThrows",
        original_text="assertThrows(Exception.class, () -> doIt());",
        lambda_body="doIt()",
        assigned_var_name=None,
        assigned_var_type=None,
        exception_class="IOException",
    )

    out = t._generate_replacement(am)  # 2.21μs -> 2.18μs (0.962% faster)
    # Expect exception_class used if provided in strip mode
    expected = "  try { doIt(); } catch (IOException ignored) {}"
    assert out == expected, "Strip-mode exception assertions should use the provided exception class and 'ignored' name"


def test_no_target_calls_returns_empty_string():
    # If an assertion is not an exception and has no target calls (empty list), return empty string.
    t = JavaAssertTransformer(function_name="f", mode="capture")
    am_empty = AssertionMatch(
        is_exception_assertion=False,
        target_calls=[],  # empty list should be treated as "no work to do"
        leading_whitespace="    ",
        assertion_method="assertEquals",
        original_text="assertEquals(1, 2);",
        lambda_body=None,
        assigned_var_name=None,
        assigned_var_type=None,
        exception_class=None,
    )

    out = t._generate_replacement(am_empty)  # 661ns -> 671ns (1.49% slower)
    assert out == "", "Non-exception assertion with no target calls should result in empty replacement"

    # Also verify None behaves the same (not a truthy target_calls)
    am_none = am_empty._replace(target_calls=None)
    out2 = t._generate_replacement(am_none)  # 341ns -> 331ns (3.02% faster)
    assert out2 == "", "target_calls set to None should be treated as no target calls and return empty string"


def test_empty_leading_whitespace_and_multiple_calls_in_strip_mode():
    # Ensure that when leading_whitespace is empty, base_indent is empty as well.
    t = JavaAssertTransformer(function_name="f", mode="strip")
    calls = [Call(full_call="a()"), Call(full_call="b()")]
    am = AssertionMatch(
        is_exception_assertion=False,
        target_calls=calls,
        leading_whitespace="",
        assertion_method="assertEquals",
        original_text="assertEquals(1, a());",
        lambda_body=None,
        assigned_var_name=None,
        assigned_var_type=None,
        exception_class=None,
    )

    out = t._generate_replacement(am)  # 2.62μs -> 2.52μs (3.61% faster)
    # Expect lines with no indentation or leading newlines
    expected = "a();\nb();"
    assert out == expected, "Empty leading whitespace should yield no indentation on any emitted lines"


def test_large_scale_capture_many_calls_and_correct_counter_and_naming():
    # Create a large number of calls (1000) and ensure the transformer generates the
    # correct number of capture statements with sequential numbering and updates the counter.
    t = JavaAssertTransformer(function_name="f", mode="capture")
    n = 1000
    calls = [Call(full_call=f"fun({i})") for i in range(1, n + 1)]

    am = AssertionMatch(
        is_exception_assertion=False,
        target_calls=calls,
        leading_whitespace="",  # no leading newline: first line won't start with newline
        assertion_method="assertNotNull",  # avoid literal inference code paths
        original_text="assertNotNull(x);",
        lambda_body=None,
        assigned_var_name=None,
        assigned_var_type=None,
        exception_class=None,
    )

    out = t._generate_replacement(am)  # 250μs -> 249μs (0.467% faster)

    # The output should contain n lines
    lines = out.split("\n")
    assert len(lines) == n, "Should emit exactly one line per target call"

    # First line should declare _cf_result1
    assert lines[0].startswith("Object _cf_result1 = fun(1);"), "First result variable name and content should match"

    # Last line should declare _cf_result1000
    assert lines[-1].startswith(f"Object _cf_result{n} = fun({n});"), "Last result variable name should increment to n"

    # invocation_counter should be updated to n
    assert t.invocation_counter == n, (
        "invocation_counter should equal number of generated variables after large generation"
    )
from codeflash.languages.java.remove_asserts import AssertionMatch, JavaAssertTransformer, TargetCall


def create_assertion_match(
    assertion_method: str = "assertEquals",
    original_text: str = "assertEquals(1, 1);",
    leading_whitespace: str = "",
    target_calls: list = None,
    is_exception_assertion: bool = False,
    lambda_body: str = None,
    assigned_var_name: str = None,
    assigned_var_type: str = None,
    exception_class: str = None,
    start_pos: int = 0,
    end_pos: int = 100,
    statement_type: str = "junit5",
):
    """Create a real AssertionMatch instance for testing."""
    if target_calls is None:
        target_calls = []

    real_target_calls = []
    for call in target_calls:
        if isinstance(call, str):
            real_target_calls.append(
                TargetCall(
                    receiver=None, method_name="test_func", arguments="", full_call=call, start_pos=0, end_pos=len(call)
                )
            )
        else:
            real_target_calls.append(call)

    return AssertionMatch(
        start_pos=start_pos,
        end_pos=end_pos,
        statement_type=statement_type,
        assertion_method=assertion_method,
        target_calls=real_target_calls,
        leading_whitespace=leading_whitespace,
        original_text=original_text,
        is_exception_assertion=is_exception_assertion,
        lambda_body=lambda_body,
        assigned_var_name=assigned_var_name,
        assigned_var_type=assigned_var_type,
        exception_class=exception_class,
    )


class TestGenerateReplacementBasic:
    """Basic tests for _generate_replacement functionality."""

    def test_empty_target_calls_returns_empty_string(self):
        """When there are no target calls, replacement should be empty."""
        transformer = JavaAssertTransformer("test_func")
        assertion = create_assertion_match(target_calls=[])
        result = transformer._generate_replacement(assertion)  # 501ns -> 511ns (1.96% slower)
        assert result == ""

    def test_single_target_call_capture_mode(self):
        """Single target call in capture mode should generate variable assignment."""
        transformer = JavaAssertTransformer("test_func", mode="capture")
        assertion = create_assertion_match(
            assertion_method="assertEquals",
            original_text="assertEquals(5, obj.getValue());",
            leading_whitespace="    ",
            target_calls=["obj.getValue()"],
        )
        result = transformer._generate_replacement(assertion)  # 9.02μs -> 9.49μs (4.96% slower)
        # Should contain the variable name and assignment
        assert "_cf_result" in result
        assert "obj.getValue();" in result
        assert "int _cf_result1 = obj.getValue();" in result

    def test_multiple_target_calls_capture_mode(self):
        """Multiple target calls should generate multiple variable assignments."""
        transformer = JavaAssertTransformer("test_func", mode="capture")
        assertion = create_assertion_match(
            assertion_method="assertEquals",
            original_text="assertEquals(obj.func1(), obj.func2());",
            leading_whitespace="  ",
            target_calls=["obj.func1()", "obj.func2()"],
        )
        result = transformer._generate_replacement(assertion)  # 12.6μs -> 12.9μs (2.40% slower)
        # Should contain both result variables
        assert "_cf_result1" in result
        assert "_cf_result2" in result
        assert "obj.func1();" in result
        assert "obj.func2();" in result

    def test_strip_mode_single_call(self):
        """In strip mode, should generate bare function call without variable capture."""
        transformer = JavaAssertTransformer("test_func", mode="strip")
        assertion = create_assertion_match(
            assertion_method="assertEquals",
            original_text="assertEquals(5, obj.getValue());",
            leading_whitespace="    ",
            target_calls=["obj.getValue()"],
        )
        result = transformer._generate_replacement(assertion)  # 1.90μs -> 1.88μs (1.06% faster)
        # Should not contain variable declarations or _cf_result
        assert "_cf_result" not in result
        assert "obj.getValue();" in result

    def test_void_return_type_generates_strip_replacement(self):
        """When target_return_type is 'void', should generate strip replacement."""
        transformer = JavaAssertTransformer("test_func", target_return_type="void", mode="capture")
        assertion = create_assertion_match(
            assertion_method="assertEquals",
            original_text="assertEquals(1, obj.method());",
            leading_whitespace="  ",
            target_calls=["obj.method()"],
        )
        result = transformer._generate_replacement(assertion)  # 1.83μs -> 1.83μs (0.000% faster)
        # Should not have variable capture even though mode is capture
        assert "_cf_result" not in result
        assert "obj.method();" in result

    def test_leading_whitespace_preserved_first_call(self):
        """First target call should preserve full leading whitespace including newlines."""
        transformer = JavaAssertTransformer("test_func")
        assertion = create_assertion_match(
            assertion_method="assertEquals", leading_whitespace="\n    ", target_calls=["func()"]
        )
        result = transformer._generate_replacement(assertion)  # 8.44μs -> 8.88μs (4.97% slower)
        # Result should start with the newline and spaces
        assert result.startswith("\n    ")

    def test_leading_whitespace_stripped_subsequent_calls(self):
        """Subsequent target calls should have newlines stripped from leading whitespace."""
        transformer = JavaAssertTransformer("test_func")
        assertion = create_assertion_match(
            assertion_method="assertEquals", leading_whitespace="\n    ", target_calls=["func1()", "func2()"]
        )
        result = transformer._generate_replacement(assertion)  # 8.72μs -> 9.32μs (6.45% slower)
        lines = result.split("\n")
        # First line should have newline prefix (from leading_whitespace)
        assert lines[0].startswith("    ")
        # Second line should not start with newlines, just spaces
        assert not lines[1].startswith("\n")

    def test_invocation_counter_incremented(self):
        """Invocation counter should be incremented for each call."""
        transformer = JavaAssertTransformer("test_func")
        assert transformer.invocation_counter == 0

        assertion = create_assertion_match(target_calls=["func()"])  # 7.88μs -> 8.35μs (5.64% slower)
        transformer._generate_replacement(assertion)
        assert transformer.invocation_counter == 1  # 4.56μs -> 1.71μs (166% faster)

        assertion2 = create_assertion_match(target_calls=["func()"])
        transformer._generate_replacement(assertion2)
        assert transformer.invocation_counter == 2

    def test_exception_assertion_detected(self):
        """Exception assertions should be delegated to exception handler."""
        transformer = JavaAssertTransformer("test_func")
        assertion = create_assertion_match(
            assertion_method="assertThrows",
            is_exception_assertion=True,
            lambda_body="calculator.divide(1, 0)",
            exception_class="ArithmeticException",
        )
        result = transformer._generate_replacement(assertion)  # 1.91μs -> 1.86μs (2.68% faster)
        # Should contain try-catch pattern
        assert "try" in result
        assert "catch" in result


class TestGenerateReplacementEdge:
    """Edge case tests for _generate_replacement."""

    def test_no_leading_whitespace(self):
        """Should work with no leading whitespace."""
        transformer = JavaAssertTransformer("test_func")
        assertion = create_assertion_match(leading_whitespace="", target_calls=["func()"])
        result = transformer._generate_replacement(assertion)  # 8.04μs -> 8.68μs (7.28% slower)
        assert "func();" in result

    def test_very_long_leading_whitespace(self):
        """Should handle excessive leading whitespace."""
        transformer = JavaAssertTransformer("test_func")
        long_whitespace = "\n" + " " * 100
        assertion = create_assertion_match(leading_whitespace=long_whitespace, target_calls=["func()"])
        result = transformer._generate_replacement(assertion)  # 7.82μs -> 8.65μs (9.51% slower)
        assert result.startswith(long_whitespace)

    def test_special_characters_in_function_calls(self):
        """Should handle function calls with special characters."""
        transformer = JavaAssertTransformer("test_func")
        assertion = create_assertion_match(target_calls=["obj.func_name_with_underscores()"])
        result = transformer._generate_replacement(assertion)  # 7.78μs -> 8.23μs (5.48% slower)
        assert "obj.func_name_with_underscores();" in result

    def test_nested_function_calls(self):
        """Should handle nested function calls."""
        transformer = JavaAssertTransformer("test_func")
        assertion = create_assertion_match(target_calls=["obj.outer(inner.method())"])
        result = transformer._generate_replacement(assertion)  # 7.54μs -> 8.23μs (8.39% slower)
        assert "obj.outer(inner.method());" in result

    def test_function_call_with_parameters(self):
        """Should handle function calls with multiple parameters."""
        transformer = JavaAssertTransformer("test_func")
        assertion = create_assertion_match(target_calls=["obj.method(1, 2, 3)"])
        result = transformer._generate_replacement(assertion)  # 7.41μs -> 8.25μs (10.1% slower)
        assert "obj.method(1, 2, 3);" in result

    def test_function_call_with_generic_types(self):
        """Should handle function calls returning generic types."""
        transformer = JavaAssertTransformer("test_func")
        assertion = create_assertion_match(target_calls=["list.get(0)"])
        result = transformer._generate_replacement(assertion)  # 7.67μs -> 8.23μs (6.80% slower)
        assert "list.get(0);" in result

    def test_multiple_calls_with_mixed_whitespace(self):
        """Multiple calls with mixed leading whitespace should be formatted correctly."""
        transformer = JavaAssertTransformer("test_func")
        assertion = create_assertion_match(leading_whitespace="\n\n  ", target_calls=["call1()", "call2()", "call3()"])
        result = transformer._generate_replacement(assertion)  # 9.10μs -> 9.56μs (4.82% slower)
        lines = result.split("\n")
        # First call has full leading whitespace
        assert lines[0] == ""
        assert lines[1] == ""
        # Subsequent calls use base indent only
        assert "call1()" in result
        assert "call2()" in result
        assert "call3()" in result

    def test_exception_assertion_with_variable_assignment(self):
        """Exception assertion with variable assignment should create proper try-catch."""
        transformer = JavaAssertTransformer("test_func")
        assertion = create_assertion_match(
            assertion_method="assertThrows",
            is_exception_assertion=True,
            lambda_body="code.execute()",
            assigned_var_name="ex",
            assigned_var_type="Exception",
            exception_class="Exception",
        )
        result = transformer._generate_replacement(assertion)  # 2.69μs -> 2.69μs (0.372% faster)
        # Should assign null first
        assert "ex = null;" in result
        # Should have try-catch
        assert "try" in result
        assert "_cf_caught" in result
        assert "ex = _cf_caught" in result

    def test_strip_mode_exception_assertion(self):
        """Exception assertion in strip mode should use simple try-catch."""
        transformer = JavaAssertTransformer("test_func", mode="strip")
        assertion = create_assertion_match(
            assertion_method="assertThrows",
            is_exception_assertion=True,
            lambda_body="code.execute()",
            exception_class="Exception",
        )
        result = transformer._generate_replacement(assertion)  # 1.78μs -> 1.76μs (1.08% faster)
        # Should have try-catch without _cf_caught
        assert "try" in result
        assert "catch" in result
        assert "_cf_caught" not in result
        assert "ignored" in result


class TestGenerateReplacementLargeScale:
    """Large-scale tests for performance and scalability."""

    def test_many_target_calls_capture_mode(self):
        """Should handle many target calls efficiently."""
        transformer = JavaAssertTransformer("test_func")
        # Create 100 target calls
        calls = [f"obj.method{i}()" for i in range(100)]
        assertion = create_assertion_match(target_calls=calls)
        result = transformer._generate_replacement(assertion)  # 29.5μs -> 30.1μs (2.29% slower)
        # Verify all calls are present
        for i, call in enumerate(calls):
            assert f"_cf_result{i + 1}" in result
            assert call + ";" in result
        # Counter should be incremented
        assert transformer.invocation_counter == 100

    def test_very_long_function_call_name(self):
        """Should handle very long function call names."""
        transformer = JavaAssertTransformer("test_func")
        long_call = "obj." + "method" * 50 + "()"
        assertion = create_assertion_match(target_calls=[long_call])
        result = transformer._generate_replacement(assertion)  # 8.06μs -> 8.60μs (6.18% slower)
        assert long_call + ";" in result

    def test_deeply_nested_function_calls(self):
        """Should handle deeply nested function calls."""
        transformer = JavaAssertTransformer("test_func")
        nested_call = "obj"
        for i in range(50):
            nested_call += f".method{i}()"
        assertion = create_assertion_match(target_calls=[nested_call])
        result = transformer._generate_replacement(assertion)  # 7.82μs -> 8.51μs (8.12% slower)
        assert nested_call + ";" in result

    def test_sequential_replacements_maintain_counter(self):
        """Multiple sequential replacements should maintain proper counter values."""
        transformer = JavaAssertTransformer("test_func")

        for i in range(100):
            assertion = create_assertion_match(target_calls=[f"func{i}()"])
            result = transformer._generate_replacement(assertion)  # 273μs -> 92.6μs (195% faster)
            assert f"_cf_result{i + 1}" in result

        assert transformer.invocation_counter == 100

    def test_mixed_call_types_large_scale(self):
        """Should handle a large mix of different call types."""
        transformer = JavaAssertTransformer("test_func")
        calls = []

        for i in range(100):
            # Mix different call patterns
            if i % 3 == 0:
                calls.append(f"simple{i}()")
            elif i % 3 == 1:
                calls.append(f"obj{i}.method{i}()")
            else:
                calls.append(f"obj{i}.method{i}(arg1, arg2, arg3)")

        assertion = create_assertion_match(target_calls=calls)
        result = transformer._generate_replacement(assertion)  # 31.1μs -> 31.5μs (1.24% slower)

        # Verify all calls are processed
        for call in calls:
            assert call + ";" in result

    def test_large_whitespace_variations(self):
        """Should handle many target calls with varying whitespace."""
        transformer = JavaAssertTransformer("test_func")
        # Create calls with progressively different indentation expectations
        calls = [f"func{i}()" for i in range(100)]
        assertion = create_assertion_match(leading_whitespace="\n" + " " * 16, target_calls=calls)
        result = transformer._generate_replacement(assertion)  # 29.1μs -> 29.8μs (2.39% slower)
        # Should maintain structure
        lines = result.split("\n")
        assert len(lines) == 100
        for i in range(100):
            assert f"_cf_result{i + 1}" in result

    def test_exception_assertion_many_calls_fallback(self):
        """Exception assertion should handle gracefully even with exception paths."""
        transformer = JavaAssertTransformer("test_func")
        for i in range(50):
            assertion = create_assertion_match(
                assertion_method="assertThrows",
                is_exception_assertion=True,
                lambda_body=f"code{i}.execute()",
                exception_class="Exception",
            )
            result = transformer._generate_replacement(assertion)  # 29.2μs -> 29.4μs (0.640% slower)
            # Should always produce valid output
            assert "try" in result or "_cf_ignored" in result

    def test_counter_consistency_across_exception_assertions(self):
        """Exception assertions should properly increment counter."""
        transformer = JavaAssertTransformer("test_func")
        initial_counter = transformer.invocation_counter

        for i in range(50):
            assertion = create_assertion_match(
                assertion_method="assertThrows",
                is_exception_assertion=True,
                lambda_body="code.execute()",
                exception_class="Exception",
            )
            transformer._generate_replacement(assertion)  # 28.9μs -> 29.0μs (0.521% slower)

        # Counter should have incremented by 50
        assert transformer.invocation_counter == initial_counter + 50

    def test_large_scale_strip_mode_performance(self):
        """Strip mode should handle many calls without slowdown."""
        transformer = JavaAssertTransformer("test_func", mode="strip")
        calls = [f"obj.method{i}({i})" for i in range(500)]

        assertion = create_assertion_match(target_calls=calls)
        result = transformer._generate_replacement(assertion)  # 47.7μs -> 44.9μs (6.23% faster)

        # All calls should be present as bare statements
        for call in calls:
            assert call + ";" in result
        # Should NOT have variable captures in strip mode
        assert "_cf_result" not in result


class TestGenerateReplacementTypeInference:
    """Tests for type inference in replacements."""

    def test_boolean_type_inference_from_assertTrue(self):
        """assertTrue/assertFalse should infer boolean return type."""
        transformer = JavaAssertTransformer("test_func")
        assertion = create_assertion_match(
            assertion_method="assertTrue", original_text="assertTrue(obj.isValid());", target_calls=["obj.isValid()"]
        )
        result = transformer._generate_replacement(assertion)  # 2.44μs -> 2.56μs (4.72% slower)
        # Should use boolean type
        assert "boolean _cf_result" in result

    def test_object_type_for_assertNull(self):
        """AssertNull should use Object type."""
        transformer = JavaAssertTransformer("test_func")
        assertion = create_assertion_match(
            assertion_method="assertNull", original_text="assertNull(obj.getValue());", target_calls=["obj.getValue()"]
        )
        result = transformer._generate_replacement(assertion)  # 2.50μs -> 2.46μs (1.22% faster)
        # Should use Object type
        assert "Object _cf_result" in result

    def test_default_object_type_for_fluent_assertions(self):
        """Fluent assertions (assertThat) should default to Object type."""
        transformer = JavaAssertTransformer("test_func")
        assertion = create_assertion_match(
            assertion_method="assertThat",
            original_text="assertThat(obj.getValue()).isNotNull();",
            target_calls=["obj.getValue()"],
        )
        result = transformer._generate_replacement(assertion)  # 2.65μs -> 2.52μs (4.75% faster)
        # Should use Object type as fallback
        assert "Object _cf_result" in result

    def test_type_consistency_across_multiple_calls(self):
        """All calls in single assertion should use same inferred type."""
        transformer = JavaAssertTransformer("test_func")
        assertion = create_assertion_match(
            assertion_method="assertTrue",
            original_text="assertTrue(obj.check1() && obj.check2());",
            target_calls=["obj.check1()", "obj.check2()"],
        )
        result = transformer._generate_replacement(assertion)  # 2.98μs -> 2.98μs (0.335% slower)
        # Both should be boolean
        assert "boolean _cf_result1" in result
        assert "boolean _cf_result2" in result

To edit these changes git checkout codeflash/optimize-pr1980-2026-04-03T13.47.39 and push.

Codeflash Static Badge

The optimization added a dictionary cache (`_type_infer_cache`) to memoize the results of `_infer_type_from_assertion_args`, which is an expensive method involving regex operations and string parsing. Before, every call to `_infer_return_type` for a value assertion (e.g., `assertEquals`) would re-parse the assertion's original text, spending ~92% of the method's runtime in `_infer_type_from_assertion_args`. With caching keyed on `(original_text, method)`, repeated assertions with identical text reuse the inferred type, cutting `_infer_return_type` time by 59% (1.38 ms → 570 µs) and overall runtime by 20% (886 µs → 708 µs). No functional regressions observed across all test cases.
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Apr 3, 2026
@claude
Copy link
Copy Markdown
Contributor

claude bot commented Apr 3, 2026

Claude finished @codeflash-ai[bot]'s task in 2m 2s —— View job


PR Review Summary

Prek Checks

All checks pass (ruff, format). mypy reports no issues.

Code Review

PR size: SMALL (10 additions, 1 deletion in 1 file)

Optimization correctness: ✅ Valid

The change adds a _type_infer_cache: dict[tuple[str, str], str] to memoize results of _infer_type_from_assertion_args. Key observations:

  • The cache key (assertion.original_text, method) correctly captures all inputs that determine the output — the function is deterministic given these two values.
  • _infer_type_from_assertion_args always returns a str (never None), so the if cached is not None guard correctly distinguishes a cached miss from a cached empty string result (defensive and correct).
  • Cache is instance-level, appropriate since JavaAssertTransformer is created per-function-optimization, limiting cache size.
  • The speedup claim (25%) is credible: the method involves regex matching and string parsing, so caching repeated assertions with identical text is a genuine win.

No bugs, security issues, or breaking changes found.

Duplicate Detection

No duplicates detected. This is a localized memoization change with no equivalent logic elsewhere.

Test Coverage

281 generated regression tests pass with 100% coverage per the PR description. No existing unit tests for this file — no change to coverage posture.


Other open codeflash-ai[bot] PRs:


Last updated: 2026-04-03

@claude claude bot merged commit 10cfedf into cf-java-void-optimization Apr 3, 2026
28 of 30 checks passed
@claude claude bot deleted the codeflash/optimize-pr1980-2026-04-03T13.47.39 branch April 3, 2026 14:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants