diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py
index 27876355b..d54aa1b60 100644
--- a/codeflash/cli_cmds/cli.py
+++ b/codeflash/cli_cmds/cli.py
@@ -1,19 +1,29 @@
+from __future__ import annotations
+
import logging
import os
import sys
-from argparse import SUPPRESS, ArgumentParser, Namespace
+from argparse import SUPPRESS, ArgumentParser
from functools import lru_cache
from pathlib import Path
+from typing import TYPE_CHECKING, Any
from codeflash.cli_cmds import logging_config
from codeflash.cli_cmds.console import apologize_and_exit, logger
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_utils import exit_with_message, normalize_ignore_paths
from codeflash.code_utils.config_parser import parse_config_file
+from codeflash.languages import set_current_language
+from codeflash.languages.language_enum import Language
from codeflash.languages.test_framework import set_current_test_framework
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.version import __version__ as version
+if TYPE_CHECKING:
+ from argparse import Namespace
+
+ from codeflash.code_utils.config_parser import LanguageConfig
+
def parse_args() -> Namespace:
parser = _build_parser()
@@ -89,6 +99,23 @@ def process_pyproject_config(args: Namespace) -> Namespace:
pyproject_config, pyproject_file_path = parse_config_file(args.config_file)
except ValueError as e:
exit_with_message(f"Error parsing config file: {e}", error_on_exit=True)
+
+ language = None
+ lang_str = pyproject_config.get("language")
+ if lang_str:
+ language = Language(lang_str)
+
+ args = resolve_config_onto_args(args, pyproject_config, pyproject_file_path, language)
+
+ if is_LSP_enabled():
+ args.all = None
+ return args
+ return handle_optimize_all_arg_parsing(args)
+
+
+def resolve_config_onto_args(
+ args: Namespace, config: dict[str, Any], config_path: Path, language: Language | None
+) -> Namespace:
supported_keys = [
"module_root",
"tests_root",
@@ -102,26 +129,26 @@ def process_pyproject_config(args: Namespace) -> Namespace:
"override_fixtures",
]
for key in supported_keys:
- if key in pyproject_config and (
- (hasattr(args, key.replace("-", "_")) and getattr(args, key.replace("-", "_")) is None)
- or not hasattr(args, key.replace("-", "_"))
+ attr_key = key.replace("-", "_")
+ if key in config and (
+ (hasattr(args, attr_key) and getattr(args, attr_key) is None) or not hasattr(args, attr_key)
):
- setattr(args, key.replace("-", "_"), pyproject_config[key])
+ setattr(args, attr_key, config[key])
+
assert args.module_root is not None, "--module-root must be specified"
assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory"
- # For JS/TS projects, tests_root is optional (Jest auto-discovers tests)
- # Default to module_root if not specified
- is_js_ts_project = pyproject_config.get("language") in ("javascript", "typescript")
- is_java_project = pyproject_config.get("language") == "java"
+ is_js_ts = language in (Language.JAVASCRIPT, Language.TYPESCRIPT)
+ is_java = language == Language.JAVA
- # Set the test framework singleton for JS/TS projects
- if is_js_ts_project and pyproject_config.get("test_framework"):
- set_current_test_framework(pyproject_config["test_framework"])
+ if language is not None:
+ set_current_language(language)
+
+ if is_js_ts and config.get("test_framework"):
+ set_current_test_framework(config["test_framework"])
if args.tests_root is None:
- if is_java_project:
- # Try standard Maven/Gradle test directories
+ if is_java:
for test_dir in ["src/test/java", "test", "tests"]:
test_path = Path(args.module_root).parent / test_dir if "/" in test_dir else Path(test_dir)
if not test_path.is_absolute():
@@ -131,13 +158,11 @@ def process_pyproject_config(args: Namespace) -> Namespace:
break
if args.tests_root is None:
args.tests_root = str(Path.cwd() / "src" / "test" / "java")
- elif is_js_ts_project:
- # Try common JS test directories at project root first
+ elif is_js_ts:
for test_dir in ["test", "tests", "__tests__"]:
if Path(test_dir).is_dir():
args.tests_root = test_dir
break
- # If not found at project root, try inside module_root (e.g., src/test, src/__tests__)
if args.tests_root is None and args.module_root:
module_root_path = Path(args.module_root)
for test_dir in ["test", "tests", "__tests__"]:
@@ -145,15 +170,14 @@ def process_pyproject_config(args: Namespace) -> Namespace:
if test_path.is_dir():
args.tests_root = str(test_path)
break
- # Final fallback: default to module_root
- # Note: This may cause issues if tests are colocated with source files
- # In such cases, the user should explicitly configure testsRoot in package.json
if args.tests_root is None:
args.tests_root = args.module_root
else:
raise AssertionError("--tests-root must be specified")
+
assert Path(args.tests_root).is_dir(), f"--tests-root {args.tests_root} must be a valid directory"
- if args.benchmark:
+
+ if getattr(args, "benchmark", False):
assert args.benchmarks_root is not None, "--benchmarks-root must be specified when running with --benchmark"
assert Path(args.benchmarks_root).is_dir(), (
f"--benchmarks-root {args.benchmarks_root} must be a valid directory"
@@ -179,29 +203,25 @@ def process_pyproject_config(args: Namespace) -> Namespace:
require_github_app_or_exit(owner, repo_name)
- # Project root path is one level above the specified directory, because that's where the module can be imported from
args.module_root = Path(args.module_root).resolve()
if hasattr(args, "ignore_paths") and args.ignore_paths is not None:
- # Normalize ignore paths, supporting both literal paths and glob patterns
- # Use module_root as base path for resolving relative paths and patterns
args.ignore_paths = normalize_ignore_paths(args.ignore_paths, base_path=args.module_root)
- # If module-root is "." then all imports are relatives to it.
- # in this case, the ".." becomes outside project scope, causing issues with un-importable paths
- args.project_root = project_root_from_module_root(Path(args.module_root), pyproject_file_path)
+ args.project_root = project_root_from_module_root(Path(args.module_root), config_path)
args.tests_root = Path(args.tests_root).resolve()
if args.benchmarks_root:
args.benchmarks_root = Path(args.benchmarks_root).resolve()
- args.test_project_root = project_root_from_module_root(args.tests_root, pyproject_file_path)
+ args.test_project_root = project_root_from_module_root(args.tests_root, config_path)
- if is_java_project and pyproject_file_path.is_dir():
- # For Java projects, pyproject_file_path IS the project root directory (not a file).
- # Override project_root which may have resolved to a sub-module.
- args.project_root = pyproject_file_path.resolve()
- args.test_project_root = pyproject_file_path.resolve()
- if is_LSP_enabled():
- args.all = None
- return args
- return handle_optimize_all_arg_parsing(args)
+ if is_java and config_path.is_dir():
+ resolved_config = config_path.resolve()
+ try:
+ args.module_root.relative_to(resolved_config)
+ args.project_root = resolved_config
+ args.test_project_root = resolved_config
+ except ValueError:
+ pass
+
+ return args
def project_root_from_module_root(module_root: Path, pyproject_file_path: Path) -> Path:
@@ -221,6 +241,10 @@ def project_root_from_module_root(module_root: Path, pyproject_file_path: Path)
return module_root.parent.resolve()
+def apply_language_config(args: Namespace, lang_config: LanguageConfig) -> Namespace:
+ return resolve_config_onto_args(args, lang_config.config, lang_config.config_path, lang_config.language)
+
+
def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
if hasattr(args, "all") or (hasattr(args, "file") and args.file):
no_pr = getattr(args, "no_pr", False)
@@ -234,14 +258,14 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
# Ensure that the user can actually open PRs on the repo.
try:
git_repo = git.Repo(search_parent_directories=True)
- except git.exc.InvalidGitRepositoryError:
+ except (git.exc.InvalidGitRepositoryError, git.exc.NoSuchPathError):
mode = "--all" if hasattr(args, "all") else "--file"
logger.exception(
f"I couldn't find a git repository in the current directory. "
f"I need a git repository to run {mode} and open PRs for optimizations. Exiting..."
)
apologize_and_exit()
- git_remote = getattr(args, "git_remote", None)
+ git_remote = getattr(args, "git_remote", None) or "origin"
if not check_and_push_branch(git_repo, git_remote=git_remote):
exit_with_message("Branch is not pushed...", error_on_exit=True)
owner, repo = get_repo_owner_and_name(git_repo)
diff --git a/codeflash/code_utils/config_parser.py b/codeflash/code_utils/config_parser.py
index 196779589..f32af13cb 100644
--- a/codeflash/code_utils/config_parser.py
+++ b/codeflash/code_utils/config_parser.py
@@ -1,17 +1,29 @@
from __future__ import annotations
+import logging
+from dataclasses import dataclass
from pathlib import Path
from typing import Any
import tomlkit
from codeflash.code_utils.config_js import find_package_json, parse_package_json_config
+from codeflash.languages.language_enum import Language
from codeflash.lsp.helpers import is_LSP_enabled
+logger = logging.getLogger("codeflash")
+
PYPROJECT_TOML_CACHE: dict[Path, Path] = {}
ALL_CONFIG_FILES: dict[Path, dict[str, Path]] = {}
+@dataclass
+class LanguageConfig:
+ config: dict[str, Any]
+ config_path: Path
+ language: Language
+
+
def _try_parse_java_build_config() -> tuple[dict[str, Any], Path] | None:
"""Detect Java project from build files and parse config from pom.xml/gradle.properties.
@@ -103,6 +115,170 @@ def find_conftest_files(test_paths: list[Path]) -> list[Path]:
return list(list_of_conftest_files)
+def normalize_toml_config(config: dict[str, Any], config_file_path: Path) -> dict[str, Any]:
+ path_keys = ["module-root", "tests-root", "benchmarks-root"]
+ path_list_keys = ["ignore-paths"]
+ str_keys = {"pytest-cmd": "pytest", "git-remote": "origin"}
+ bool_keys = {
+ "override-fixtures": False,
+ "disable-telemetry": False,
+ "disable-imports-sorting": False,
+ "benchmark": False,
+ }
+ list_str_keys = {"formatter-cmds": []}
+
+ for key, default_value in str_keys.items():
+ if key in config:
+ config[key] = str(config[key])
+ else:
+ config[key] = default_value
+ for key, default_value in bool_keys.items():
+ if key in config:
+ config[key] = bool(config[key])
+ else:
+ config[key] = default_value
+ for key in path_keys:
+ if key in config:
+ config[key] = str((config_file_path.parent / Path(config[key])).resolve())
+ for key, default_value in list_str_keys.items():
+ if key in config:
+ config[key] = [str(cmd) for cmd in config[key]]
+ else:
+ config[key] = default_value
+ for key in path_list_keys:
+ if key in config:
+ config[key] = [str((config_file_path.parent / path).resolve()) for path in config[key]]
+ else:
+ config[key] = []
+
+ # Convert hyphenated keys to underscored keys
+ for key in list(config.keys()):
+ if "-" in key:
+ config[key.replace("-", "_")] = config[key]
+ del config[key]
+
+ return config
+
+
+def _parse_java_config_for_dir(dir_path: Path) -> dict[str, Any] | None:
+ from codeflash.languages.java.build_config_strategy import parse_java_project_config
+
+ return parse_java_project_config(dir_path)
+
+
+_SUBDIR_SKIP = frozenset(
+ {
+ ".git",
+ ".hg",
+ ".svn",
+ "node_modules",
+ ".venv",
+ "venv",
+ "__pycache__",
+ "target",
+ "build",
+ "dist",
+ ".tox",
+ ".mypy_cache",
+ ".ruff_cache",
+ ".pytest_cache",
+ }
+)
+
+
+def _check_dir_for_configs(dir_path: Path, configs: list[LanguageConfig], seen_languages: set[Language]) -> None:
+ if Language.PYTHON not in seen_languages:
+ pyproject = dir_path / "pyproject.toml"
+ if pyproject.exists():
+ try:
+ with pyproject.open("rb") as f:
+ data = tomlkit.parse(f.read())
+ tool = data.get("tool", {})
+ if isinstance(tool, dict) and "codeflash" in tool:
+ raw_config = dict(tool["codeflash"])
+ normalized = normalize_toml_config(raw_config, pyproject)
+ seen_languages.add(Language.PYTHON)
+ configs.append(LanguageConfig(config=normalized, config_path=pyproject, language=Language.PYTHON))
+ except Exception:
+ logger.debug("Failed to parse Python config in %s", dir_path, exc_info=True)
+
+ if Language.JAVASCRIPT not in seen_languages and Language.TYPESCRIPT not in seen_languages:
+ package_json = dir_path / "package.json"
+ if package_json.exists():
+ try:
+ import json
+
+ pkg_data = json.loads(package_json.read_text(encoding="utf-8"))
+ if isinstance(pkg_data, dict) and "codeflash" in pkg_data:
+ result = parse_package_json_config(package_json)
+ if result is not None:
+ config, path = result
+ lang = Language(config.get("language", "javascript"))
+ seen_languages.add(lang)
+ configs.append(LanguageConfig(config=config, config_path=path, language=lang))
+ except Exception:
+ logger.debug("Failed to parse JS/TS config in %s", dir_path, exc_info=True)
+
+ if Language.JAVA not in seen_languages:
+ if (
+ (dir_path / "pom.xml").exists()
+ or (dir_path / "build.gradle").exists()
+ or (dir_path / "build.gradle.kts").exists()
+ ):
+ try:
+ java_config = _parse_java_config_for_dir(dir_path)
+ if java_config is not None:
+ seen_languages.add(Language.JAVA)
+ configs.append(LanguageConfig(config=java_config, config_path=dir_path, language=Language.JAVA))
+ except Exception:
+ logger.debug("Failed to parse Java config in %s", dir_path, exc_info=True)
+
+
+def find_all_config_files(start_dir: Path | None = None) -> list[LanguageConfig]:
+ if start_dir is None:
+ start_dir = Path.cwd()
+
+ configs: list[LanguageConfig] = []
+ seen_languages: set[Language] = set()
+
+ # Determine the git root as the upward walk boundary.
+ # Without this, a pyproject.toml in ~ would be picked up from any subdirectory.
+ git_root: Path | None = None
+ try:
+ import subprocess
+
+ result = subprocess.run(
+ ["git", "rev-parse", "--show-toplevel"], capture_output=True, text=True, cwd=start_dir, check=False
+ )
+ if result.returncode == 0:
+ git_root = Path(result.stdout.strip()).resolve()
+ except Exception:
+ pass
+
+ # Walk upward from start_dir to git root (closest config wins per language)
+ dir_path = start_dir.resolve()
+ while True:
+ _check_dir_for_configs(dir_path, configs, seen_languages)
+
+ if git_root is not None and dir_path == git_root:
+ break
+ parent = dir_path.parent
+ if parent == dir_path:
+ break
+ dir_path = parent
+
+ # Scan immediate subdirectories for monorepo language subprojects
+ resolved_start = start_dir.resolve()
+ try:
+ subdirs = sorted(p for p in resolved_start.iterdir() if p.is_dir() and p.name not in _SUBDIR_SKIP)
+ except OSError:
+ subdirs = []
+ for subdir in subdirs:
+ _check_dir_for_configs(subdir, configs, seen_languages)
+
+ return configs
+
+
def parse_config_file(
config_file_path: Path | None = None, override_formatter_check: bool = False
) -> tuple[dict[str, Any], Path]:
@@ -174,55 +350,13 @@ def parse_config_file(
if config == {} and lsp_mode:
return {}, config_file_path
- # Preserve language field if present (important for JS/TS projects)
- # default values:
- path_keys = ["module-root", "tests-root", "benchmarks-root"]
- path_list_keys = ["ignore-paths"]
- str_keys = {"pytest-cmd": "pytest", "git-remote": "origin"}
- bool_keys = {
- "override-fixtures": False,
- "disable-telemetry": False,
- "disable-imports-sorting": False,
- "benchmark": False,
- }
- # Note: formatter-cmds defaults to empty list. For Python projects, black is typically
- # detected by the project detector. For Java projects, no formatter is supported yet.
- list_str_keys = {"formatter-cmds": []}
-
- for key, default_value in str_keys.items():
- if key in config:
- config[key] = str(config[key])
- else:
- config[key] = default_value
- for key, default_value in bool_keys.items():
- if key in config:
- config[key] = bool(config[key])
- else:
- config[key] = default_value
- for key in path_keys:
- if key in config:
- config[key] = str((Path(config_file_path).parent / Path(config[key])).resolve())
- for key, default_value in list_str_keys.items():
- if key in config:
- config[key] = [str(cmd) for cmd in config[key]]
- else:
- config[key] = default_value
-
- for key in path_list_keys:
- if key in config:
- config[key] = [str((Path(config_file_path).parent / path).resolve()) for path in config[key]]
- else:
- config[key] = []
+ config = normalize_toml_config(config, config_file_path)
# see if this is happening during GitHub actions setup
- if config.get("formatter-cmds") and len(config.get("formatter-cmds")) > 0 and not override_formatter_check:
- assert config.get("formatter-cmds")[0] != "your-formatter $file", (
+ if config.get("formatter_cmds") and len(config.get("formatter_cmds")) > 0 and not override_formatter_check:
+ assert config.get("formatter_cmds")[0] != "your-formatter $file", (
"The formatter command is not set correctly in pyproject.toml. Please set the "
"formatter command in the 'formatter-cmds' key. More info - https://docs.codeflash.ai/configuration"
)
- for key in list(config.keys()):
- if "-" in key:
- config[key.replace("-", "_")] = config[key]
- del config[key]
return config, config_file_path
diff --git a/codeflash/code_utils/git_utils.py b/codeflash/code_utils/git_utils.py
index 21afc3bd7..d19f4ea18 100644
--- a/codeflash/code_utils/git_utils.py
+++ b/codeflash/code_utils/git_utils.py
@@ -149,7 +149,7 @@ def mirror_path(path: Path, src_root: Path, dest_root: Path) -> Path:
def check_running_in_git_repo(module_root: str) -> bool:
try:
_ = git.Repo(module_root, search_parent_directories=True).git_dir
- except git.InvalidGitRepositoryError:
+ except (git.InvalidGitRepositoryError, git.NoSuchPathError):
return False
else:
return True
diff --git a/codeflash/main.py b/codeflash/main.py
index da0d83db6..1786968d1 100644
--- a/codeflash/main.py
+++ b/codeflash/main.py
@@ -6,6 +6,8 @@
from __future__ import annotations
+import copy
+import logging
import os
import sys
from pathlib import Path
@@ -17,16 +19,39 @@
warnings.filterwarnings("ignore")
-from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
+from codeflash.cli_cmds.cli import (
+ apply_language_config,
+ handle_optimize_all_arg_parsing,
+ parse_args,
+ process_pyproject_config,
+)
from codeflash.cli_cmds.console import paneled_text
from codeflash.code_utils import env_utils
from codeflash.code_utils.checkpoint import ask_should_use_checkpoint_get_functions
-from codeflash.code_utils.config_parser import parse_config_file
+from codeflash.code_utils.config_parser import find_all_config_files, parse_config_file
from codeflash.code_utils.version_check import check_for_newer_minor_version
+from codeflash.languages.registry import UnsupportedLanguageError, get_language_support
if TYPE_CHECKING:
from argparse import Namespace
+ from codeflash.code_utils.config_parser import LanguageConfig
+
+
+def filter_configs_for_file(configs: list[LanguageConfig], file_path: str) -> list[LanguageConfig]:
+ resolved_file = Path(file_path).resolve()
+ matching = []
+ for config in configs:
+ config_root = config.config_path.resolve()
+ if config_root.is_file():
+ config_root = config_root.parent
+ try:
+ resolved_file.relative_to(config_root)
+ matching.append(config)
+ except ValueError:
+ continue
+ return matching if matching else configs
+
def main() -> None:
"""Entry point for the codeflash command-line interface."""
@@ -88,21 +113,92 @@ def main() -> None:
ask_run_end_to_end_test(args)
else:
- # Check for first-run experience (no config exists)
- loaded_args = _handle_config_loading(args)
- if loaded_args is None:
- sys.exit(0)
- args = loaded_args
+ language_configs = find_all_config_files()
- if not env_utils.check_formatter_installed(args.formatter_cmds):
- return
- args.previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(args)
- init_sentry(enabled=not args.disable_telemetry, exclude_errors=True)
- posthog_cf.initialize_posthog(enabled=not args.disable_telemetry)
+ logger = logging.getLogger("codeflash")
+
+ if not language_configs:
+ # Fallback: no multi-config found, use existing single-config path
+ loaded_args = _handle_config_loading(args)
+ if loaded_args is None:
+ sys.exit(0)
+ args = loaded_args
+
+ if not env_utils.check_formatter_installed(args.formatter_cmds):
+ return
+ args.previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(args)
+ init_sentry(enabled=not args.disable_telemetry, exclude_errors=True)
+ posthog_cf.initialize_posthog(enabled=not args.disable_telemetry)
- from codeflash.optimization import optimizer
+ from codeflash.optimization import optimizer
+
+ optimizer.run_with_args(args)
+ return
- optimizer.run_with_args(args)
+ # Filter to single language when --file is specified
+ if hasattr(args, "file") and args.file:
+ try:
+ file_lang_support = get_language_support(Path(args.file))
+ file_language = file_lang_support.language
+ matching_configs = [lc for lc in language_configs if lc.language == file_language]
+ if matching_configs:
+ language_configs = matching_configs
+ except UnsupportedLanguageError:
+ pass # Unknown extension, let all configs run
+
+ language_configs = filter_configs_for_file(language_configs, str(args.file))
+
+ # Save the raw --all value before handle_optimize_all_arg_parsing mutates it.
+ # In multi-language mode, module_root is None at this point so the resolution
+ # produces None for the default case; we re-resolve per language inside the loop.
+ original_all = getattr(args, "all", None) if hasattr(args, "all") else None
+ optimize_all_requested = hasattr(args, "all") and original_all is not None
+
+ # Multi-language path: run git/GitHub checks ONCE before the loop
+ args = handle_optimize_all_arg_parsing(args)
+
+ results: dict[str, str] = {}
+ for lang_config in language_configs:
+ lang_name = lang_config.language.value
+ try:
+ pass_args = copy.deepcopy(args)
+ pass_args = apply_language_config(pass_args, lang_config)
+
+ if optimize_all_requested:
+ if original_all == "":
+ # --all with no path: use this language's module_root
+ pass_args.all = pass_args.module_root
+ else:
+ # --all /specific/path: preserve the user's path
+ pass_args.all = Path(str(original_all)).resolve()
+
+ if not env_utils.check_formatter_installed(pass_args.formatter_cmds):
+ logger.info("Skipping %s: formatter not installed", lang_name)
+ results[lang_name] = "skipped"
+ continue
+
+ pass_args.previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(pass_args)
+ init_sentry(enabled=not pass_args.disable_telemetry, exclude_errors=True)
+ posthog_cf.initialize_posthog(enabled=not pass_args.disable_telemetry)
+
+ logger.info("Processing %s (config: %s)", lang_name, lang_config.config_path)
+
+ from codeflash.optimization import optimizer
+
+ optimizer.run_with_args(pass_args)
+ results[lang_name] = "success"
+ except Exception:
+ logger.exception("Error processing %s, continuing with remaining languages", lang_name)
+ results[lang_name] = "failed"
+
+ _log_orchestration_summary(logger, results)
+
+
+def _log_orchestration_summary(logger: logging.Logger, results: dict[str, str]) -> None:
+ if not results:
+ return
+ parts = [f"{lang}: {status}" for lang, status in results.items()]
+ logger.info("Multi-language orchestration complete: %s", ", ".join(parts))
def _handle_config_loading(args: Namespace) -> Namespace | None:
diff --git a/tests/test_git_utils.py b/tests/test_git_utils.py
index f3f23c1d9..0666a6136 100644
--- a/tests/test_git_utils.py
+++ b/tests/test_git_utils.py
@@ -282,64 +282,145 @@ def helper():
"""
+UNSUPPORTED_LANG_DIFF = """\
+--- a/src/main.rs
++++ b/src/main.rs
+@@ -1,3 +1,4 @@
+ fn main() {
++ let x = 1;
+ println!("Hello");
+
+"""
+
+JS_TS_DIFF = """\
+--- a/src/app.js
++++ b/src/app.js
+@@ -1,3 +1,4 @@
+ function start() {
++ const x = 1;
+ return true;
+
+--- a/src/utils.ts
++++ b/src/utils.ts
+@@ -1,3 +1,4 @@
+ function helper() {
++ const y = 2;
+ return false;
+
+--- a/src/Component.jsx
++++ b/src/Component.jsx
+@@ -1,3 +1,4 @@
+ function Component() {
++ const a = null;
+ return null;
+
+--- a/src/Page.tsx
++++ b/src/Page.tsx
+@@ -1,3 +1,4 @@
+ function Page() {
++ const b = null;
+ return null;
+
+"""
+
+ALL_THREE_LANGS_DIFF = """\
+--- a/src/main.py
++++ b/src/main.py
+@@ -1,3 +1,4 @@
+ def main():
++ x = 1
+ return True
+
+--- a/src/Main.java
++++ b/src/Main.java
+@@ -1,3 +1,4 @@
+ public class Main {
++ int x = 1;
+ public static void main(String[] args) {}
+
+--- a/src/app.js
++++ b/src/app.js
+@@ -1,3 +1,4 @@
+ function app() {
++ const x = 1;
+ return true;
+
+--- a/src/utils.ts
++++ b/src/utils.ts
+@@ -1,3 +1,4 @@
+ function util() {
++ const y = 2;
+ return false;
+
+"""
+
+
class TestGetGitDiffMultiLanguage(unittest.TestCase):
@patch("codeflash.code_utils.git_utils.git.Repo")
- def test_java_diff_found_when_language_is_java(self, mock_repo_cls):
- from codeflash.languages.current import reset_current_language, set_current_language
-
+ def test_java_diff_found_without_singleton(self, mock_repo_cls):
repo = mock_repo_cls.return_value
repo.head.commit.hexsha = "abc123"
repo.working_dir = "/repo"
repo.git.diff.return_value = JAVA_ADDITION_DIFF
- set_current_language("java")
- try:
- result = get_git_diff(repo_directory=None, uncommitted_changes=True)
- assert len(result) == 1
- key = list(result.keys())[0]
- assert str(key).endswith("Fibonacci.java")
- assert result[key] == [7, 8]
- finally:
- reset_current_language()
+ result = get_git_diff(repo_directory=None, uncommitted_changes=True)
+ assert len(result) == 1
+ key = list(result.keys())[0]
+ assert str(key).endswith("Fibonacci.java")
+ assert result[key] == [7, 8]
@patch("codeflash.code_utils.git_utils.git.Repo")
- def test_java_diff_found_regardless_of_current_language(self, mock_repo_cls):
- from codeflash.languages.current import reset_current_language, set_current_language
+ def test_unsupported_extension_still_filtered(self, mock_repo_cls):
+ repo = mock_repo_cls.return_value
+ repo.head.commit.hexsha = "abc123"
+ repo.working_dir = "/repo"
+ repo.git.diff.return_value = UNSUPPORTED_LANG_DIFF
+
+ result = get_git_diff(repo_directory=None, uncommitted_changes=True)
+ assert len(result) == 0
+ @patch("codeflash.code_utils.git_utils.git.Repo")
+ def test_mixed_lang_diff_returns_all_languages(self, mock_repo_cls):
repo = mock_repo_cls.return_value
repo.head.commit.hexsha = "abc123"
repo.working_dir = "/repo"
- repo.git.diff.return_value = JAVA_ADDITION_DIFF
+ repo.git.diff.return_value = MIXED_LANG_DIFF
- # get_git_diff uses all registered extensions, not just the current language's
- set_current_language("python")
- try:
- result = get_git_diff(repo_directory=None, uncommitted_changes=True)
- assert len(result) == 1
- key = list(result.keys())[0]
- assert str(key).endswith("Fibonacci.java")
- finally:
- reset_current_language()
+ result = get_git_diff(repo_directory=None, uncommitted_changes=True)
+ assert len(result) == 2
+ keys = [str(k) for k in result.keys()]
+ assert any(k.endswith("utils.py") for k in keys)
+ assert any(k.endswith("App.java") for k in keys)
@patch("codeflash.code_utils.git_utils.git.Repo")
- def test_mixed_lang_diff_returns_all_supported_extensions(self, mock_repo_cls):
- from codeflash.languages.current import reset_current_language, set_current_language
+ def test_js_ts_extensions_found(self, mock_repo_cls):
+ repo = mock_repo_cls.return_value
+ repo.head.commit.hexsha = "abc123"
+ repo.working_dir = "/repo"
+ repo.git.diff.return_value = JS_TS_DIFF
+
+ result = get_git_diff(repo_directory=None, uncommitted_changes=True)
+ assert len(result) == 4
+ keys = [str(k) for k in result.keys()]
+ assert any(k.endswith("app.js") for k in keys)
+ assert any(k.endswith("utils.ts") for k in keys)
+ assert any(k.endswith("Component.jsx") for k in keys)
+ assert any(k.endswith("Page.tsx") for k in keys)
+ @patch("codeflash.code_utils.git_utils.git.Repo")
+ def test_mixed_all_three_languages(self, mock_repo_cls):
repo = mock_repo_cls.return_value
repo.head.commit.hexsha = "abc123"
repo.working_dir = "/repo"
- repo.git.diff.return_value = MIXED_LANG_DIFF
+ repo.git.diff.return_value = ALL_THREE_LANGS_DIFF
- # All supported extensions are returned regardless of current language
- set_current_language("python")
- try:
- result = get_git_diff(repo_directory=None, uncommitted_changes=True)
- assert len(result) == 2
- paths = [str(k) for k in result.keys()]
- assert any(p.endswith("utils.py") for p in paths)
- assert any(p.endswith("App.java") for p in paths)
- finally:
- reset_current_language()
+ result = get_git_diff(repo_directory=None, uncommitted_changes=True)
+ assert len(result) == 4
+ keys = [str(k) for k in result.keys()]
+ assert any(k.endswith("main.py") for k in keys)
+ assert any(k.endswith("Main.java") for k in keys)
+ assert any(k.endswith("app.js") for k in keys)
+ assert any(k.endswith("utils.ts") for k in keys)
if __name__ == "__main__":
diff --git a/tests/test_languages/test_registry.py b/tests/test_languages/test_registry.py
index cdb44e1af..417a4a62e 100644
--- a/tests/test_languages/test_registry.py
+++ b/tests/test_languages/test_registry.py
@@ -272,6 +272,7 @@ def test_clear_registry_removes_everything(self):
assert not is_language_supported(Language.PYTHON)
# Re-register all languages by importing
+ from codeflash.languages.java.support import JavaSupport
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
from codeflash.languages.python.support import PythonSupport
@@ -279,6 +280,7 @@ def test_clear_registry_removes_everything(self):
register_language(PythonSupport)
register_language(JavaScriptSupport)
register_language(TypeScriptSupport)
+ register_language(JavaSupport)
# Should be supported again
assert is_language_supported(Language.PYTHON)
diff --git a/tests/test_multi_config_discovery.py b/tests/test_multi_config_discovery.py
new file mode 100644
index 000000000..90cc7eca3
--- /dev/null
+++ b/tests/test_multi_config_discovery.py
@@ -0,0 +1,211 @@
+from __future__ import annotations
+
+import json
+from pathlib import Path
+from unittest.mock import patch
+
+import tomlkit
+
+from codeflash.code_utils.config_parser import find_all_config_files
+from codeflash.languages.language_enum import Language
+
+
+def write_toml(path: Path, data: dict) -> None:
+ path.write_text(tomlkit.dumps(data), encoding="utf-8")
+
+
+class TestFindAllConfigFiles:
+ def test_finds_pyproject_toml_with_codeflash_section(self, tmp_path: Path, monkeypatch) -> None:
+ write_toml(tmp_path / "pyproject.toml", {"tool": {"codeflash": {"module-root": "src"}}})
+ monkeypatch.chdir(tmp_path)
+ result = find_all_config_files()
+ assert len(result) == 1
+ assert result[0].language == Language.PYTHON
+ assert result[0].config_path == tmp_path / "pyproject.toml"
+
+ def test_finds_java_via_build_tool_detection(self, tmp_path: Path, monkeypatch) -> None:
+ java_config = {"language": "java", "module_root": str(tmp_path / "src/main/java")}
+ (tmp_path / "pom.xml").write_text("", encoding="utf-8")
+ monkeypatch.chdir(tmp_path)
+ with patch(
+ "codeflash.code_utils.config_parser._parse_java_config_for_dir",
+ return_value=java_config,
+ ):
+ result = find_all_config_files()
+ assert len(result) == 1
+ assert result[0].language == Language.JAVA
+ assert result[0].config_path == tmp_path
+
+ def test_finds_multiple_configs_python_and_java(self, tmp_path: Path, monkeypatch) -> None:
+ write_toml(tmp_path / "pyproject.toml", {"tool": {"codeflash": {"module-root": "src"}}})
+ java_config = {"language": "java", "module_root": str(tmp_path / "src/main/java")}
+ (tmp_path / "pom.xml").write_text("", encoding="utf-8")
+ monkeypatch.chdir(tmp_path)
+ with patch(
+ "codeflash.code_utils.config_parser._parse_java_config_for_dir",
+ return_value=java_config,
+ ):
+ result = find_all_config_files()
+ assert len(result) == 2
+ languages = {r.language for r in result}
+ assert languages == {Language.PYTHON, Language.JAVA}
+
+ def test_skips_pyproject_without_codeflash_section(self, tmp_path: Path, monkeypatch) -> None:
+ write_toml(tmp_path / "pyproject.toml", {"tool": {"black": {"line-length": 120}}})
+ monkeypatch.chdir(tmp_path)
+ result = find_all_config_files()
+ assert len(result) == 0
+
+ def test_finds_config_in_parent_directory(self, tmp_path: Path, monkeypatch) -> None:
+ write_toml(tmp_path / "pyproject.toml", {"tool": {"codeflash": {"module-root": "src"}}})
+ subdir = tmp_path / "subproject"
+ subdir.mkdir()
+ java_config = {"language": "java", "module_root": str(subdir / "src/main/java")}
+ (subdir / "pom.xml").write_text("", encoding="utf-8")
+ monkeypatch.chdir(subdir)
+ with patch(
+ "codeflash.code_utils.config_parser._parse_java_config_for_dir",
+ return_value=java_config,
+ ):
+ result = find_all_config_files()
+ assert len(result) == 2
+ languages = {r.language for r in result}
+ assert languages == {Language.PYTHON, Language.JAVA}
+
+ def test_closest_config_wins_per_language(self, tmp_path: Path, monkeypatch) -> None:
+ write_toml(tmp_path / "pyproject.toml", {"tool": {"codeflash": {"module-root": "."}}})
+ subdir = tmp_path / "sub"
+ subdir.mkdir()
+ write_toml(subdir / "pyproject.toml", {"tool": {"codeflash": {"module-root": "src"}}})
+ monkeypatch.chdir(subdir)
+ result = find_all_config_files()
+ assert len(result) == 1
+ assert result[0].language == Language.PYTHON
+ assert result[0].config_path == subdir / "pyproject.toml"
+
+ def test_finds_package_json_with_codeflash_section(self, tmp_path: Path, monkeypatch) -> None:
+ pkg = {"codeflash": {"moduleRoot": "src"}}
+ (tmp_path / "package.json").write_text(json.dumps(pkg), encoding="utf-8")
+ monkeypatch.chdir(tmp_path)
+ result = find_all_config_files()
+ assert len(result) == 1
+ assert result[0].language == Language.JAVASCRIPT
+ assert result[0].config_path == tmp_path / "package.json"
+
+ def test_finds_all_three_config_types(self, tmp_path: Path, monkeypatch) -> None:
+ write_toml(tmp_path / "pyproject.toml", {"tool": {"codeflash": {"module-root": "src"}}})
+ pkg = {"codeflash": {"moduleRoot": "src"}}
+ (tmp_path / "package.json").write_text(json.dumps(pkg), encoding="utf-8")
+ java_config = {"language": "java", "module_root": str(tmp_path / "src/main/java")}
+ (tmp_path / "pom.xml").write_text("", encoding="utf-8")
+ monkeypatch.chdir(tmp_path)
+ with patch(
+ "codeflash.code_utils.config_parser._parse_java_config_for_dir",
+ return_value=java_config,
+ ):
+ result = find_all_config_files()
+ assert len(result) == 3
+ languages = {r.language for r in result}
+ assert languages == {Language.PYTHON, Language.JAVA, Language.JAVASCRIPT}
+
+ def test_no_java_when_no_build_file_exists(self, tmp_path: Path, monkeypatch) -> None:
+ monkeypatch.chdir(tmp_path)
+ result = find_all_config_files()
+ assert len(result) == 0
+
+ def test_missing_codeflash_section_skipped(self, tmp_path: Path, monkeypatch) -> None:
+ write_toml(tmp_path / "pyproject.toml", {"tool": {"other": {"key": "value"}}})
+ monkeypatch.chdir(tmp_path)
+ result = find_all_config_files()
+ assert len(result) == 0
+
+ def test_finds_java_in_subdirectory(self, tmp_path: Path, monkeypatch) -> None:
+ """Monorepo: Java project in a subdirectory is discovered from the repo root."""
+ write_toml(tmp_path / "pyproject.toml", {"tool": {"codeflash": {"module-root": "src"}}})
+ java_dir = tmp_path / "java"
+ java_dir.mkdir()
+ (java_dir / "pom.xml").write_text("", encoding="utf-8")
+ java_config = {"language": "java", "module_root": str(java_dir / "src/main/java")}
+ monkeypatch.chdir(tmp_path)
+ with patch(
+ "codeflash.code_utils.config_parser._parse_java_config_for_dir",
+ return_value=java_config,
+ ):
+ result = find_all_config_files()
+ assert len(result) == 2
+ languages = {r.language for r in result}
+ assert languages == {Language.PYTHON, Language.JAVA}
+ java_result = next(r for r in result if r.language == Language.JAVA)
+ assert java_result.config_path == java_dir
+
+ def test_finds_js_in_subdirectory(self, tmp_path: Path, monkeypatch) -> None:
+ """Monorepo: JS project in a subdirectory is discovered from the repo root."""
+ write_toml(tmp_path / "pyproject.toml", {"tool": {"codeflash": {"module-root": "src"}}})
+ js_dir = tmp_path / "js"
+ js_dir.mkdir()
+ pkg = {"codeflash": {"moduleRoot": "src"}}
+ (js_dir / "package.json").write_text(json.dumps(pkg), encoding="utf-8")
+ monkeypatch.chdir(tmp_path)
+ result = find_all_config_files()
+ assert len(result) == 2
+ languages = {r.language for r in result}
+ assert languages == {Language.PYTHON, Language.JAVASCRIPT}
+
+ def test_finds_all_three_in_monorepo_subdirs(self, tmp_path: Path, monkeypatch) -> None:
+ """Monorepo: Python at root, Java and JS in subdirectories."""
+ write_toml(tmp_path / "pyproject.toml", {"tool": {"codeflash": {"module-root": "src"}}})
+ java_dir = tmp_path / "java"
+ java_dir.mkdir()
+ (java_dir / "pom.xml").write_text("", encoding="utf-8")
+ java_config = {"language": "java", "module_root": str(java_dir / "src/main/java")}
+ js_dir = tmp_path / "js"
+ js_dir.mkdir()
+ pkg = {"codeflash": {"moduleRoot": "src"}}
+ (js_dir / "package.json").write_text(json.dumps(pkg), encoding="utf-8")
+ monkeypatch.chdir(tmp_path)
+ with patch(
+ "codeflash.code_utils.config_parser._parse_java_config_for_dir",
+ return_value=java_config,
+ ):
+ result = find_all_config_files()
+ assert len(result) == 3
+ languages = {r.language for r in result}
+ assert languages == {Language.PYTHON, Language.JAVA, Language.JAVASCRIPT}
+
+ def test_skips_hidden_and_build_subdirs(self, tmp_path: Path, monkeypatch) -> None:
+ """Subdirectory scan skips .git, node_modules, target, etc."""
+ for name in [".git", "node_modules", "target", "build", "__pycache__"]:
+ d = tmp_path / name
+ d.mkdir()
+ write_toml(d / "pyproject.toml", {"tool": {"codeflash": {"module-root": "."}}})
+ monkeypatch.chdir(tmp_path)
+ result = find_all_config_files()
+ assert len(result) == 0
+
+ def test_root_config_wins_over_subdir(self, tmp_path: Path, monkeypatch) -> None:
+ """Config at CWD (found during upward walk) takes precedence over subdirectory."""
+ (tmp_path / "pom.xml").write_text("", encoding="utf-8")
+ java_dir = tmp_path / "java"
+ java_dir.mkdir()
+ (java_dir / "pom.xml").write_text("", encoding="utf-8")
+ java_config = {"language": "java", "module_root": str(tmp_path / "src/main/java")}
+ monkeypatch.chdir(tmp_path)
+ with patch(
+ "codeflash.code_utils.config_parser._parse_java_config_for_dir",
+ return_value=java_config,
+ ):
+ result = find_all_config_files()
+ java_results = [r for r in result if r.language == Language.JAVA]
+ assert len(java_results) == 1
+ assert java_results[0].config_path == tmp_path
+
+
+def test_find_all_functions_uses_registry_not_singleton() -> None:
+ """DISC-04: Verify find_all_functions_in_file uses per-file registry lookup."""
+ import inspect
+
+ from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
+
+ source = inspect.getsource(find_all_functions_in_file)
+ assert "get_language_support" in source
+ assert "current_language_support" not in source
diff --git a/tests/test_multi_language_orchestration.py b/tests/test_multi_language_orchestration.py
new file mode 100644
index 000000000..191d7a717
--- /dev/null
+++ b/tests/test_multi_language_orchestration.py
@@ -0,0 +1,819 @@
+from __future__ import annotations
+
+import logging
+from argparse import Namespace
+from pathlib import Path
+from unittest.mock import MagicMock, patch
+
+import pytest
+import tomlkit
+
+from codeflash.code_utils.config_parser import LanguageConfig, normalize_toml_config
+from codeflash.languages.language_enum import Language
+
+
+def write_toml(path: Path, data: dict) -> None:
+ path.write_text(tomlkit.dumps(data), encoding="utf-8")
+
+
+def make_base_args(**overrides) -> Namespace:
+ defaults = {
+ "module_root": None,
+ "tests_root": None,
+ "benchmarks_root": None,
+ "ignore_paths": None,
+ "pytest_cmd": None,
+ "formatter_cmds": None,
+ "disable_telemetry": None,
+ "disable_imports_sorting": None,
+ "git_remote": None,
+ "override_fixtures": None,
+ "config_file": None,
+ "file": None,
+ "function": None,
+ "no_pr": False,
+ "verbose": False,
+ "command": None,
+ "verify_setup": False,
+ "version": False,
+ "show_config": False,
+ "reset_config": False,
+ "previous_checkpoint_functions": [],
+ }
+ defaults.update(overrides)
+ return Namespace(**defaults)
+
+
+class TestApplyLanguageConfig:
+ def test_sets_module_root(self, tmp_path: Path) -> None:
+ src = tmp_path / "src" / "main" / "java"
+ src.mkdir(parents=True)
+ config = {"module_root": str(src)}
+ lang_config = LanguageConfig(config=config, config_path=tmp_path, language=Language.JAVA)
+ args = make_base_args()
+
+ from codeflash.cli_cmds.cli import apply_language_config
+
+ result = apply_language_config(args, lang_config)
+ assert result.module_root == src.resolve()
+
+ def test_sets_tests_root(self, tmp_path: Path) -> None:
+ src = tmp_path / "src" / "main" / "java"
+ src.mkdir(parents=True)
+ tests = tmp_path / "src" / "test" / "java"
+ tests.mkdir(parents=True)
+ config = {"module_root": str(src), "tests_root": str(tests)}
+ lang_config = LanguageConfig(config=config, config_path=tmp_path, language=Language.JAVA)
+ args = make_base_args()
+
+ from codeflash.cli_cmds.cli import apply_language_config
+
+ result = apply_language_config(args, lang_config)
+ assert result.tests_root == tests.resolve()
+
+ def test_resolves_paths_relative_to_config_parent(self, tmp_path: Path) -> None:
+ src = tmp_path / "src" / "main" / "java"
+ src.mkdir(parents=True)
+ tests = tmp_path / "src" / "test" / "java"
+ tests.mkdir(parents=True)
+ config = {"module_root": str(src), "tests_root": str(tests)}
+ lang_config = LanguageConfig(config=config, config_path=tmp_path, language=Language.JAVA)
+ args = make_base_args()
+
+ from codeflash.cli_cmds.cli import apply_language_config
+
+ result = apply_language_config(args, lang_config)
+ assert result.module_root.is_absolute()
+ assert result.tests_root.is_absolute()
+
+ def test_sets_project_root(self, tmp_path: Path) -> None:
+ src = tmp_path / "src" / "main" / "java"
+ src.mkdir(parents=True)
+ tests = tmp_path / "src" / "test" / "java"
+ tests.mkdir(parents=True)
+ (tmp_path / "pom.xml").touch()
+ config = {"module_root": str(src), "tests_root": str(tests)}
+ lang_config = LanguageConfig(config=config, config_path=tmp_path, language=Language.JAVA)
+ args = make_base_args()
+
+ from codeflash.cli_cmds.cli import apply_language_config
+
+ result = apply_language_config(args, lang_config)
+ assert result.project_root == tmp_path.resolve()
+
+ def test_preserves_cli_overrides(self, tmp_path: Path) -> None:
+ src = tmp_path / "src" / "main" / "java"
+ src.mkdir(parents=True)
+ override_module = tmp_path / "custom"
+ override_module.mkdir()
+ tests = tmp_path / "src" / "test" / "java"
+ tests.mkdir(parents=True)
+ config = {"module_root": str(src), "tests_root": str(tests)}
+ lang_config = LanguageConfig(config=config, config_path=tmp_path, language=Language.JAVA)
+ args = make_base_args(module_root=str(override_module))
+
+ from codeflash.cli_cmds.cli import apply_language_config
+
+ result = apply_language_config(args, lang_config)
+ assert result.module_root == override_module.resolve()
+
+ def test_copies_formatter_cmds(self, tmp_path: Path) -> None:
+ src = tmp_path / "src"
+ src.mkdir()
+ tests = tmp_path / "tests"
+ tests.mkdir()
+ config = {"module_root": str(src), "tests_root": str(tests), "formatter_cmds": ["black $file"]}
+ lang_config = LanguageConfig(config=config, config_path=tmp_path / "pyproject.toml", language=Language.PYTHON)
+ args = make_base_args()
+
+ from codeflash.cli_cmds.cli import apply_language_config
+
+ result = apply_language_config(args, lang_config)
+ assert result.formatter_cmds == ["black $file"]
+
+ def test_sets_language_singleton(self, tmp_path: Path) -> None:
+ src = tmp_path / "src" / "main" / "java"
+ src.mkdir(parents=True)
+ tests = tmp_path / "src" / "test" / "java"
+ tests.mkdir(parents=True)
+ config = {"module_root": str(src), "tests_root": str(tests)}
+ lang_config = LanguageConfig(config=config, config_path=tmp_path, language=Language.JAVA)
+ args = make_base_args()
+
+ with patch("codeflash.cli_cmds.cli.set_current_language") as mock_set:
+ from codeflash.cli_cmds.cli import apply_language_config
+
+ apply_language_config(args, lang_config)
+ mock_set.assert_called_once_with(Language.JAVA)
+
+ def test_handles_python_config(self, tmp_path: Path) -> None:
+ src = tmp_path / "src"
+ src.mkdir()
+ tests = tmp_path / "tests"
+ tests.mkdir()
+ config = {"module_root": str(src), "tests_root": str(tests)}
+ lang_config = LanguageConfig(config=config, config_path=tmp_path / "pyproject.toml", language=Language.PYTHON)
+ args = make_base_args()
+
+ from codeflash.cli_cmds.cli import apply_language_config
+
+ result = apply_language_config(args, lang_config)
+ assert result.module_root == src.resolve()
+ assert result.tests_root == tests.resolve()
+
+ def test_java_default_tests_root(self, tmp_path: Path, monkeypatch) -> None:
+ src = tmp_path / "src" / "main" / "java"
+ src.mkdir(parents=True)
+ default_tests = tmp_path / "src" / "test" / "java"
+ default_tests.mkdir(parents=True)
+ monkeypatch.chdir(tmp_path)
+ config = {"module_root": str(src)}
+ lang_config = LanguageConfig(config=config, config_path=tmp_path, language=Language.JAVA)
+ args = make_base_args()
+
+ from codeflash.cli_cmds.cli import apply_language_config
+
+ result = apply_language_config(args, lang_config)
+ assert result.tests_root == default_tests.resolve()
+
+
+def make_lang_config(tmp_path: Path, language: Language, subdir: str = "") -> LanguageConfig:
+ if language == Language.PYTHON:
+ src = tmp_path / subdir / "src" if subdir else tmp_path / "src"
+ tests = tmp_path / subdir / "tests" if subdir else tmp_path / "tests"
+ src.mkdir(parents=True, exist_ok=True)
+ tests.mkdir(parents=True, exist_ok=True)
+ config_path = tmp_path / subdir / "pyproject.toml" if subdir else tmp_path / "pyproject.toml"
+ return LanguageConfig(
+ config={"module_root": str(src), "tests_root": str(tests)},
+ config_path=config_path,
+ language=Language.PYTHON,
+ )
+ if language == Language.JAVASCRIPT:
+ src = tmp_path / subdir / "src" if subdir else tmp_path / "src"
+ tests = tmp_path / subdir / "tests" if subdir else tmp_path / "tests"
+ src.mkdir(parents=True, exist_ok=True)
+ tests.mkdir(parents=True, exist_ok=True)
+ config_path = tmp_path / subdir / "package.json" if subdir else tmp_path / "package.json"
+ return LanguageConfig(
+ config={"module_root": str(src), "tests_root": str(tests)},
+ config_path=config_path,
+ language=Language.JAVASCRIPT,
+ )
+ src = tmp_path / subdir / "src" / "main" / "java" if subdir else tmp_path / "src" / "main" / "java"
+ tests = tmp_path / subdir / "src" / "test" / "java" if subdir else tmp_path / "src" / "test" / "java"
+ src.mkdir(parents=True, exist_ok=True)
+ tests.mkdir(parents=True, exist_ok=True)
+ config_path = tmp_path / subdir if subdir else tmp_path
+ return LanguageConfig(
+ config={"module_root": str(src), "tests_root": str(tests)},
+ config_path=config_path,
+ language=Language.JAVA,
+ )
+
+
+class TestMultiLanguageOrchestration:
+ @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[])
+ @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True)
+ @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args)
+ @patch("codeflash.optimization.optimizer.run_with_args")
+ @patch("codeflash.main.find_all_config_files")
+ @patch("codeflash.main.parse_args")
+ @patch("codeflash.main.print_codeflash_banner")
+ @patch("codeflash.main.check_for_newer_minor_version")
+ @patch("codeflash.telemetry.posthog_cf.initialize_posthog")
+ @patch("codeflash.telemetry.sentry.init_sentry")
+ def test_sequential_passes_calls_optimizer_per_language(
+ self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path
+ ) -> None:
+ py_config = make_lang_config(tmp_path, Language.PYTHON)
+ java_config = make_lang_config(tmp_path, Language.JAVA)
+ mock_find_configs.return_value = [py_config, java_config]
+ mock_parse_args.return_value = make_base_args(disable_telemetry=False)
+
+ from codeflash.main import main
+
+ main()
+
+ assert mock_run.call_count == 2
+
+ @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[])
+ @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True)
+ @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args)
+ @patch("codeflash.optimization.optimizer.run_with_args")
+ @patch("codeflash.main.find_all_config_files")
+ @patch("codeflash.main.parse_args")
+ @patch("codeflash.main.print_codeflash_banner")
+ @patch("codeflash.main.check_for_newer_minor_version")
+ @patch("codeflash.telemetry.posthog_cf.initialize_posthog")
+ @patch("codeflash.telemetry.sentry.init_sentry")
+ @patch("codeflash.cli_cmds.cli.set_current_language")
+ def test_singleton_set_per_pass(
+ self,
+ mock_set_lang,
+ _sentry,
+ _posthog,
+ _ver,
+ _banner,
+ mock_parse_args,
+ mock_find_configs,
+ mock_run,
+ _handle_all,
+ _fmt,
+ _ckpt,
+ tmp_path: Path,
+ ) -> None:
+ py_config = make_lang_config(tmp_path, Language.PYTHON)
+ java_config = make_lang_config(tmp_path, Language.JAVA)
+ mock_find_configs.return_value = [py_config, java_config]
+ mock_parse_args.return_value = make_base_args(disable_telemetry=False)
+
+ from codeflash.main import main
+
+ main()
+
+ # set_current_language is called once per language pass via apply_language_config
+ lang_calls = [c for c in mock_set_lang.call_args_list if c[0][0] in (Language.PYTHON, Language.JAVA)]
+ assert len(lang_calls) >= 2
+ called_langs = {c[0][0] for c in lang_calls}
+ assert Language.PYTHON in called_langs
+ assert Language.JAVA in called_langs
+
+ @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[])
+ @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True)
+ @patch("codeflash.optimization.optimizer.run_with_args")
+ @patch("codeflash.main.find_all_config_files", return_value=[])
+ @patch("codeflash.main._handle_config_loading")
+ @patch("codeflash.main.parse_args")
+ @patch("codeflash.main.print_codeflash_banner")
+ @patch("codeflash.main.check_for_newer_minor_version")
+ @patch("codeflash.telemetry.posthog_cf.initialize_posthog")
+ @patch("codeflash.telemetry.sentry.init_sentry")
+ def test_fallback_to_single_config_when_no_multi_configs(
+ self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_handle_config, mock_run, _fmt, _ckpt, tmp_path: Path
+ ) -> None:
+ base = make_base_args(
+ disable_telemetry=False, formatter_cmds=[], module_root=str(tmp_path), tests_root=str(tmp_path)
+ )
+ mock_parse_args.return_value = base
+ mock_handle_config.return_value = base
+
+ from codeflash.main import main
+
+ main()
+
+ mock_handle_config.assert_called_once()
+ mock_run.assert_called_once()
+
+ @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[])
+ @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True)
+ @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args)
+ @patch("codeflash.optimization.optimizer.run_with_args")
+ @patch("codeflash.main.find_all_config_files")
+ @patch("codeflash.main.parse_args")
+ @patch("codeflash.main.print_codeflash_banner")
+ @patch("codeflash.main.check_for_newer_minor_version")
+ @patch("codeflash.telemetry.posthog_cf.initialize_posthog")
+ @patch("codeflash.telemetry.sentry.init_sentry")
+ def test_args_deep_copied_between_passes(
+ self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path
+ ) -> None:
+ py_config = make_lang_config(tmp_path, Language.PYTHON)
+ java_config = make_lang_config(tmp_path, Language.JAVA)
+ mock_find_configs.return_value = [py_config, java_config]
+ mock_parse_args.return_value = make_base_args(disable_telemetry=False)
+
+ from codeflash.main import main
+
+ main()
+
+ assert mock_run.call_count == 2
+ call1_args = mock_run.call_args_list[0][0][0]
+ call2_args = mock_run.call_args_list[1][0][0]
+ # Args should be different objects (deep copied)
+ assert call1_args is not call2_args
+ # Module roots should differ between Python and Java configs
+ assert call1_args.module_root != call2_args.module_root
+
+
+ @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[])
+ @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True)
+ @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args)
+ @patch("codeflash.optimization.optimizer.run_with_args")
+ @patch("codeflash.main.find_all_config_files")
+ @patch("codeflash.main.parse_args")
+ @patch("codeflash.main.print_codeflash_banner")
+ @patch("codeflash.main.check_for_newer_minor_version")
+ @patch("codeflash.telemetry.posthog_cf.initialize_posthog")
+ @patch("codeflash.telemetry.sentry.init_sentry")
+ def test_error_in_one_language_does_not_block_others(
+ self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path
+ ) -> None:
+ py_config = make_lang_config(tmp_path, Language.PYTHON)
+ java_config = make_lang_config(tmp_path, Language.JAVA)
+ mock_find_configs.return_value = [py_config, java_config]
+ mock_parse_args.return_value = make_base_args(disable_telemetry=False)
+ # First call (Python) raises, second call (Java) succeeds
+ mock_run.side_effect = [RuntimeError("Python optimizer crashed"), None]
+
+ from codeflash.main import main
+
+ main()
+
+ assert mock_run.call_count == 2
+
+ @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[])
+ @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True)
+ @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args)
+ @patch("codeflash.optimization.optimizer.run_with_args")
+ @patch("codeflash.main.find_all_config_files")
+ @patch("codeflash.main.parse_args")
+ @patch("codeflash.main.print_codeflash_banner")
+ @patch("codeflash.main.check_for_newer_minor_version")
+ @patch("codeflash.telemetry.posthog_cf.initialize_posthog")
+ @patch("codeflash.telemetry.sentry.init_sentry")
+ def test_orchestration_summary_logged(
+ self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path
+ ) -> None:
+ py_config = make_lang_config(tmp_path, Language.PYTHON)
+ java_config = make_lang_config(tmp_path, Language.JAVA)
+ mock_find_configs.return_value = [py_config, java_config]
+ mock_parse_args.return_value = make_base_args(disable_telemetry=False)
+
+ with patch("codeflash.main._log_orchestration_summary") as mock_summary:
+ from codeflash.main import main
+
+ main()
+
+ mock_summary.assert_called_once()
+ results = mock_summary.call_args[0][1]
+ assert results["python"] == "success"
+ assert results["java"] == "success"
+
+ @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[])
+ @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True)
+ @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args)
+ @patch("codeflash.optimization.optimizer.run_with_args")
+ @patch("codeflash.main.find_all_config_files")
+ @patch("codeflash.main.parse_args")
+ @patch("codeflash.main.print_codeflash_banner")
+ @patch("codeflash.main.check_for_newer_minor_version")
+ @patch("codeflash.telemetry.posthog_cf.initialize_posthog")
+ @patch("codeflash.telemetry.sentry.init_sentry")
+ def test_summary_reports_failure_status(
+ self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path
+ ) -> None:
+ py_config = make_lang_config(tmp_path, Language.PYTHON)
+ java_config = make_lang_config(tmp_path, Language.JAVA)
+ mock_find_configs.return_value = [py_config, java_config]
+ mock_parse_args.return_value = make_base_args(disable_telemetry=False)
+ mock_run.side_effect = [RuntimeError("boom"), None]
+
+ with patch("codeflash.main._log_orchestration_summary") as mock_summary:
+ from codeflash.main import main
+
+ main()
+
+ results = mock_summary.call_args[0][1]
+ assert results["python"] == "failed"
+ assert results["java"] == "success"
+
+
+class TestOrchestrationSummaryLogging:
+ def test_summary_format_all_success(self) -> None:
+ import logging
+
+ from codeflash.main import _log_orchestration_summary
+
+ with patch.object(logging.Logger, "info") as mock_info:
+ logger = logging.getLogger("codeflash.test")
+ _log_orchestration_summary(logger, {"python": "success", "java": "success"})
+ mock_info.assert_called_once()
+ msg = mock_info.call_args[0][0] % mock_info.call_args[0][1:]
+ assert "python: success" in msg
+ assert "java: success" in msg
+
+ def test_summary_format_mixed_statuses(self) -> None:
+ import logging
+
+ from codeflash.main import _log_orchestration_summary
+
+ with patch.object(logging.Logger, "info") as mock_info:
+ logger = logging.getLogger("codeflash.test")
+ _log_orchestration_summary(logger, {"python": "failed", "java": "success", "javascript": "skipped"})
+ mock_info.assert_called_once()
+ msg = mock_info.call_args[0][0] % mock_info.call_args[0][1:]
+ assert "python: failed" in msg
+ assert "java: success" in msg
+ assert "javascript: skipped" in msg
+
+ def test_summary_no_results_no_log(self) -> None:
+ import logging
+
+ from codeflash.main import _log_orchestration_summary
+
+ with patch.object(logging.Logger, "info") as mock_info:
+ logger = logging.getLogger("codeflash.test")
+ _log_orchestration_summary(logger, {})
+ mock_info.assert_not_called()
+
+ @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[])
+ @patch("codeflash.main.env_utils.check_formatter_installed")
+ @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args)
+ @patch("codeflash.optimization.optimizer.run_with_args")
+ @patch("codeflash.main.find_all_config_files")
+ @patch("codeflash.main.parse_args")
+ @patch("codeflash.main.print_codeflash_banner")
+ @patch("codeflash.main.check_for_newer_minor_version")
+ @patch("codeflash.telemetry.posthog_cf.initialize_posthog")
+ @patch("codeflash.telemetry.sentry.init_sentry")
+ def test_summary_reports_skipped_status(
+ self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, mock_fmt, _ckpt, tmp_path: Path
+ ) -> None:
+ py_config = make_lang_config(tmp_path, Language.PYTHON)
+ java_config = make_lang_config(tmp_path, Language.JAVA)
+ mock_find_configs.return_value = [py_config, java_config]
+ mock_parse_args.return_value = make_base_args(disable_telemetry=False)
+ # Python formatter check fails (skipped), Java succeeds
+ mock_fmt.side_effect = [False, True]
+
+ with patch("codeflash.main._log_orchestration_summary") as mock_summary:
+ from codeflash.main import main
+
+ main()
+
+ results = mock_summary.call_args[0][1]
+ assert results["python"] == "skipped"
+ assert results["java"] == "success"
+ assert mock_run.call_count == 1
+
+
+class TestCLIPathRouting:
+ @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[])
+ @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True)
+ @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args)
+ @patch("codeflash.optimization.optimizer.run_with_args")
+ @patch("codeflash.main.find_all_config_files")
+ @patch("codeflash.main.parse_args")
+ @patch("codeflash.main.print_codeflash_banner")
+ @patch("codeflash.main.check_for_newer_minor_version")
+ @patch("codeflash.telemetry.posthog_cf.initialize_posthog")
+ @patch("codeflash.telemetry.sentry.init_sentry")
+ def test_file_flag_filters_to_matching_language(
+ self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path
+ ) -> None:
+ py_config = make_lang_config(tmp_path, Language.PYTHON)
+ java_config = make_lang_config(tmp_path, Language.JAVA)
+ mock_find_configs.return_value = [py_config, java_config]
+ mock_parse_args.return_value = make_base_args(file="path/to/Foo.java", disable_telemetry=False)
+
+ from codeflash.main import main
+
+ main()
+
+ assert mock_run.call_count == 1
+
+ @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[])
+ @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True)
+ @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args)
+ @patch("codeflash.optimization.optimizer.run_with_args")
+ @patch("codeflash.main.find_all_config_files")
+ @patch("codeflash.main.parse_args")
+ @patch("codeflash.main.print_codeflash_banner")
+ @patch("codeflash.main.check_for_newer_minor_version")
+ @patch("codeflash.telemetry.posthog_cf.initialize_posthog")
+ @patch("codeflash.telemetry.sentry.init_sentry")
+ def test_file_flag_python_file_filters_to_python(
+ self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path
+ ) -> None:
+ py_config = make_lang_config(tmp_path, Language.PYTHON)
+ java_config = make_lang_config(tmp_path, Language.JAVA)
+ mock_find_configs.return_value = [py_config, java_config]
+ mock_parse_args.return_value = make_base_args(file="module.py", disable_telemetry=False)
+
+ from codeflash.main import main
+
+ main()
+
+ assert mock_run.call_count == 1
+
+ @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[])
+ @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True)
+ @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args)
+ @patch("codeflash.optimization.optimizer.run_with_args")
+ @patch("codeflash.main.find_all_config_files")
+ @patch("codeflash.main.parse_args")
+ @patch("codeflash.main.print_codeflash_banner")
+ @patch("codeflash.main.check_for_newer_minor_version")
+ @patch("codeflash.telemetry.posthog_cf.initialize_posthog")
+ @patch("codeflash.telemetry.sentry.init_sentry")
+ def test_file_flag_unknown_extension_runs_all(
+ self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path
+ ) -> None:
+ py_config = make_lang_config(tmp_path, Language.PYTHON)
+ java_config = make_lang_config(tmp_path, Language.JAVA)
+ mock_find_configs.return_value = [py_config, java_config]
+ mock_parse_args.return_value = make_base_args(file="Foo.rs", disable_telemetry=False)
+
+ from codeflash.main import main
+
+ main()
+
+ assert mock_run.call_count == 2
+
+ @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[])
+ @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True)
+ @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args)
+ @patch("codeflash.optimization.optimizer.run_with_args")
+ @patch("codeflash.main.find_all_config_files")
+ @patch("codeflash.main.parse_args")
+ @patch("codeflash.main.print_codeflash_banner")
+ @patch("codeflash.main.check_for_newer_minor_version")
+ @patch("codeflash.telemetry.posthog_cf.initialize_posthog")
+ @patch("codeflash.telemetry.sentry.init_sentry")
+ def test_file_flag_no_matching_config_runs_all(
+ self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path
+ ) -> None:
+ py_config = make_lang_config(tmp_path, Language.PYTHON)
+ mock_find_configs.return_value = [py_config]
+ mock_parse_args.return_value = make_base_args(file="Foo.java", disable_telemetry=False)
+
+ from codeflash.main import main
+
+ main()
+
+ assert mock_run.call_count == 1
+
+ @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[])
+ @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True)
+ @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args)
+ @patch("codeflash.optimization.optimizer.run_with_args")
+ @patch("codeflash.main.find_all_config_files")
+ @patch("codeflash.main.parse_args")
+ @patch("codeflash.main.print_codeflash_banner")
+ @patch("codeflash.main.check_for_newer_minor_version")
+ @patch("codeflash.telemetry.posthog_cf.initialize_posthog")
+ @patch("codeflash.telemetry.sentry.init_sentry")
+ def test_all_flag_sets_module_root_per_language(
+ self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path
+ ) -> None:
+ py_config = make_lang_config(tmp_path, Language.PYTHON)
+ java_config = make_lang_config(tmp_path, Language.JAVA)
+ mock_find_configs.return_value = [py_config, java_config]
+ mock_parse_args.return_value = make_base_args(all="", disable_telemetry=False)
+
+ from codeflash.main import main
+
+ main()
+
+ assert mock_run.call_count == 2
+ for call in mock_run.call_args_list:
+ passed_args = call[0][0]
+ assert passed_args.all == passed_args.module_root
+
+ @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[])
+ @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True)
+ @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args)
+ @patch("codeflash.optimization.optimizer.run_with_args")
+ @patch("codeflash.main.find_all_config_files")
+ @patch("codeflash.main.parse_args")
+ @patch("codeflash.main.print_codeflash_banner")
+ @patch("codeflash.main.check_for_newer_minor_version")
+ @patch("codeflash.telemetry.posthog_cf.initialize_posthog")
+ @patch("codeflash.telemetry.sentry.init_sentry")
+ def test_no_flags_runs_all_language_passes(
+ self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path
+ ) -> None:
+ py_config = make_lang_config(tmp_path, Language.PYTHON)
+ java_config = make_lang_config(tmp_path, Language.JAVA)
+ mock_find_configs.return_value = [py_config, java_config]
+ mock_parse_args.return_value = make_base_args(disable_telemetry=False)
+
+ from codeflash.main import main
+
+ main()
+
+ assert mock_run.call_count == 2
+
+ @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[])
+ @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True)
+ @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args)
+ @patch("codeflash.optimization.optimizer.run_with_args")
+ @patch("codeflash.main.find_all_config_files")
+ @patch("codeflash.main.parse_args")
+ @patch("codeflash.main.print_codeflash_banner")
+ @patch("codeflash.main.check_for_newer_minor_version")
+ @patch("codeflash.telemetry.posthog_cf.initialize_posthog")
+ @patch("codeflash.telemetry.sentry.init_sentry")
+ def test_file_flag_typescript_extension(
+ self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path
+ ) -> None:
+ # .tsx maps to Language.TYPESCRIPT, which is distinct from Language.JAVASCRIPT.
+ # When no TYPESCRIPT config exists, all configs run (fallback behavior).
+ py_config = make_lang_config(tmp_path, Language.PYTHON)
+ js_config = make_lang_config(tmp_path, Language.JAVASCRIPT, subdir="js-proj")
+ mock_find_configs.return_value = [py_config, js_config]
+ mock_parse_args.return_value = make_base_args(file="path/to/Component.tsx", disable_telemetry=False)
+
+ from codeflash.main import main
+
+ main()
+
+ # No TYPESCRIPT config exists, so all configs run (same as unknown extension)
+ assert mock_run.call_count == 2
+
+ @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[])
+ @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True)
+ @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args)
+ @patch("codeflash.optimization.optimizer.run_with_args")
+ @patch("codeflash.main.find_all_config_files")
+ @patch("codeflash.main.parse_args")
+ @patch("codeflash.main.print_codeflash_banner")
+ @patch("codeflash.main.check_for_newer_minor_version")
+ @patch("codeflash.telemetry.posthog_cf.initialize_posthog")
+ @patch("codeflash.telemetry.sentry.init_sentry")
+ def test_file_flag_jsx_extension(
+ self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path
+ ) -> None:
+ # .jsx maps to Language.JAVASCRIPT, so it correctly filters to the JS config.
+ py_config = make_lang_config(tmp_path, Language.PYTHON)
+ js_config = make_lang_config(tmp_path, Language.JAVASCRIPT, subdir="js-proj")
+ mock_find_configs.return_value = [py_config, js_config]
+ mock_parse_args.return_value = make_base_args(file="path/to/Widget.jsx", disable_telemetry=False)
+
+ from codeflash.main import main
+
+ main()
+
+ assert mock_run.call_count == 1
+
+
+class TestDirectFunctionCoverage:
+ def test_empty_config_no_module_root(self, tmp_path: Path) -> None:
+ config: dict = {}
+ result = normalize_toml_config(config, tmp_path / "pyproject.toml")
+ assert result["formatter_cmds"] == []
+ assert result["disable_telemetry"] is False
+ assert "module_root" not in result
+
+
+class TestNormalizeTomlConfig:
+ def test_converts_hyphenated_keys_to_underscored(self, tmp_path: Path) -> None:
+ config = {"module-root": "src", "tests-root": "tests"}
+ (tmp_path / "src").mkdir()
+ (tmp_path / "tests").mkdir()
+ result = normalize_toml_config(config, tmp_path / "pyproject.toml")
+ assert "module_root" in result
+ assert "tests_root" in result
+ assert "module-root" not in result
+ assert "tests-root" not in result
+
+ def test_resolves_paths_relative_to_config_parent(self, tmp_path: Path) -> None:
+ src = tmp_path / "src"
+ src.mkdir()
+ config = {"module-root": "src"}
+ result = normalize_toml_config(config, tmp_path / "pyproject.toml")
+ assert result["module_root"] == str(src.resolve())
+
+ def test_applies_default_values(self, tmp_path: Path) -> None:
+ config: dict = {}
+ result = normalize_toml_config(config, tmp_path / "pyproject.toml")
+ assert result["formatter_cmds"] == []
+ assert result["disable_telemetry"] is False
+ assert result["override_fixtures"] is False
+ assert result["git_remote"] == "origin"
+ assert result["pytest_cmd"] == "pytest"
+
+ def test_preserves_explicit_values(self, tmp_path: Path) -> None:
+ config = {"disable-telemetry": True, "formatter-cmds": ["prettier $file"]}
+ result = normalize_toml_config(config, tmp_path / "pyproject.toml")
+ assert result["disable_telemetry"] is True
+ assert result["formatter_cmds"] == ["prettier $file"]
+
+ def test_resolves_ignore_paths(self, tmp_path: Path) -> None:
+ config = {"ignore-paths": ["build", "dist"]}
+ result = normalize_toml_config(config, tmp_path / "pyproject.toml")
+ assert result["ignore_paths"] == [
+ str((tmp_path / "build").resolve()),
+ str((tmp_path / "dist").resolve()),
+ ]
+
+ def test_empty_ignore_paths_default(self, tmp_path: Path) -> None:
+ config: dict = {}
+ result = normalize_toml_config(config, tmp_path / "pyproject.toml")
+ assert result["ignore_paths"] == []
+
+
+class TestPerLanguageLogging:
+ @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[])
+ @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True)
+ @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args)
+ @patch("codeflash.optimization.optimizer.run_with_args")
+ @patch("codeflash.main.find_all_config_files")
+ @patch("codeflash.main.parse_args")
+ @patch("codeflash.main.print_codeflash_banner")
+ @patch("codeflash.main.check_for_newer_minor_version")
+ @patch("codeflash.telemetry.posthog_cf.initialize_posthog")
+ @patch("codeflash.telemetry.sentry.init_sentry")
+ def test_per_language_logging_shows_config_path(
+ self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path
+ ) -> None:
+ py_config = make_lang_config(tmp_path, Language.PYTHON)
+ mock_find_configs.return_value = [py_config]
+ mock_parse_args.return_value = make_base_args(disable_telemetry=False)
+
+ with patch("codeflash.main._log_orchestration_summary"):
+ from codeflash.main import main
+
+ with patch("logging.Logger.info") as mock_log_info:
+ main()
+ logged_messages = [str(call) for call in mock_log_info.call_args_list]
+ processing_logs = [m for m in logged_messages if "Processing" in m and "config:" in m]
+ assert len(processing_logs) >= 1
+
+
+class TestGitRepoDetectionEdgeCases:
+ def test_check_running_in_git_repo_nonexistent_path(self) -> None:
+ from codeflash.code_utils.git_utils import check_running_in_git_repo
+
+ assert check_running_in_git_repo("/nonexistent/path/that/does/not/exist") is False
+
+ def test_check_running_in_git_repo_none_uses_cwd(self) -> None:
+ from codeflash.code_utils.git_utils import check_running_in_git_repo
+
+ # None defaults to CWD, which is inside the codeflash git repo
+ assert check_running_in_git_repo(None) is True
+
+ def test_handle_optimize_all_git_remote_defaults_to_origin(self) -> None:
+ import git as git_module
+
+ from codeflash.cli_cmds.cli import handle_optimize_all_arg_parsing
+
+ args = make_base_args(file="/some/file.java", no_pr=False)
+ # Remove git_remote to simulate multi-config path where config hasn't loaded
+ if hasattr(args, "git_remote"):
+ delattr(args, "git_remote")
+
+ mock_repo = MagicMock(spec=git_module.Repo)
+ with (
+ patch("git.Repo", return_value=mock_repo),
+ patch(
+ "codeflash.code_utils.git_utils.check_and_push_branch", return_value=True
+ ) as mock_push,
+ patch("codeflash.code_utils.git_utils.get_repo_owner_and_name", return_value=("owner", "repo")),
+ patch("codeflash.code_utils.github_utils.require_github_app_or_exit"),
+ ):
+ handle_optimize_all_arg_parsing(args)
+ mock_push.assert_called_once()
+ assert mock_push.call_args[1].get("git_remote") == "origin"
+
+ def test_handle_optimize_all_no_such_path_error(self) -> None:
+ import git as git_module
+
+ from codeflash.cli_cmds.cli import handle_optimize_all_arg_parsing
+
+ args = make_base_args(file="/some/file.java", no_pr=False)
+
+ with patch("git.Repo", side_effect=git_module.exc.NoSuchPathError("/bad/path")):
+ with pytest.raises(SystemExit):
+ handle_optimize_all_arg_parsing(args)
diff --git a/tests/test_spurious_java_config.py b/tests/test_spurious_java_config.py
new file mode 100644
index 000000000..682e8a763
--- /dev/null
+++ b/tests/test_spurious_java_config.py
@@ -0,0 +1,142 @@
+"""Test that spurious Java configs (like codeflash-java-runtime/) don't crash --file optimizations.
+
+Reproduces the bug where running:
+ codeflash --file tests/.../Calculator.java --module-root tests/.../src/main/java --tests-root tests/.../src/test/java
+
+from a repo that contains codeflash-java-runtime/ (which has a pom.xml) crashes with:
+ ValueError: File .../Calculator.java is not within the project root .../codeflash-java-runtime
+"""
+
+from __future__ import annotations
+
+from pathlib import Path
+from unittest.mock import patch
+
+import tomlkit
+
+from codeflash.code_utils.config_parser import LanguageConfig, find_all_config_files
+from codeflash.languages.language_enum import Language
+
+
+def write_toml(path: Path, data: dict) -> None:
+ path.write_text(tomlkit.dumps(data), encoding="utf-8")
+
+
+class TestSpuriousJavaConfigDiscovery:
+ def test_subdirectory_with_pom_picked_up_as_java_config(self, tmp_path: Path, monkeypatch) -> None:
+ """Verify the bug scenario: a subdir with pom.xml gets picked up as Java config."""
+ # Root has a pyproject.toml (Python project, like codeflash itself)
+ write_toml(tmp_path / "pyproject.toml", {"tool": {"codeflash": {"module-root": "src"}}})
+ (tmp_path / "src").mkdir()
+
+ # Subdirectory mimicking codeflash-java-runtime (has pom.xml)
+ runtime_dir = tmp_path / "codeflash-java-runtime"
+ runtime_dir.mkdir()
+ (runtime_dir / "pom.xml").write_text("", encoding="utf-8")
+
+ java_config = {"language": "java", "module_root": str(runtime_dir / "src/main/java")}
+ monkeypatch.chdir(tmp_path)
+
+ with patch("codeflash.code_utils.config_parser._parse_java_config_for_dir", return_value=java_config):
+ result = find_all_config_files()
+
+ # This demonstrates the bug: codeflash-java-runtime gets picked up
+ java_configs = [r for r in result if r.language == Language.JAVA]
+ assert len(java_configs) == 1
+ assert java_configs[0].config_path == runtime_dir
+
+ def test_file_flag_with_spurious_java_config_should_not_crash(self, tmp_path: Path, monkeypatch) -> None:
+ """The actual bug: --file Calculator.java crashes because project_root points to codeflash-java-runtime."""
+ # Setup: Python project at root with codeflash-java-runtime subdir
+ write_toml(tmp_path / "pyproject.toml", {"tool": {"codeflash": {"module-root": "codeflash"}}})
+ (tmp_path / "codeflash").mkdir()
+
+ # Spurious Java subdir (like codeflash-java-runtime)
+ runtime_dir = tmp_path / "codeflash-java-runtime"
+ runtime_dir.mkdir()
+ (runtime_dir / "pom.xml").write_text("", encoding="utf-8")
+ (runtime_dir / "src" / "main" / "java").mkdir(parents=True)
+
+ # Actual target: Java fixture in a completely different location
+ fixture_dir = tmp_path / "tests" / "fixtures" / "java_maven"
+ (fixture_dir / "src" / "main" / "java" / "com" / "example").mkdir(parents=True)
+ (fixture_dir / "src" / "test" / "java").mkdir(parents=True)
+ target_file = fixture_dir / "src" / "main" / "java" / "com" / "example" / "Calculator.java"
+ target_file.write_text(
+ "public class Calculator { public int add(int a, int b) { return a + b; } }", encoding="utf-8"
+ )
+
+ monkeypatch.chdir(tmp_path)
+
+ runtime_java_config = {"language": "java", "module_root": str(runtime_dir / "src" / "main" / "java")}
+
+ from codeflash.cli_cmds.cli import apply_language_config
+
+ # Simulate what main() does: discover configs, filter by language, apply config
+ with patch("codeflash.code_utils.config_parser._parse_java_config_for_dir", return_value=runtime_java_config):
+ configs = find_all_config_files()
+
+ java_configs = [c for c in configs if c.language == Language.JAVA]
+ assert len(java_configs) == 1
+
+ # Now simulate what happens: user provided --file and --module-root explicitly
+ from tests.test_multi_language_orchestration import make_base_args
+
+ args = make_base_args(
+ file=str(target_file),
+ module_root=str(fixture_dir / "src" / "main" / "java"),
+ tests_root=str(fixture_dir / "src" / "test" / "java"),
+ )
+
+ # This is where it crashes: apply_language_config sets project_root to
+ # codeflash-java-runtime/ (config_path), then later module_name_from_file_path
+ # fails because Calculator.java is not within codeflash-java-runtime/
+ result = apply_language_config(args, java_configs[0])
+
+ # The bug: project_root is set to the spurious config path, not the user's target
+ # After the fix, the file should be within project_root
+ resolved_file = Path(args.file).resolve()
+ assert resolved_file.is_relative_to(result.project_root), (
+ f"File {resolved_file} is not within project_root {result.project_root}"
+ )
+
+
+class TestFileNotWithinDiscoveredProjectRoot:
+ def test_orchestrator_skips_config_when_file_outside_project_root(self, tmp_path: Path, monkeypatch) -> None:
+ """When --file points to a file outside a discovered config's project root, skip that config."""
+ # Two Java configs: one correct, one spurious
+ correct_dir = tmp_path / "my-java-project"
+ (correct_dir / "src" / "main" / "java").mkdir(parents=True)
+ (correct_dir / "src" / "test" / "java").mkdir(parents=True)
+ target_file = correct_dir / "src" / "main" / "java" / "Foo.java"
+ target_file.write_text("public class Foo {}", encoding="utf-8")
+
+ spurious_dir = tmp_path / "runtime-lib"
+ (spurious_dir / "src" / "main" / "java").mkdir(parents=True)
+ (spurious_dir / "src" / "test" / "java").mkdir(parents=True)
+
+ correct_config = LanguageConfig(
+ config={
+ "module_root": str(correct_dir / "src/main/java"),
+ "tests_root": str(correct_dir / "src/test/java"),
+ },
+ config_path=correct_dir,
+ language=Language.JAVA,
+ )
+ spurious_config = LanguageConfig(
+ config={
+ "module_root": str(spurious_dir / "src/main/java"),
+ "tests_root": str(spurious_dir / "src/test/java"),
+ },
+ config_path=spurious_dir,
+ language=Language.JAVA,
+ )
+
+ monkeypatch.chdir(tmp_path)
+
+ from codeflash.main import filter_configs_for_file
+
+ # After the fix, this function should exist and filter out spurious configs
+ filtered = filter_configs_for_file([spurious_config, correct_config], str(target_file))
+ assert len(filtered) == 1
+ assert filtered[0].config_path == correct_dir