Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 49 additions & 15 deletions codeflash/languages/javascript/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def extract_code_context(self, function: FunctionToOptimize, project_root: Path,
# Find the class definition in the source to get proper indentation, JSDoc, constructor, and fields
class_info = self._find_class_definition(source, class_name, analyzer, function.function_name)
if class_info:
class_jsdoc, class_indent, constructor_code, fields_code = class_info
class_jsdoc, class_indent, constructor_code, fields_code, is_exported = class_info
# Build the class body with fields, constructor, target method, and same-class helpers
class_body_parts = []
if fields_code:
Expand All @@ -364,12 +364,14 @@ def extract_code_context(self, function: FunctionToOptimize, project_root: Path,
class_body = "\n".join(class_body_parts)

# Wrap the method in a class definition with context
# Include 'export' keyword if the class is exported
export_keyword = "export " if is_exported else ""
if class_jsdoc:
target_code = f"{class_jsdoc}\n{class_indent}{export_keyword}class {class_name} {{\n{class_body}{class_indent}}}\n"
else:
target_code = (
f"{class_jsdoc}\n{class_indent}class {class_name} {{\n{class_body}{class_indent}}}\n"
f"{class_indent}{export_keyword}class {class_name} {{\n{class_body}{class_indent}}}\n"
)
else:
target_code = f"{class_indent}class {class_name} {{\n{class_body}{class_indent}}}\n"
else:
# Fallback: wrap with no indentation, including same-class helpers
helper_code = "\n".join(h[1] for h in same_class_helpers)
Expand Down Expand Up @@ -432,7 +434,7 @@ def extract_code_context(self, function: FunctionToOptimize, project_root: Path,

def _find_class_definition(
self, source: str, class_name: str, analyzer: TreeSitterAnalyzer, target_method_name: str | None = None
) -> tuple[str, str, str, str] | None:
) -> tuple[str, str, str, str, bool] | None:
"""Find a class definition and extract its JSDoc, indentation, constructor, and fields.

Args:
Expand All @@ -442,7 +444,7 @@ def _find_class_definition(
target_method_name: Name of the target method (to exclude from extracted context).

Returns:
Tuple of (jsdoc_comment, indentation, constructor_code, fields_code) or None if not found.
Tuple of (jsdoc_comment, indentation, constructor_code, fields_code, is_exported) or None if not found.
Constructor and fields are included to provide context for method optimization.

"""
Expand All @@ -467,6 +469,11 @@ def find_class_node(node):
if not class_node:
return None

# Check if the class is exported by examining its parent node
is_exported = False
if class_node.parent and class_node.parent.type == "export_statement":
is_exported = True

# Get indentation from the class line
lines = source.splitlines(keepends=True)
class_line_idx = class_node.start_point[0]
Expand Down Expand Up @@ -495,7 +502,7 @@ def find_class_node(node):
body_node, source_bytes, lines, target_method_name
)

return (jsdoc, indentation, constructor_code, fields_code)
return (jsdoc, indentation, constructor_code, fields_code, is_exported)

def _extract_class_context(
self, body_node: Any, source_bytes: bytes, lines: list[str], target_method_name: str | None
Expand Down Expand Up @@ -2164,8 +2171,9 @@ def find_test_root(self, project_root: Path) -> Path | None:
def get_module_path(self, source_file: Path, project_root: Path, tests_root: Path | None = None) -> str:
"""Get the module path for importing a JavaScript source file from tests.

For JavaScript, this returns a relative path from the tests directory to the source file
(e.g., '../fibonacci' for source at /project/fibonacci.js and tests at /project/tests/).
For JavaScript/TypeScript, this returns a relative path from the tests directory to
the source file. For ESM projects or TypeScript, the path includes a .js extension
(TypeScript convention). For CommonJS, no extension is added.

Args:
source_file: Path to the source file.
Expand All @@ -2179,13 +2187,15 @@ def get_module_path(self, source_file: Path, project_root: Path, tests_root: Pat
import os

from codeflash.cli_cmds.console import logger
from codeflash.languages.javascript.module_system import ModuleSystem, detect_module_system

if tests_root is None:
tests_root = self.find_test_root(project_root) or project_root

try:
# Resolve both paths to absolute to ensure consistent relative path calculation
source_file_abs = source_file.resolve().with_suffix("")
# Note: Don't remove extension yet - we'll decide based on module system
source_file_abs = source_file.resolve()
tests_root_abs = tests_root.resolve()

# Find the project root using language support
Expand All @@ -2205,16 +2215,40 @@ def get_module_path(self, source_file: Path, project_root: Path, tests_root: Pat
if not tests_root_abs.exists():
tests_root_abs = project_root_from_lang

# Detect module system to determine if we need to add .js extension
module_system = detect_module_system(project_root, source_file)

# Remove source file extension first
source_without_ext = source_file_abs.with_suffix("")

# Use os.path.relpath to compute relative path from tests_root to source file
rel_path = os.path.relpath(str(source_file_abs), str(tests_root_abs))
logger.debug(
f"!lsp|Module path: source={source_file_abs}, tests_root={tests_root_abs}, rel_path={rel_path}"
)
rel_path = os.path.relpath(str(source_without_ext), str(tests_root_abs))

# For ESM, add .js extension (TypeScript convention)
# TypeScript requires imports to reference the OUTPUT file extension (.js),
# even when the source file is .ts. This is required for Node.js ESM resolution.
if module_system == ModuleSystem.ES_MODULE:
rel_path = rel_path + ".js"
logger.debug(
f"!lsp|Module path (ESM): source={source_file_abs}, tests_root={tests_root_abs}, "
f"rel_path={rel_path} (added .js for ESM)"
)
else:
logger.debug(
f"!lsp|Module path (CommonJS): source={source_file_abs}, tests_root={tests_root_abs}, "
f"rel_path={rel_path}"
)

return rel_path
except ValueError:
# Fallback if paths are on different drives (Windows)
rel_path = source_file.relative_to(project_root)
return "../" + rel_path.with_suffix("").as_posix()
# For fallback, also check module system
module_system = detect_module_system(project_root, source_file)
path_without_ext = "../" + rel_path.with_suffix("").as_posix()
if module_system == ModuleSystem.ES_MODULE:
return path_without_ext + ".js"
return path_without_ext

def verify_requirements(self, project_root: Path, test_framework: str = "jest") -> tuple[bool, list[str]]:
"""Verify that all JavaScript requirements are met.
Expand Down
166 changes: 166 additions & 0 deletions tests/test_export_keyword_preservation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""Test that export keywords are preserved during code extraction."""

import pytest
from pathlib import Path
from codeflash.languages.javascript.support import JavaScriptSupport
from codeflash.models.function_types import FunctionToOptimize


class TestExportKeywordPreservation:
"""Test export keyword is preserved when extracting code context."""

def test_export_class_includes_export_keyword(self, tmp_path: Path):
"""Test that exported classes include the export keyword in extracted code."""
# Arrange: Create a test file with an exported class
test_file = tmp_path / "test.ts"
test_code = """export class WsContextCreator {
public getMetadata(instance: any, methodName: string): any {
return { test: true };
}
}"""
test_file.write_text(test_code)

# Create a FunctionToOptimize for the class method
from codeflash.models.function_types import FunctionParent
function = FunctionToOptimize(
function_name="getMetadata",
file_path=test_file,
parents=[FunctionParent(type="ClassDef", name="WsContextCreator")],
starting_line=2, # The method starts at line 2
ending_line=4, # The method ends at line 4
starting_col=2,
ending_col=3,
is_async=False,
is_method=True,
language="typescript",
doc_start_line=None,
)

# Act: Extract code context
support = JavaScriptSupport()
context = support.extract_code_context(function, tmp_path, tmp_path)

# Assert: The extracted code should include the export keyword
assert "export" in context.target_code, (
f"Export keyword missing from extracted code. "
f"Expected code to start with 'export class', but got:\n{context.target_code}"
)
assert "export class WsContextCreator" in context.target_code, (
f"Expected 'export class WsContextCreator' in extracted code, but got:\n{context.target_code}"
)

def test_export_function_includes_export_keyword(self, tmp_path: Path):
"""Test that exported functions include the export keyword in extracted code."""
# Arrange: Create a test file with an exported function
test_file = tmp_path / "test.ts"
test_code = """export function helperFunction(a: number, b: number): number {
return a + b;
}"""
test_file.write_text(test_code)

# Create a FunctionToOptimize for the function
function = FunctionToOptimize(
function_name="helperFunction",
file_path=test_file,
parents=[],
starting_line=1,
ending_line=3,
starting_col=0,
ending_col=1,
is_async=False,
is_method=False,
language="typescript",
doc_start_line=None,
)

# Act: Extract code context
support = JavaScriptSupport()
context = support.extract_code_context(function, tmp_path, tmp_path)

# Assert: The extracted code should include the export keyword
assert "export" in context.target_code, (
f"Export keyword missing from extracted code for function. "
f"Expected code to start with 'export function', but got:\n{context.target_code}"
)
assert "export function helperFunction" in context.target_code, (
f"Expected 'export function helperFunction' in extracted code, but got:\n{context.target_code}"
)

def test_export_const_arrow_function_includes_export(self, tmp_path: Path):
"""Test that exported const arrow functions include the export keyword."""
# Arrange: Create a test file with an exported const arrow function
test_file = tmp_path / "test.ts"
test_code = """export const multiply = (a: number, b: number): number => {
return a * b;
};"""
test_file.write_text(test_code)

# Create a FunctionToOptimize for the arrow function
function = FunctionToOptimize(
function_name="multiply",
file_path=test_file,
parents=[],
starting_line=1,
ending_line=3,
starting_col=0,
ending_col=2,
is_async=False,
is_method=False,
language="typescript",
doc_start_line=None,
)

# Act: Extract code context
support = JavaScriptSupport()
context = support.extract_code_context(function, tmp_path, tmp_path)

# Assert: The extracted code should include the export keyword
assert "export" in context.target_code, (
f"Export keyword missing from exported const arrow function. "
f"Expected code to start with 'export const', but got:\n{context.target_code}"
)
assert "export const multiply" in context.target_code, (
f"Expected 'export const multiply' in extracted code, but got:\n{context.target_code}"
)

def test_non_exported_class_unchanged(self, tmp_path: Path):
"""Test that non-exported classes work correctly (baseline test)."""
# Arrange: Create a test file with a NON-exported class
test_file = tmp_path / "test.ts"
test_code = """class InternalHelper {
process(): void {
console.log('internal');
}
}"""
test_file.write_text(test_code)

# Create a FunctionToOptimize for the method
from codeflash.models.function_types import FunctionParent
function = FunctionToOptimize(
function_name="process",
file_path=test_file,
parents=[FunctionParent(type="ClassDef", name="InternalHelper")],
starting_line=2,
ending_line=4,
starting_col=2,
ending_col=3,
is_async=False,
is_method=True,
language="typescript",
doc_start_line=None,
)

# Act: Extract code context
support = JavaScriptSupport()
context = support.extract_code_context(function, tmp_path, tmp_path)

# Assert: The extracted code should NOT include export (it's not exported)
# But it should include the class definition
assert "class InternalHelper" in context.target_code, (
f"Expected 'class InternalHelper' in extracted code, but got:\n{context.target_code}"
)
# Should not start with export
stripped_code = context.target_code.lstrip()
assert not stripped_code.startswith("export"), (
f"Non-exported class should not start with 'export', but got:\n{context.target_code}"
)
Loading
Loading