diff --git a/.github/workflows/e2e-java-tracer.yaml b/.github/workflows/e2e-java-tracer.yaml index 7e92e9eee..6ed17ce90 100644 --- a/.github/workflows/e2e-java-tracer.yaml +++ b/.github/workflows/e2e-java-tracer.yaml @@ -3,17 +3,9 @@ name: E2E - Java Tracer on: pull_request: paths: - - 'codeflash/languages/java/**' - - 'codeflash/languages/base.py' - - 'codeflash/languages/registry.py' - - 'codeflash/tracer.py' - - 'codeflash/benchmarking/function_ranker.py' - - 'codeflash/discovery/functions_to_optimize.py' - - 'codeflash/optimization/**' - - 'codeflash/verification/**' + - 'codeflash/**' - 'codeflash-java-runtime/**' - - 'tests/test_languages/fixtures/java_tracer_e2e/**' - - 'tests/scripts/end_to_end_test_java_tracer.py' + - 'tests/**' - '.github/workflows/e2e-java-tracer.yaml' workflow_dispatch: diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/ReplayHelper.java b/codeflash-java-runtime/src/main/java/com/codeflash/ReplayHelper.java index f4b9ec453..c8b05a4f8 100644 --- a/codeflash-java-runtime/src/main/java/com/codeflash/ReplayHelper.java +++ b/codeflash-java-runtime/src/main/java/com/codeflash/ReplayHelper.java @@ -12,20 +12,181 @@ public class ReplayHelper { - private final Connection db; + private final Connection traceDb; + + // Codeflash instrumentation state — read from environment variables once + private final String mode; // "behavior", "performance", or null + private final int loopIndex; + private final String testIteration; + private final String outputFile; // SQLite path for behavior capture + private final int innerIterations; // for performance looping + + // Behavior mode: lazily opened SQLite connection for writing results + private Connection behaviorDb; + private boolean behaviorDbInitialized; public ReplayHelper(String traceDbPath) { try { - this.db = DriverManager.getConnection("jdbc:sqlite:" + traceDbPath); + this.traceDb = DriverManager.getConnection("jdbc:sqlite:" + traceDbPath); } catch (SQLException e) { throw new RuntimeException("Failed to open trace database: " + traceDbPath, e); } + + // Read codeflash instrumentation env vars (set by the test runner) + this.mode = System.getenv("CODEFLASH_MODE"); + this.loopIndex = parseIntEnv("CODEFLASH_LOOP_INDEX", 1); + this.testIteration = getEnvOrDefault("CODEFLASH_TEST_ITERATION", "0"); + this.outputFile = System.getenv("CODEFLASH_OUTPUT_FILE"); + this.innerIterations = parseIntEnv("CODEFLASH_INNER_ITERATIONS", 10); } public void replay(String className, String methodName, String descriptor, int invocationIndex) throws Exception { - // Query the function_calls table for this method at the given index + // Deserialize args and resolve method (done once, outside timing) + Object[] allArgs = loadArgs(className, methodName, descriptor, invocationIndex); + Class targetClass = Class.forName(className); + + Type[] paramTypes = Type.getArgumentTypes(descriptor); + Class[] paramClasses = new Class[paramTypes.length]; + for (int i = 0; i < paramTypes.length; i++) { + paramClasses[i] = typeToClass(paramTypes[i]); + } + + Method method = targetClass.getDeclaredMethod(methodName, paramClasses); + method.setAccessible(true); + boolean isStatic = Modifier.isStatic(method.getModifiers()); + + Object instance = null; + if (!isStatic) { + try { + java.lang.reflect.Constructor ctor = targetClass.getDeclaredConstructor(); + ctor.setAccessible(true); + instance = ctor.newInstance(); + } catch (NoSuchMethodException e) { + instance = new org.objenesis.ObjenesisStd().newInstance(targetClass); + } + } + + // Get the calling test method name from the stack trace + String testMethodName = getCallingTestMethodName(); + // Module name = the test class that called us + String testClassName = getCallingTestClassName(); + + if ("behavior".equals(mode)) { + replayBehavior(method, instance, allArgs, className, methodName, testClassName, testMethodName); + } else if ("performance".equals(mode)) { + replayPerformance(method, instance, allArgs, className, methodName, testClassName, testMethodName); + } else { + // No codeflash mode — just invoke (trace-only or manual testing) + method.invoke(instance, allArgs); + } + } + + private void replayBehavior(Method method, Object instance, Object[] args, + String className, String methodName, + String testClassName, String testMethodName) throws Exception { + // testIteration goes at the END so the Comparator's lastUnderscore stripping + // removes it, making baseline (iteration=0) and candidate (iteration=N) keys match. + String invId = testMethodName + "_" + testIteration; + + // Print start marker (same format as behavior instrumentation) + System.out.println("!$######" + testClassName + ":" + testClassName + "." + testMethodName + + ":" + methodName + ":" + loopIndex + ":" + invId + "######$!"); + + long startNs = System.nanoTime(); + Object result; + try { + result = method.invoke(instance, args); + } catch (java.lang.reflect.InvocationTargetException e) { + throw (Exception) e.getCause(); + } + long durationNs = System.nanoTime() - startNs; + + // Print end marker + System.out.println("!######" + testClassName + ":" + testClassName + "." + testMethodName + + ":" + methodName + ":" + loopIndex + ":" + invId + ":" + durationNs + "######!"); + + // Write return value to SQLite for correctness comparison + if (outputFile != null && !outputFile.isEmpty()) { + writeBehaviorResult(testClassName, testMethodName, methodName, invId, durationNs, result); + } + } + + private void replayPerformance(Method method, Object instance, Object[] args, + String className, String methodName, + String testClassName, String testMethodName) throws Exception { + // Performance mode: run inner loop for JIT warmup, print timing for each iteration + int maxInner = innerIterations; + for (int inner = 0; inner < maxInner; inner++) { + int loopId = (loopIndex - 1) * maxInner + inner; + String invId = testMethodName; + + // Print start marker + System.out.println("!$######" + testClassName + ":" + testClassName + "." + testMethodName + + ":" + methodName + ":" + loopId + ":" + invId + "######$!"); + + long startNs = System.nanoTime(); + try { + method.invoke(instance, args); + } catch (java.lang.reflect.InvocationTargetException e) { + // Swallow — performance mode doesn't check correctness + } + long durationNs = System.nanoTime() - startNs; + + // Print end marker + System.out.println("!######" + testClassName + ":" + testClassName + "." + testMethodName + + ":" + methodName + ":" + loopId + ":" + invId + ":" + durationNs + "######!"); + } + } + + private void writeBehaviorResult(String testClassName, String testMethodName, + String functionName, String invId, + long durationNs, Object result) { + try { + ensureBehaviorDb(); + String sql = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; + try (PreparedStatement ps = behaviorDb.prepareStatement(sql)) { + ps.setString(1, testClassName); // test_module_path + ps.setString(2, testClassName); // test_class_name + ps.setString(3, testMethodName); // test_function_name + ps.setString(4, functionName); // function_getting_tested + ps.setInt(5, loopIndex); // loop_index + ps.setString(6, invId); // iteration_id + ps.setLong(7, durationNs); // runtime + ps.setBytes(8, serializeResult(result)); // return_value + ps.setString(9, "function_call"); // verification_type + ps.executeUpdate(); + } + } catch (Exception e) { + System.err.println("ReplayHelper: SQLite behavior write error: " + e.getMessage()); + } + } + + private void ensureBehaviorDb() throws SQLException { + if (behaviorDbInitialized) return; + behaviorDbInitialized = true; + behaviorDb = DriverManager.getConnection("jdbc:sqlite:" + outputFile); + try (java.sql.Statement stmt = behaviorDb.createStatement()) { + stmt.execute("CREATE TABLE IF NOT EXISTS test_results (" + + "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " + + "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " + + "runtime INTEGER, return_value BLOB, verification_type TEXT)"); + } + } + + private byte[] serializeResult(Object result) { + if (result == null) return null; + try { + return Serializer.serialize(result); + } catch (Exception e) { + // Fall back to String.valueOf if Kryo fails + return String.valueOf(result).getBytes(java.nio.charset.StandardCharsets.UTF_8); + } + } + + private Object[] loadArgs(String className, String methodName, String descriptor, int invocationIndex) + throws SQLException { byte[] argsBlob; - try (PreparedStatement stmt = db.prepareStatement( + try (PreparedStatement stmt = traceDb.prepareStatement( "SELECT args FROM function_calls " + "WHERE classname = ? AND function = ? AND descriptor = ? " + "ORDER BY time_ns LIMIT 1 OFFSET ?")) { @@ -43,46 +204,35 @@ public void replay(String className, String methodName, String descriptor, int i } } - // Deserialize args Object deserialized = Serializer.deserialize(argsBlob); if (!(deserialized instanceof Object[])) { throw new RuntimeException("Deserialized args is not Object[], got: " + (deserialized == null ? "null" : deserialized.getClass().getName())); } - Object[] allArgs = (Object[]) deserialized; - - // Load the target class - Class targetClass = Class.forName(className); + return (Object[]) deserialized; + } - // Parse descriptor to find parameter types - Type[] paramTypes = Type.getArgumentTypes(descriptor); - Class[] paramClasses = new Class[paramTypes.length]; - for (int i = 0; i < paramTypes.length; i++) { - paramClasses[i] = typeToClass(paramTypes[i]); + private static String getCallingTestMethodName() { + StackTraceElement[] stack = Thread.currentThread().getStackTrace(); + // Walk up: [0]=getStackTrace, [1]=this method, [2]=replay(), [3]=calling test method + for (int i = 3; i < stack.length; i++) { + String method = stack[i].getMethodName(); + if (method.startsWith("replay_")) { + return method; + } } + return stack.length > 3 ? stack[3].getMethodName() : "unknown"; + } - // Find the method - Method method = targetClass.getDeclaredMethod(methodName, paramClasses); - method.setAccessible(true); - - boolean isStatic = Modifier.isStatic(method.getModifiers()); - - if (isStatic) { - method.invoke(null, allArgs); - } else { - // Args contain only explicit parameters (no 'this'). - // Create a default instance via no-arg constructor or Kryo. - Object instance; - try { - java.lang.reflect.Constructor ctor = targetClass.getDeclaredConstructor(); - ctor.setAccessible(true); - instance = ctor.newInstance(); - } catch (NoSuchMethodException e) { - // Fall back to Objenesis instantiation (no constructor needed) - instance = new org.objenesis.ObjenesisStd().newInstance(targetClass); + private static String getCallingTestClassName() { + StackTraceElement[] stack = Thread.currentThread().getStackTrace(); + for (int i = 3; i < stack.length; i++) { + String cls = stack[i].getClassName(); + if (cls.contains("ReplayTest") || cls.contains("replay")) { + return cls; } - method.invoke(instance, allArgs); } + return stack.length > 3 ? stack[3].getClassName() : "unknown"; } private static Class typeToClass(Type type) throws ClassNotFoundException { @@ -106,11 +256,23 @@ private static Class typeToClass(Type type) throws ClassNotFoundException { } } + private static int parseIntEnv(String name, int defaultValue) { + String val = System.getenv(name); + if (val == null || val.isEmpty()) return defaultValue; + try { return Integer.parseInt(val); } catch (NumberFormatException e) { return defaultValue; } + } + + private static String getEnvOrDefault(String name, String defaultValue) { + String val = System.getenv(name); + return (val != null && !val.isEmpty()) ? val : defaultValue; + } + public void close() { - try { - if (db != null) db.close(); - } catch (SQLException e) { - System.err.println("Error closing ReplayHelper: " + e.getMessage()); + try { if (traceDb != null) traceDb.close(); } catch (SQLException e) { + System.err.println("Error closing ReplayHelper trace db: " + e.getMessage()); + } + try { if (behaviorDb != null) behaviorDb.close(); } catch (SQLException e) { + System.err.println("Error closing ReplayHelper behavior db: " + e.getMessage()); } } } diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceRecorder.java b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceRecorder.java index 2a22b74f4..28c2d2998 100644 --- a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceRecorder.java +++ b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceRecorder.java @@ -22,6 +22,7 @@ public final class TraceRecorder { private final TracerConfig config; private final TraceWriter writer; private final ConcurrentHashMap functionCounts = new ConcurrentHashMap<>(); + private final AtomicInteger droppedCaptures = new AtomicInteger(0); private final int maxFunctionCount; private final ExecutorService serializerExecutor; @@ -82,11 +83,13 @@ private void onEntryImpl(String className, String methodName, String descriptor, argsBlob = future.get(SERIALIZATION_TIMEOUT_MS, TimeUnit.MILLISECONDS); } catch (TimeoutException e) { future.cancel(true); + droppedCaptures.incrementAndGet(); System.err.println("[codeflash-tracer] Serialization timed out for " + className + "." + methodName); return; } catch (Exception e) { Throwable cause = e.getCause() != null ? e.getCause() : e; + droppedCaptures.incrementAndGet(); System.err.println("[codeflash-tracer] Serialization failed for " + className + "." + methodName + ": " + cause.getClass().getSimpleName() + ": " + cause.getMessage()); return; @@ -113,11 +116,15 @@ public void flush() { } metadata.put("totalCaptures", String.valueOf(totalCaptures)); + int dropped = droppedCaptures.get(); + metadata.put("droppedCaptures", String.valueOf(dropped)); + writer.writeMetadata(metadata); writer.flush(); writer.close(); System.err.println("[codeflash-tracer] Captured " + totalCaptures - + " invocations across " + functionCounts.size() + " methods"); + + " invocations across " + functionCounts.size() + " methods" + + (dropped > 0 ? " (" + dropped + " dropped due to serialization timeout/failure)" : "")); } } diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingClassVisitor.java b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingClassVisitor.java index c760ea636..90d4cd7a0 100644 --- a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingClassVisitor.java +++ b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingClassVisitor.java @@ -4,14 +4,20 @@ import org.objectweb.asm.MethodVisitor; import org.objectweb.asm.Opcodes; +import java.util.Collections; +import java.util.Map; + public class TracingClassVisitor extends ClassVisitor { private final String internalClassName; + private final Map methodLineNumbers; private String sourceFile; - public TracingClassVisitor(ClassVisitor classVisitor, String internalClassName) { + public TracingClassVisitor(ClassVisitor classVisitor, String internalClassName, + Map methodLineNumbers) { super(Opcodes.ASM9, classVisitor); this.internalClassName = internalClassName; + this.methodLineNumbers = methodLineNumbers != null ? methodLineNumbers : Collections.emptyMap(); } @Override @@ -37,7 +43,8 @@ public MethodVisitor visitMethod(int access, String name, String descriptor, return mv; } + int lineNumber = methodLineNumbers.getOrDefault(name + descriptor, 0); return new TracingMethodAdapter(mv, access, name, descriptor, - internalClassName, 0, sourceFile != null ? sourceFile : ""); + internalClassName, lineNumber, sourceFile != null ? sourceFile : ""); } } diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingTransformer.java b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingTransformer.java index 974c767a9..53ac775af 100644 --- a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingTransformer.java +++ b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingTransformer.java @@ -1,10 +1,16 @@ package com.codeflash.tracer; import org.objectweb.asm.ClassReader; +import org.objectweb.asm.ClassVisitor; import org.objectweb.asm.ClassWriter; +import org.objectweb.asm.Label; +import org.objectweb.asm.MethodVisitor; +import org.objectweb.asm.Opcodes; import java.lang.instrument.ClassFileTransformer; import java.security.ProtectionDomain; +import java.util.HashMap; +import java.util.Map; public class TracingTransformer implements ClassFileTransformer { @@ -22,11 +28,6 @@ public byte[] transform(ClassLoader loader, String className, return null; } - // Skip instrumentation if we're inside a recording call (e.g., during Kryo serialization) - if (TraceRecorder.isRecording()) { - return null; - } - // Skip internal JDK, framework, and synthetic classes if (className.startsWith("java/") || className.startsWith("javax/") @@ -51,6 +52,30 @@ public byte[] transform(ClassLoader loader, String className, private byte[] instrumentClass(String internalClassName, byte[] bytecode) { ClassReader cr = new ClassReader(bytecode); + + // Pre-scan: collect the first source line number for each method. + // ASM's visitMethod() doesn't provide line info — it arrives later via visitLineNumber(). + // We do a lightweight read pass first so the instrumentation pass has accurate line numbers. + Map methodLineNumbers = new HashMap<>(); + cr.accept(new ClassVisitor(Opcodes.ASM9) { + @Override + public MethodVisitor visitMethod(int access, String name, String descriptor, + String signature, String[] exceptions) { + String key = name + descriptor; + return new MethodVisitor(Opcodes.ASM9) { + private boolean captured = false; + + @Override + public void visitLineNumber(int line, Label start) { + if (!captured) { + methodLineNumbers.put(key, line); + captured = true; + } + } + }; + } + }, ClassReader.SKIP_FRAMES); + // Use COMPUTE_MAXS only (not COMPUTE_FRAMES) to preserve original stack map frames. // COMPUTE_FRAMES recomputes all frames and calls getCommonSuperClass() which either // triggers classloader deadlocks or produces incorrect frames when returning "java/lang/Object". @@ -58,7 +83,7 @@ private byte[] instrumentClass(String internalClassName, byte[] bytecode) { // adjusts offsets for injected code. Our AdviceAdapter only injects at method entry // (before any branch points), so existing frames remain valid. ClassWriter cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS); - TracingClassVisitor cv = new TracingClassVisitor(cw, internalClassName); + TracingClassVisitor cv = new TracingClassVisitor(cw, internalClassName, methodLineNumbers); cr.accept(cv, ClassReader.EXPAND_FRAMES); return cw.toByteArray(); } diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index 27876355b..41c4158cc 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -376,6 +376,7 @@ def _build_parser() -> ArgumentParser: subparsers.add_parser("vscode-install", help="Install the Codeflash VSCode extension") subparsers.add_parser("init-actions", help="Initialize GitHub Actions workflow") + trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize your project.", add_help=False) auth_parser = subparsers.add_parser("auth", help="Authentication commands") auth_subparsers = auth_parser.add_subparsers(dest="auth_command", help="Auth sub-commands") auth_subparsers.add_parser("login", help="Log in to Codeflash via OAuth") @@ -391,8 +392,6 @@ def _build_parser() -> ArgumentParser: compare_parser.add_argument("--timeout", type=int, default=600, help="Benchmark timeout in seconds (default: 600)") compare_parser.add_argument("--config-file", type=str, dest="config_file", help="Path to pyproject.toml") - trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize your project.") - trace_optimize.add_argument( "--max-function-count", type=int, diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 4d98fbde9..0280b2d85 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -554,11 +554,13 @@ def get_all_replay_test_functions( def _get_java_replay_test_functions( - replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path + replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path | str ) -> tuple[dict[Path, list[FunctionToOptimize]], Path]: """Parse Java replay test files to extract functions and trace file path.""" from codeflash.languages.java.replay_test import parse_replay_test_metadata + project_root_path = Path(project_root_path) + trace_file_path: Path | None = None functions: dict[Path, list[FunctionToOptimize]] = defaultdict(list) @@ -602,7 +604,7 @@ def _get_java_replay_test_functions( all_functions = lang_support.discover_functions(source_code, source_file) for func in all_functions: - if func.function_name in function_names: + if func.function_name in function_names or func.qualified_name in function_names: functions[source_file].append(func) if trace_file_path is None: diff --git a/codeflash/languages/java/instrumentation.py b/codeflash/languages/java/instrumentation.py index 83b52f3c8..c1a8ae683 100644 --- a/codeflash/languages/java/instrumentation.py +++ b/codeflash/languages/java/instrumentation.py @@ -745,26 +745,35 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str, if _is_test_annotation(stripped): if not helper_added: helper_added = True - result.append(line) - i += 1 - # Collect any additional annotations - while i < len(lines) and lines[i].strip().startswith("@"): - result.append(lines[i]) + # Check if the @Test line already contains the method signature and opening brace + # (common in compact test styles like replay tests: @Test void replay_foo_0() throws Exception {) + if "{" in line: + # The annotation line IS the method signature — don't look for a separate one + result.append(line) i += 1 - - # Now find the method signature and opening brace - method_lines = [] - while i < len(lines): - method_lines.append(lines[i]) - if "{" in lines[i]: - break + method_lines = [line] + else: + result.append(line) i += 1 - # Add the method signature lines - for ml in method_lines: - result.append(ml) - i += 1 + # Collect any additional annotations + while i < len(lines) and lines[i].strip().startswith("@"): + result.append(lines[i]) + i += 1 + + # Now find the method signature and opening brace + method_lines = [] + while i < len(lines): + method_lines.append(lines[i]) + if "{" in lines[i]: + break + i += 1 + + # Add the method signature lines + for ml in method_lines: + result.append(ml) + i += 1 # Extract the test method name from the method signature test_method_name = _extract_test_method_name(method_lines) diff --git a/codeflash/languages/java/jfr_parser.py b/codeflash/languages/java/jfr_parser.py index 7775378e6..c7f55d507 100644 --- a/codeflash/languages/java/jfr_parser.py +++ b/codeflash/languages/java/jfr_parser.py @@ -2,6 +2,7 @@ import json import logging +import os import shutil import subprocess from datetime import datetime @@ -42,6 +43,13 @@ def _find_jfr_tool(self) -> str | None: candidate = Path(home) / "bin" / "jfr" if candidate.exists(): return str(candidate) + + java_home_env = os.environ.get("JAVA_HOME") + if java_home_env: + candidate = Path(java_home_env) / "bin" / "jfr" + if candidate.exists(): + return str(candidate) + return None def _parse(self) -> None: @@ -152,6 +160,8 @@ def _frame_to_key(self, frame: dict[str, Any]) -> str | None: method_name = method.get("name", "") if not class_name or not method_name: return None + # JFR uses / separators (JVM internal format), normalize to dots for package matching + class_name = class_name.replace("/", ".") return f"{class_name}.{method_name}" def _store_method_info(self, key: str, frame: dict[str, Any]) -> None: @@ -159,7 +169,7 @@ def _store_method_info(self, key: str, frame: dict[str, Any]) -> None: return method = frame.get("method", {}) self._method_info[key] = { - "class_name": method.get("type", {}).get("name", ""), + "class_name": method.get("type", {}).get("name", "").replace("/", "."), "method_name": method.get("name", ""), "descriptor": method.get("descriptor", ""), "line_number": str(frame.get("lineNumber", 0)), diff --git a/codeflash/languages/java/maven_strategy.py b/codeflash/languages/java/maven_strategy.py index fab3a144e..6ab91cb3d 100644 --- a/codeflash/languages/java/maven_strategy.py +++ b/codeflash/languages/java/maven_strategy.py @@ -906,7 +906,15 @@ def run_tests_via_build_tool( " --add-opens java.base/java.net=ALL-UNNAMED" " --add-opens java.base/java.util.zip=ALL-UNNAMED" ) - if javaagent_arg: + if enable_coverage: + # When coverage is enabled, JaCoCo's prepare-agent goal sets argLine via + # @{argLine}. Overriding -DargLine would clobber the JaCoCo agent flag. + # Pass add-opens and javaagent via JDK_JAVA_OPTIONS instead. + jdk_opts_parts = [add_opens_flags] + if javaagent_arg: + jdk_opts_parts.insert(0, javaagent_arg) + env["JDK_JAVA_OPTIONS"] = " ".join(jdk_opts_parts) + elif javaagent_arg: cmd.append(f"-DargLine={javaagent_arg} {add_opens_flags}") else: cmd.append(f"-DargLine={add_opens_flags}") diff --git a/codeflash/languages/java/replay_test.py b/codeflash/languages/java/replay_test.py index c753bf4fa..fcc452a80 100644 --- a/codeflash/languages/java/replay_test.py +++ b/codeflash/languages/java/replay_test.py @@ -12,9 +12,12 @@ logger = logging.getLogger(__name__) -def generate_replay_tests(trace_db_path: Path, output_dir: Path, project_root: Path, max_run_count: int = 256) -> int: - """Generate JUnit 5 replay test files from a trace SQLite database. +def generate_replay_tests( + trace_db_path: Path, output_dir: Path, project_root: Path, max_run_count: int = 256, test_framework: str = "junit5" +) -> int: + """Generate JUnit replay test files from a trace SQLite database. + Supports both JUnit 5 (default) and JUnit 4. Returns the number of test files generated. """ if not trace_db_path.exists(): @@ -44,22 +47,29 @@ def generate_replay_tests(trace_db_path: Path, output_dir: Path, project_root: P test_methods_code: list[str] = [] class_function_names: list[str] = [] + # Global test counter to avoid duplicate method names for overloaded Java methods + method_name_counters: dict[str, int] = {} for method_name, descriptor in method_list: - # Count invocations for this method count_result = conn.execute( "SELECT COUNT(*) FROM function_calls WHERE classname = ? AND function = ? AND descriptor = ?", (classname, method_name, descriptor), ).fetchone() invocation_count = min(count_result[0], max_run_count) - class_function_names.append(method_name) + simple_class = classname.rsplit(".", 1)[-1] + class_function_names.append(f"{simple_class}.{method_name}") safe_method = _sanitize_identifier(method_name) for i in range(invocation_count): + # Use a global counter per method name to avoid collisions on overloaded methods + test_idx = method_name_counters.get(safe_method, 0) + method_name_counters[safe_method] = test_idx + 1 + escaped_descriptor = descriptor.replace('"', '\\"') + access = "public " if test_framework == "junit4" else "" test_methods_code.append( - f" @Test void replay_{safe_method}_{i}() throws Exception {{\n" + f" @Test {access}void replay_{safe_method}_{test_idx}() throws Exception {{\n" f' helper.replay("{classname}", "{method_name}", ' f'"{escaped_descriptor}", {i});\n' f" }}" @@ -69,18 +79,28 @@ def generate_replay_tests(trace_db_path: Path, output_dir: Path, project_root: P # Generate the test file functions_comment = ",".join(class_function_names) + if test_framework == "junit4": + test_imports = "import org.junit.Test;\nimport org.junit.AfterClass;\n" + cleanup_annotation = "@AfterClass" + class_modifier = "public " + else: + test_imports = "import org.junit.jupiter.api.Test;\nimport org.junit.jupiter.api.AfterAll;\n" + cleanup_annotation = "@AfterAll" + class_modifier = "" + test_content = ( f"// codeflash:functions={functions_comment}\n" f"// codeflash:trace_file={trace_db_path.as_posix()}\n" f"// codeflash:classname={classname}\n" f"package codeflash.replay;\n\n" - f"import org.junit.jupiter.api.Test;\n" - f"import org.junit.jupiter.api.AfterAll;\n" + f"{test_imports}" f"import com.codeflash.ReplayHelper;\n\n" - f"class {test_class_name} {{\n" + f"{class_modifier}class {test_class_name} {{\n" f" private static final ReplayHelper helper =\n" f' new ReplayHelper("{trace_db_path.as_posix()}");\n\n' - f" @AfterAll static void cleanup() {{ helper.close(); }}\n\n" + "\n\n".join(test_methods_code) + "\n" + f" {cleanup_annotation} public static void cleanup() {{ helper.close(); }}\n\n" + + "\n\n".join(test_methods_code) + + "\n" "}\n" ) diff --git a/codeflash/languages/java/resources/codeflash-runtime-1.0.0.jar b/codeflash/languages/java/resources/codeflash-runtime-1.0.0.jar index cfcee9390..546a8b89d 100644 Binary files a/codeflash/languages/java/resources/codeflash-runtime-1.0.0.jar and b/codeflash/languages/java/resources/codeflash-runtime-1.0.0.jar differ diff --git a/codeflash/languages/java/tracer.py b/codeflash/languages/java/tracer.py index 7b5a30421..ab8f19514 100644 --- a/codeflash/languages/java/tracer.py +++ b/codeflash/languages/java/tracer.py @@ -14,6 +14,39 @@ logger = logging.getLogger(__name__) +GRACEFUL_SHUTDOWN_WAIT = 5 # seconds to wait after SIGTERM before SIGKILL + + +def _run_java_with_graceful_timeout( + java_command: list[str], env: dict[str, str], timeout: int, stage_name: str +) -> None: + """Run a Java command with graceful timeout handling. + + Sends SIGTERM first (allowing JFR dump and shutdown hooks to run), + then SIGKILL if the process doesn't exit within GRACEFUL_SHUTDOWN_WAIT seconds. + """ + if not timeout: + subprocess.run(java_command, env=env, check=False) + return + + import signal + + proc = subprocess.Popen(java_command, env=env) + try: + proc.wait(timeout=timeout) + except subprocess.TimeoutExpired: + logger.warning( + "%s stage timed out after %d seconds, sending SIGTERM for graceful shutdown...", stage_name, timeout + ) + proc.send_signal(signal.SIGTERM) + try: + proc.wait(timeout=GRACEFUL_SHUTDOWN_WAIT) + except subprocess.TimeoutExpired: + logger.warning("%s stage did not exit after SIGTERM, sending SIGKILL", stage_name) + proc.kill() + proc.wait() + + # --add-opens flags needed for Kryo serialization on Java 16+ ADD_OPENS_FLAGS = ( "--add-opens=java.base/java.util=ALL-UNNAMED " @@ -48,10 +81,7 @@ def trace( # Stage 1: JFR Profiling logger.info("Stage 1: Running JFR profiling...") jfr_env = self.build_jfr_env(jfr_file) - try: - subprocess.run(java_command, env=jfr_env, check=False, timeout=timeout or None) - except subprocess.TimeoutExpired: - logger.warning("JFR profiling stage timed out after %d seconds", timeout) + _run_java_with_graceful_timeout(java_command, jfr_env, timeout, "JFR profiling") if not jfr_file.exists(): logger.warning("JFR file was not created at %s", jfr_file) @@ -62,10 +92,7 @@ def trace( trace_db_path, packages, project_root=project_root, max_function_count=max_function_count, timeout=timeout ) agent_env = self.build_agent_env(config_path) - try: - subprocess.run(java_command, env=agent_env, check=False, timeout=timeout or None) - except subprocess.TimeoutExpired: - logger.warning("Argument capture stage timed out after %d seconds", timeout) + _run_java_with_graceful_timeout(java_command, agent_env, timeout, "Argument capture") if not trace_db_path.exists(): logger.error("Trace database was not created at %s", trace_db_path) @@ -95,7 +122,12 @@ def create_tracer_config( def build_jfr_env(self, jfr_file: Path) -> dict[str, str]: env = os.environ.copy() - jfr_opts = f"-XX:StartFlightRecording=filename={jfr_file.resolve()},settings=profile,dumponexit=true" + # Use profile settings with increased sampling frequency (1ms instead of default 10ms) + # This captures more samples for short-running programs + jfr_opts = ( + f"-XX:StartFlightRecording=filename={jfr_file.resolve()},settings=profile,dumponexit=true" + ",jdk.ExecutionSample#period=1ms" + ) existing = env.get("JAVA_TOOL_OPTIONS", "") env["JAVA_TOOL_OPTIONS"] = f"{existing} {jfr_opts}".strip() return env @@ -133,7 +165,7 @@ def detect_packages_from_source(module_root: Path) -> list[str]: if stripped.startswith("package "): pkg = stripped[8:].rstrip(";").strip() parts = pkg.split(".") - prefix = ".".join(parts[: min(2, len(parts))]) + prefix = ".".join(parts[: min(3, len(parts))]) packages.add(prefix) break if stripped and not stripped.startswith("//"): @@ -153,6 +185,7 @@ def run_java_tracer( max_function_count: int = 256, timeout: int = 0, max_run_count: int = 256, + test_framework: str = "junit5", ) -> tuple[Path, Path, int]: """High-level entry point: trace a Java command and generate replay tests. @@ -169,7 +202,11 @@ def run_java_tracer( ) test_count = generate_replay_tests( - trace_db_path=trace_db, output_dir=output_dir, project_root=project_root, max_run_count=max_run_count + trace_db_path=trace_db, + output_dir=output_dir, + project_root=project_root, + max_run_count=max_run_count, + test_framework=test_framework, ) return trace_db, jfr_file, test_count diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 0296ab24e..09fd64103 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -950,7 +950,7 @@ def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]: by_id: dict[InvocationId, list[int]] = {} for result in self.test_results: if result.did_pass: - if result.runtime: + if result.runtime is not None: by_id.setdefault(result.id, []).append(result.runtime) else: msg = ( diff --git a/codeflash/tracer.py b/codeflash/tracer.py index 26dbb25c2..48920be8c 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -349,8 +349,12 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser: max_function_count = getattr(config, "max_function_count", 256) timeout = int(getattr(config, "timeout", None) or getattr(config, "tracer_timeout", 0) or 0) + console.print("[bold]Java project detected[/]") + console.print(f" Project root: {project_root}") + console.print(f" Module root: {getattr(config, 'module_root', '?')}") + console.print(f" Tests root: {getattr(config, 'tests_root', '?')}") + from codeflash.code_utils.code_utils import get_run_tmp_file - from codeflash.languages.java.build_tools import find_test_root from codeflash.languages.java.tracer import JavaTracer, run_java_tracer tracer = JavaTracer() @@ -360,12 +364,16 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser: trace_db_path = get_run_tmp_file(Path("java_trace.db")) - # Place replay tests in the project's test source tree so Maven/Gradle can compile them - test_root = find_test_root(project_root) - if test_root: - output_dir = test_root / "codeflash" / "replay" + # Place replay tests in the project's test source tree so Maven/Gradle can compile them. + # Use the config's tests_root (correctly resolved for multi-module projects) not find_test_root(). + tests_root = Path(getattr(config, "tests_root", "")) + if tests_root.is_dir(): + output_dir = tests_root / "codeflash" / "replay" else: - output_dir = project_root / "src" / "test" / "java" / "codeflash" / "replay" + from codeflash.languages.java.build_tools import find_test_root + + test_root = find_test_root(project_root) + output_dir = (test_root or project_root / "src" / "test" / "java") / "codeflash" / "replay" output_dir.mkdir(parents=True, exist_ok=True) # Remaining args after our flags are the Java command @@ -377,6 +385,12 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser: sys.exit(1) java_command = remaining + # Detect test framework for replay test generation + from codeflash.languages.java.config import detect_java_project + + java_config = detect_java_project(project_root) + test_framework = java_config.test_framework if java_config else "junit5" + trace_db, jfr_file, test_count = run_java_tracer( java_command=java_command, trace_db_path=trace_db_path, @@ -385,6 +399,7 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser: output_dir=output_dir, max_function_count=max_function_count, timeout=timeout, + test_framework=test_framework, ) console.print(f"[bold green]Java tracing complete:[/] {test_count} replay test files generated") @@ -404,7 +419,7 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser: config.replay_test = replay_test_paths config.previous_checkpoint_functions = None config.effort = EffortLevel.HIGH.value - config.no_pr = True + config.no_pr = getattr(config, "no_pr", False) config.file = None config.function = None config.test_project_root = project_root diff --git a/tests/scripts/end_to_end_test_java_tracer.py b/tests/scripts/end_to_end_test_java_tracer.py index e904a4e98..0f9f8a2ff 100644 --- a/tests/scripts/end_to_end_test_java_tracer.py +++ b/tests/scripts/end_to_end_test_java_tracer.py @@ -50,6 +50,7 @@ def run_test(expected_improvement_pct: int) -> bool: "--no-project", "-m", "codeflash.main", + "--no-pr", "optimize", "java", "-cp", @@ -59,6 +60,7 @@ def run_test(expected_improvement_pct: int) -> bool: env = os.environ.copy() env["PYTHONIOENCODING"] = "utf-8" + env["PYTHONUNBUFFERED"] = "1" logging.info(f"Running command: {' '.join(command)}") logging.info(f"Working directory: {fixture_dir}") process = subprocess.Popen( @@ -73,13 +75,11 @@ def run_test(expected_improvement_pct: int) -> bool: output = [] for line in process.stdout: - logging.info(line.strip()) + print(line, end="", flush=True) output.append(line) return_code = process.wait() stdout = "".join(output) - if return_code != 0: - logging.error(f"Full output:\n{stdout}") if return_code != 0: logging.error(f"Command returned exit code {return_code}") @@ -90,7 +90,7 @@ def run_test(expected_improvement_pct: int) -> bool: logging.error("Failed to find replay test generation message") return False - # Validate: replay tests were discovered + # Validate: replay tests were discovered (global count) replay_match = re.search(r"Discovered \d+ existing unit tests? and (\d+) replay tests?", stdout) if not replay_match: logging.error("Failed to find replay test discovery message") @@ -101,6 +101,17 @@ def run_test(expected_improvement_pct: int) -> bool: return False logging.info(f"Replay tests discovered: {num_replay}") + # Validate: replay test files were used per-function + replay_file_match = re.search(r"Discovered \d+ existing unit test files?, (\d+) replay test files?", stdout) + if not replay_file_match: + logging.error("Failed to find per-function replay test file discovery message") + return False + num_replay_files = int(replay_file_match.group(1)) + if num_replay_files == 0: + logging.error("No replay test files discovered per-function") + return False + logging.info(f"Replay test files per-function: {num_replay_files}") + # Validate: at least one optimization was found if "⚡️ Optimization successful! 📄 " not in stdout: logging.error("Failed to find optimization success message") diff --git a/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/Workload.java b/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/Workload.java index 9b6078000..7beb2a4ea 100644 --- a/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/Workload.java +++ b/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/Workload.java @@ -36,20 +36,30 @@ public int instanceMethod(int x, int y) { } public static void main(String[] args) { - // Exercise the methods so the tracer can capture invocations - System.out.println("computeSum(100) = " + computeSum(100)); - System.out.println("computeSum(50) = " + computeSum(50)); + // Run methods with large inputs so JFR can capture CPU samples. + // Small inputs finish too fast (<1ms) for JFR's 10ms sampling interval. + for (int round = 0; round < 1000; round++) { + computeSum(100_000); + repeatString("hello world ", 1000); + + List nums = new ArrayList<>(); + for (int i = 1; i <= 10_000; i++) nums.add(i); + filterEvens(nums); + Workload w = new Workload(); + w.instanceMethod(100_000, 42); + } + + // Also call with small inputs for variety in traced args + System.out.println("computeSum(100) = " + computeSum(100)); System.out.println("repeatString(\"ab\", 3) = " + repeatString("ab", 3)); - System.out.println("repeatString(\"x\", 5) = " + repeatString("x", 5)); - List nums = new ArrayList<>(); - for (int i = 1; i <= 10; i++) nums.add(i); - System.out.println("filterEvens(1..10) = " + filterEvens(nums)); + List small = new ArrayList<>(); + for (int i = 1; i <= 10; i++) small.add(i); + System.out.println("filterEvens(1..10) = " + filterEvens(small)); Workload w = new Workload(); System.out.println("instanceMethod(5, 3) = " + w.instanceMethod(5, 3)); - System.out.println("instanceMethod(10, 2) = " + w.instanceMethod(10, 2)); System.out.println("Workload complete."); } diff --git a/tests/test_languages/test_java/test_jfr_parser.py b/tests/test_languages/test_java/test_jfr_parser.py new file mode 100644 index 000000000..8b5cf8a6e --- /dev/null +++ b/tests/test_languages/test_java/test_jfr_parser.py @@ -0,0 +1,302 @@ +"""Tests for JFR parser — class name normalization, package filtering, addressable time.""" + +from __future__ import annotations + +import json +import subprocess +from pathlib import Path +from unittest.mock import patch + +import pytest + +from codeflash.languages.java.jfr_parser import JfrProfile + + +def _make_jfr_json(events: list[dict]) -> str: + """Create fake JFR JSON output matching the jfr print format.""" + return json.dumps({"recording": {"events": events}}) + + +def _make_execution_sample(class_name: str, method_name: str, start_time: str = "2026-01-01T00:00:00Z") -> dict: + return { + "type": "jdk.ExecutionSample", + "values": { + "startTime": start_time, + "stackTrace": { + "frames": [ + { + "method": { + "type": {"name": class_name}, + "name": method_name, + "descriptor": "()V", + }, + "lineNumber": 42, + } + ], + }, + }, + } + + +class TestClassNameNormalization: + """Test that JVM internal class names (com/example/Foo) are normalized to dots (com.example.Foo).""" + + def test_slash_separators_normalized_to_dots(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json( + [ + _make_execution_sample("com/aerospike/client/command/Buffer", "bytesToInt"), + _make_execution_sample("com/aerospike/client/command/Buffer", "bytesToInt"), + _make_execution_sample("com/aerospike/client/util/Utf8", "encodedLength"), + ] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.aerospike"]) + + assert profile._total_samples == 3 + assert len(profile._method_samples) == 2 + + # Keys should use dots, not slashes + assert "com.aerospike.client.command.Buffer.bytesToInt" in profile._method_samples + assert "com.aerospike.client.util.Utf8.encodedLength" in profile._method_samples + + def test_method_info_uses_dot_class_names(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json( + [_make_execution_sample("com/example/MyClass", "myMethod")] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.example"]) + + info = profile._method_info.get("com.example.MyClass.myMethod") + assert info is not None + assert info["class_name"] == "com.example.MyClass" + assert info["method_name"] == "myMethod" + + +class TestPackageFiltering: + def test_filters_by_package_prefix(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json( + [ + _make_execution_sample("com/aerospike/client/Value", "get"), + _make_execution_sample("java/util/HashMap", "put"), + _make_execution_sample("com/aerospike/benchmarks/Main", "main"), + ] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.aerospike"]) + + # Only com.aerospike classes should be in samples + assert len(profile._method_samples) == 2 + assert "com.aerospike.client.Value.get" in profile._method_samples + assert "com.aerospike.benchmarks.Main.main" in profile._method_samples + assert "java.util.HashMap.put" not in profile._method_samples + + def test_empty_packages_includes_all(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json( + [ + _make_execution_sample("com/example/Foo", "bar"), + _make_execution_sample("java/lang/String", "length"), + ] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, []) + + assert len(profile._method_samples) == 2 + + +class TestAddressableTime: + def test_addressable_time_proportional_to_samples(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + # 3 samples for methodA, 1 for methodB, spanning 10 seconds + jfr_json = _make_jfr_json( + [ + _make_execution_sample("com/example/Foo", "methodA", "2026-01-01T00:00:00Z"), + _make_execution_sample("com/example/Foo", "methodA", "2026-01-01T00:00:03Z"), + _make_execution_sample("com/example/Foo", "methodA", "2026-01-01T00:00:06Z"), + _make_execution_sample("com/example/Foo", "methodB", "2026-01-01T00:00:10Z"), + ] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.example"]) + + time_a = profile.get_addressable_time_ns("com.example.Foo", "methodA") + time_b = profile.get_addressable_time_ns("com.example.Foo", "methodB") + + # methodA has 3x the samples of methodB, so 3x the addressable time + assert time_a > 0 + assert time_b > 0 + assert time_a == pytest.approx(time_b * 3, rel=0.01) + + def test_addressable_time_zero_for_unknown_method(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json( + [_make_execution_sample("com/example/Foo", "bar")] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.example"]) + + assert profile.get_addressable_time_ns("com.example.Foo", "nonExistent") == 0.0 + + +class TestMethodRanking: + def test_ranking_ordered_by_sample_count(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json( + [ + _make_execution_sample("com/example/A", "hot"), + _make_execution_sample("com/example/A", "hot"), + _make_execution_sample("com/example/A", "hot"), + _make_execution_sample("com/example/B", "warm"), + _make_execution_sample("com/example/B", "warm"), + _make_execution_sample("com/example/C", "cold"), + ] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.example"]) + + ranking = profile.get_method_ranking() + assert len(ranking) == 3 + assert ranking[0]["method_name"] == "hot" + assert ranking[0]["sample_count"] == 3 + assert ranking[1]["method_name"] == "warm" + assert ranking[1]["sample_count"] == 2 + assert ranking[2]["method_name"] == "cold" + assert ranking[2]["sample_count"] == 1 + + def test_empty_ranking_when_no_samples(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json([]) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.example"]) + + assert profile.get_method_ranking() == [] + + def test_ranking_uses_dot_class_names(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json( + [_make_execution_sample("com/example/nested/Deep", "method")] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.example"]) + + ranking = profile.get_method_ranking() + assert len(ranking) == 1 + assert ranking[0]["class_name"] == "com.example.nested.Deep" + + +class TestGracefulTimeout: + """Test that _run_java_with_graceful_timeout sends SIGTERM before SIGKILL.""" + + def test_sends_sigterm_on_timeout(self) -> None: + import signal + + from codeflash.languages.java.tracer import _run_java_with_graceful_timeout + + # Run a sleep command with a 1s timeout — should get SIGTERM'd + import os + + env = os.environ.copy() + _run_java_with_graceful_timeout(["sleep", "60"], env, timeout=1, stage_name="test") + # If we get here, the process was killed (didn't hang for 60s) + + def test_no_timeout_runs_normally(self) -> None: + import os + + from codeflash.languages.java.tracer import _run_java_with_graceful_timeout + + env = os.environ.copy() + _run_java_with_graceful_timeout(["echo", "hello"], env, timeout=0, stage_name="test") + # Should complete without error + + +class TestProjectRootResolution: + """Test that project_root is correctly set for Java multi-module projects.""" + + def test_java_project_root_is_build_root_not_module(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """For multi-module Maven, project_root should be the root with , not a sub-module.""" + # Create a multi-module project + (tmp_path / "pom.xml").write_text( + 'client', + encoding="utf-8", + ) + client = tmp_path / "client" + client.mkdir() + (client / "pom.xml").write_text("", encoding="utf-8") + src = client / "src" / "main" / "java" + src.mkdir(parents=True) + test = tmp_path / "src" / "test" / "java" + test.mkdir(parents=True) + monkeypatch.chdir(tmp_path) + + from codeflash.code_utils.config_parser import parse_config_file + + config, config_path = parse_config_file() + assert config["language"] == "java" + + # config_path should be the project root directory + assert config_path == tmp_path + + def test_project_root_is_path_not_string(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """project_root from process_pyproject_config should be a Path for Java projects.""" + from argparse import Namespace + + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + src = tmp_path / "src" / "main" / "java" + src.mkdir(parents=True) + test = tmp_path / "src" / "test" / "java" + test.mkdir(parents=True) + monkeypatch.chdir(tmp_path) + + from codeflash.cli_cmds.cli import process_pyproject_config + + # Create a minimal args namespace matching what parse_args produces + args = Namespace( + config_file=None, module_root=None, tests_root=None, benchmarks_root=None, + ignore_paths=None, pytest_cmd=None, formatter_cmds=None, disable_telemetry=None, + disable_imports_sorting=None, git_remote=None, override_fixtures=None, + benchmark=False, verbose=False, version=False, show_config=False, reset_config=False, + ) + args = process_pyproject_config(args) + + assert hasattr(args, "project_root") + assert isinstance(args.project_root, Path) + assert args.project_root == tmp_path diff --git a/tests/test_languages/test_java/test_replay_test_generation.py b/tests/test_languages/test_java/test_replay_test_generation.py new file mode 100644 index 000000000..da7138114 --- /dev/null +++ b/tests/test_languages/test_java/test_replay_test_generation.py @@ -0,0 +1,255 @@ +"""Tests for Java replay test generation — JUnit 4/5 support, overload handling, instrumentation skip.""" + +from __future__ import annotations + +import sqlite3 +from pathlib import Path + +import pytest + +from codeflash.languages.java.replay_test import generate_replay_tests, parse_replay_test_metadata + + +@pytest.fixture +def trace_db(tmp_path: Path) -> Path: + """Create a trace database with sample function calls.""" + db_path = tmp_path / "trace.db" + conn = sqlite3.connect(str(db_path)) + conn.execute( + "CREATE TABLE function_calls(" + "type TEXT, function TEXT, classname TEXT, filename TEXT, " + "line_number INTEGER, descriptor TEXT, time_ns INTEGER, args BLOB)" + ) + conn.execute("CREATE TABLE metadata(key TEXT PRIMARY KEY, value TEXT)") + conn.execute( + "INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ("call", "add", "com.example.Calculator", "Calculator.java", 10, "(II)I", 1000, b"\x00"), + ) + conn.execute( + "INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ("call", "add", "com.example.Calculator", "Calculator.java", 10, "(II)I", 2000, b"\x00"), + ) + conn.execute( + "INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ("call", "multiply", "com.example.Calculator", "Calculator.java", 20, "(II)I", 3000, b"\x00"), + ) + conn.commit() + conn.close() + return db_path + + +@pytest.fixture +def trace_db_overloaded(tmp_path: Path) -> Path: + """Create a trace database with overloaded methods (same name, different descriptors).""" + db_path = tmp_path / "trace_overloaded.db" + conn = sqlite3.connect(str(db_path)) + conn.execute( + "CREATE TABLE function_calls(" + "type TEXT, function TEXT, classname TEXT, filename TEXT, " + "line_number INTEGER, descriptor TEXT, time_ns INTEGER, args BLOB)" + ) + conn.execute("CREATE TABLE metadata(key TEXT PRIMARY KEY, value TEXT)") + # Two overloads of estimateKeySize with different descriptors + for i in range(3): + conn.execute( + "INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ("call", "estimateKeySize", "com.example.Command", "Command.java", 10, "(I)I", i * 1000, b"\x00"), + ) + for i in range(2): + conn.execute( + "INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ( + "call", + "estimateKeySize", + "com.example.Command", + "Command.java", + 15, + "(Ljava/lang/String;)I", + (i + 10) * 1000, + b"\x00", + ), + ) + conn.commit() + conn.close() + return db_path + + +class TestGenerateReplayTestsJunit5: + def test_generates_junit5_by_default(self, trace_db: Path, tmp_path: Path) -> None: + output_dir = tmp_path / "output" + count = generate_replay_tests(trace_db, output_dir, tmp_path) + assert count == 1 + + test_file = list(output_dir.glob("*.java"))[0] + content = test_file.read_text(encoding="utf-8") + assert "import org.junit.jupiter.api.Test;" in content + assert "import org.junit.jupiter.api.AfterAll;" in content + assert "@Test void replay_add_0()" in content + + def test_junit5_class_is_package_private(self, trace_db: Path, tmp_path: Path) -> None: + output_dir = tmp_path / "output" + generate_replay_tests(trace_db, output_dir, tmp_path) + + test_file = list(output_dir.glob("*.java"))[0] + content = test_file.read_text(encoding="utf-8") + assert "class ReplayTest_" in content + assert "public class ReplayTest_" not in content + + +class TestGenerateReplayTestsJunit4: + def test_generates_junit4_imports(self, trace_db: Path, tmp_path: Path) -> None: + output_dir = tmp_path / "output" + count = generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4") + assert count == 1 + + test_file = list(output_dir.glob("*.java"))[0] + content = test_file.read_text(encoding="utf-8") + assert "import org.junit.Test;" in content + assert "import org.junit.AfterClass;" in content + assert "org.junit.jupiter" not in content + + def test_junit4_methods_are_public(self, trace_db: Path, tmp_path: Path) -> None: + output_dir = tmp_path / "output" + generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4") + + test_file = list(output_dir.glob("*.java"))[0] + content = test_file.read_text(encoding="utf-8") + assert "@Test public void replay_add_0()" in content + + def test_junit4_class_is_public(self, trace_db: Path, tmp_path: Path) -> None: + output_dir = tmp_path / "output" + generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4") + + test_file = list(output_dir.glob("*.java"))[0] + content = test_file.read_text(encoding="utf-8") + assert "public class ReplayTest_" in content + + def test_junit4_cleanup_uses_afterclass(self, trace_db: Path, tmp_path: Path) -> None: + output_dir = tmp_path / "output" + generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4") + + test_file = list(output_dir.glob("*.java"))[0] + content = test_file.read_text(encoding="utf-8") + assert "@AfterClass" in content + assert "@AfterAll" not in content + + +class TestOverloadedMethods: + def test_no_duplicate_method_names(self, trace_db_overloaded: Path, tmp_path: Path) -> None: + output_dir = tmp_path / "output" + count = generate_replay_tests(trace_db_overloaded, output_dir, tmp_path) + assert count == 1 + + test_file = list(output_dir.glob("*.java"))[0] + content = test_file.read_text(encoding="utf-8") + + # Should have 5 unique methods (3 from first overload + 2 from second) + assert "replay_estimateKeySize_0" in content + assert "replay_estimateKeySize_1" in content + assert "replay_estimateKeySize_2" in content + assert "replay_estimateKeySize_3" in content + assert "replay_estimateKeySize_4" in content + + # Verify no duplicates by counting occurrences + lines = content.splitlines() + method_lines = [l for l in lines if "void replay_estimateKeySize_" in l] + method_names = [l.split("void ")[1].split("(")[0] for l in method_lines] + assert len(method_names) == len(set(method_names)), f"Duplicate methods: {method_names}" + + +class TestReplayTestInstrumentation: + def test_replay_tests_instrumented_correctly(self, trace_db: Path, tmp_path: Path) -> None: + """Replay tests with compact @Test lines should be instrumented without orphaned code.""" + from codeflash.languages.java.discovery import discover_functions_from_source + + output_dir = tmp_path / "output" + generate_replay_tests(trace_db, output_dir, tmp_path) + + test_file = list(output_dir.glob("*.java"))[0] + + src = "public class Calculator { public int add(int a, int b) { return a + b; } }" + funcs = discover_functions_from_source(src, tmp_path / "Calculator.java") + target = funcs[0] + + from codeflash.languages.java.support import JavaSupport + + support = JavaSupport() + success, instrumented = support.instrument_existing_test( + test_path=test_file, + call_positions=[], + function_to_optimize=target, + tests_project_root=tmp_path, + mode="behavior", + ) + assert success + assert instrumented is not None + assert "__perfinstrumented" in instrumented + + # Verify no code outside class body + lines = instrumented.splitlines() + class_closed = False + for line in lines: + if line.strip() == "}" and not line.startswith(" "): + class_closed = True + elif class_closed and line.strip() and not line.strip().startswith("//"): + pytest.fail(f"Orphaned code outside class: {line}") + + def test_replay_tests_perf_instrumented(self, trace_db: Path, tmp_path: Path) -> None: + from codeflash.languages.java.discovery import discover_functions_from_source + + output_dir = tmp_path / "output" + generate_replay_tests(trace_db, output_dir, tmp_path) + + test_file = list(output_dir.glob("*.java"))[0] + + src = "public class Calculator { public int add(int a, int b) { return a + b; } }" + funcs = discover_functions_from_source(src, tmp_path / "Calculator.java") + target = funcs[0] + + from codeflash.languages.java.support import JavaSupport + + support = JavaSupport() + success, instrumented = support.instrument_existing_test( + test_path=test_file, + call_positions=[], + function_to_optimize=target, + tests_project_root=tmp_path, + mode="performance", + ) + assert success + assert "__perfonlyinstrumented" in instrumented + + def test_regular_tests_still_instrumented(self, tmp_path: Path) -> None: + from codeflash.languages.java.discovery import discover_functions_from_source + + src = "public class Calculator { public int add(int a, int b) { return a + b; } }" + funcs = discover_functions_from_source(src, tmp_path / "Calculator.java") + target = funcs[0] + + test_file = tmp_path / "CalculatorTest.java" + test_file.write_text( + """ +import org.junit.jupiter.api.Test; +public class CalculatorTest { + @Test + public void testAdd() { + Calculator calc = new Calculator(); + calc.add(1, 2); + } +} +""", + encoding="utf-8", + ) + + from codeflash.languages.java.support import JavaSupport + + support = JavaSupport() + success, instrumented = support.instrument_existing_test( + test_path=test_file, + call_positions=[], + function_to_optimize=target, + tests_project_root=tmp_path, + mode="behavior", + ) + assert success + assert "CODEFLASH_LOOP_INDEX" in instrumented