Skip to content
102 changes: 63 additions & 39 deletions codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
from __future__ import annotations

import logging
import os
import sys
from argparse import SUPPRESS, ArgumentParser, Namespace
from argparse import SUPPRESS, ArgumentParser
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any

from codeflash.cli_cmds import logging_config
from codeflash.cli_cmds.console import apologize_and_exit, logger
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_utils import exit_with_message, normalize_ignore_paths
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.languages import set_current_language
from codeflash.languages.language_enum import Language
from codeflash.languages.test_framework import set_current_test_framework
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.version import __version__ as version

if TYPE_CHECKING:
from argparse import Namespace

from codeflash.code_utils.config_parser import LanguageConfig


def parse_args() -> Namespace:
parser = _build_parser()
Expand Down Expand Up @@ -89,6 +99,23 @@ def process_pyproject_config(args: Namespace) -> Namespace:
pyproject_config, pyproject_file_path = parse_config_file(args.config_file)
except ValueError as e:
exit_with_message(f"Error parsing config file: {e}", error_on_exit=True)

language = None
lang_str = pyproject_config.get("language")
if lang_str:
language = Language(lang_str)

args = resolve_config_onto_args(args, pyproject_config, pyproject_file_path, language)

if is_LSP_enabled():
args.all = None
return args
return handle_optimize_all_arg_parsing(args)


def resolve_config_onto_args(
args: Namespace, config: dict[str, Any], config_path: Path, language: Language | None
) -> Namespace:
supported_keys = [
"module_root",
"tests_root",
Expand All @@ -102,26 +129,26 @@ def process_pyproject_config(args: Namespace) -> Namespace:
"override_fixtures",
]
for key in supported_keys:
if key in pyproject_config and (
(hasattr(args, key.replace("-", "_")) and getattr(args, key.replace("-", "_")) is None)
or not hasattr(args, key.replace("-", "_"))
attr_key = key.replace("-", "_")
if key in config and (
(hasattr(args, attr_key) and getattr(args, attr_key) is None) or not hasattr(args, attr_key)
):
setattr(args, key.replace("-", "_"), pyproject_config[key])
setattr(args, attr_key, config[key])

assert args.module_root is not None, "--module-root must be specified"
assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory"

# For JS/TS projects, tests_root is optional (Jest auto-discovers tests)
# Default to module_root if not specified
is_js_ts_project = pyproject_config.get("language") in ("javascript", "typescript")
is_java_project = pyproject_config.get("language") == "java"
is_js_ts = language in (Language.JAVASCRIPT, Language.TYPESCRIPT)
is_java = language == Language.JAVA

# Set the test framework singleton for JS/TS projects
if is_js_ts_project and pyproject_config.get("test_framework"):
set_current_test_framework(pyproject_config["test_framework"])
if language is not None:
set_current_language(language)

if is_js_ts and config.get("test_framework"):
set_current_test_framework(config["test_framework"])

if args.tests_root is None:
if is_java_project:
# Try standard Maven/Gradle test directories
if is_java:
for test_dir in ["src/test/java", "test", "tests"]:
test_path = Path(args.module_root).parent / test_dir if "/" in test_dir else Path(test_dir)
if not test_path.is_absolute():
Expand All @@ -131,29 +158,26 @@ def process_pyproject_config(args: Namespace) -> Namespace:
break
if args.tests_root is None:
args.tests_root = str(Path.cwd() / "src" / "test" / "java")
elif is_js_ts_project:
# Try common JS test directories at project root first
elif is_js_ts:
for test_dir in ["test", "tests", "__tests__"]:
if Path(test_dir).is_dir():
args.tests_root = test_dir
break
# If not found at project root, try inside module_root (e.g., src/test, src/__tests__)
if args.tests_root is None and args.module_root:
module_root_path = Path(args.module_root)
for test_dir in ["test", "tests", "__tests__"]:
test_path = module_root_path / test_dir
if test_path.is_dir():
args.tests_root = str(test_path)
break
# Final fallback: default to module_root
# Note: This may cause issues if tests are colocated with source files
# In such cases, the user should explicitly configure testsRoot in package.json
if args.tests_root is None:
args.tests_root = args.module_root
else:
raise AssertionError("--tests-root must be specified")

assert Path(args.tests_root).is_dir(), f"--tests-root {args.tests_root} must be a valid directory"
if args.benchmark:

if getattr(args, "benchmark", False):
assert args.benchmarks_root is not None, "--benchmarks-root must be specified when running with --benchmark"
assert Path(args.benchmarks_root).is_dir(), (
f"--benchmarks-root {args.benchmarks_root} must be a valid directory"
Expand All @@ -179,29 +203,25 @@ def process_pyproject_config(args: Namespace) -> Namespace:

require_github_app_or_exit(owner, repo_name)

# Project root path is one level above the specified directory, because that's where the module can be imported from
args.module_root = Path(args.module_root).resolve()
if hasattr(args, "ignore_paths") and args.ignore_paths is not None:
# Normalize ignore paths, supporting both literal paths and glob patterns
# Use module_root as base path for resolving relative paths and patterns
args.ignore_paths = normalize_ignore_paths(args.ignore_paths, base_path=args.module_root)
# If module-root is "." then all imports are relatives to it.
# in this case, the ".." becomes outside project scope, causing issues with un-importable paths
args.project_root = project_root_from_module_root(Path(args.module_root), pyproject_file_path)
args.project_root = project_root_from_module_root(Path(args.module_root), config_path)
args.tests_root = Path(args.tests_root).resolve()
if args.benchmarks_root:
args.benchmarks_root = Path(args.benchmarks_root).resolve()
args.test_project_root = project_root_from_module_root(args.tests_root, pyproject_file_path)
args.test_project_root = project_root_from_module_root(args.tests_root, config_path)

if is_java_project and pyproject_file_path.is_dir():
# For Java projects, pyproject_file_path IS the project root directory (not a file).
# Override project_root which may have resolved to a sub-module.
args.project_root = pyproject_file_path.resolve()
args.test_project_root = pyproject_file_path.resolve()
if is_LSP_enabled():
args.all = None
return args
return handle_optimize_all_arg_parsing(args)
if is_java and config_path.is_dir():
resolved_config = config_path.resolve()
try:
args.module_root.relative_to(resolved_config)
args.project_root = resolved_config
args.test_project_root = resolved_config
except ValueError:
pass

return args


def project_root_from_module_root(module_root: Path, pyproject_file_path: Path) -> Path:
Expand All @@ -221,6 +241,10 @@ def project_root_from_module_root(module_root: Path, pyproject_file_path: Path)
return module_root.parent.resolve()


def apply_language_config(args: Namespace, lang_config: LanguageConfig) -> Namespace:
return resolve_config_onto_args(args, lang_config.config, lang_config.config_path, lang_config.language)


def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
if hasattr(args, "all") or (hasattr(args, "file") and args.file):
no_pr = getattr(args, "no_pr", False)
Expand All @@ -234,14 +258,14 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
# Ensure that the user can actually open PRs on the repo.
try:
git_repo = git.Repo(search_parent_directories=True)
except git.exc.InvalidGitRepositoryError:
except (git.exc.InvalidGitRepositoryError, git.exc.NoSuchPathError):
mode = "--all" if hasattr(args, "all") else "--file"
logger.exception(
f"I couldn't find a git repository in the current directory. "
f"I need a git repository to run {mode} and open PRs for optimizations. Exiting..."
)
apologize_and_exit()
git_remote = getattr(args, "git_remote", None)
git_remote = getattr(args, "git_remote", None) or "origin"
if not check_and_push_branch(git_repo, git_remote=git_remote):
exit_with_message("Branch is not pushed...", error_on_exit=True)
owner, repo = get_repo_owner_and_name(git_repo)
Expand Down
Loading
Loading