diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index db96c4df1..b172f3bb9 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -350,7 +350,7 @@ def extract_code_context(self, function: FunctionToOptimize, project_root: Path, # Find the class definition in the source to get proper indentation, JSDoc, constructor, and fields class_info = self._find_class_definition(source, class_name, analyzer, function.function_name) if class_info: - class_jsdoc, class_indent, constructor_code, fields_code = class_info + class_jsdoc, class_indent, constructor_code, fields_code, is_exported = class_info # Build the class body with fields, constructor, target method, and same-class helpers class_body_parts = [] if fields_code: @@ -364,12 +364,14 @@ def extract_code_context(self, function: FunctionToOptimize, project_root: Path, class_body = "\n".join(class_body_parts) # Wrap the method in a class definition with context + # Include 'export' keyword if the class is exported + export_keyword = "export " if is_exported else "" if class_jsdoc: + target_code = f"{class_jsdoc}\n{class_indent}{export_keyword}class {class_name} {{\n{class_body}{class_indent}}}\n" + else: target_code = ( - f"{class_jsdoc}\n{class_indent}class {class_name} {{\n{class_body}{class_indent}}}\n" + f"{class_indent}{export_keyword}class {class_name} {{\n{class_body}{class_indent}}}\n" ) - else: - target_code = f"{class_indent}class {class_name} {{\n{class_body}{class_indent}}}\n" else: # Fallback: wrap with no indentation, including same-class helpers helper_code = "\n".join(h[1] for h in same_class_helpers) @@ -432,7 +434,7 @@ def extract_code_context(self, function: FunctionToOptimize, project_root: Path, def _find_class_definition( self, source: str, class_name: str, analyzer: TreeSitterAnalyzer, target_method_name: str | None = None - ) -> tuple[str, str, str, str] | None: + ) -> tuple[str, str, str, str, bool] | None: """Find a class definition and extract its JSDoc, indentation, constructor, and fields. Args: @@ -442,7 +444,7 @@ def _find_class_definition( target_method_name: Name of the target method (to exclude from extracted context). Returns: - Tuple of (jsdoc_comment, indentation, constructor_code, fields_code) or None if not found. + Tuple of (jsdoc_comment, indentation, constructor_code, fields_code, is_exported) or None if not found. Constructor and fields are included to provide context for method optimization. """ @@ -467,6 +469,11 @@ def find_class_node(node): if not class_node: return None + # Check if the class is exported by examining its parent node + is_exported = False + if class_node.parent and class_node.parent.type == "export_statement": + is_exported = True + # Get indentation from the class line lines = source.splitlines(keepends=True) class_line_idx = class_node.start_point[0] @@ -495,7 +502,7 @@ def find_class_node(node): body_node, source_bytes, lines, target_method_name ) - return (jsdoc, indentation, constructor_code, fields_code) + return (jsdoc, indentation, constructor_code, fields_code, is_exported) def _extract_class_context( self, body_node: Any, source_bytes: bytes, lines: list[str], target_method_name: str | None @@ -2164,8 +2171,9 @@ def find_test_root(self, project_root: Path) -> Path | None: def get_module_path(self, source_file: Path, project_root: Path, tests_root: Path | None = None) -> str: """Get the module path for importing a JavaScript source file from tests. - For JavaScript, this returns a relative path from the tests directory to the source file - (e.g., '../fibonacci' for source at /project/fibonacci.js and tests at /project/tests/). + For JavaScript/TypeScript, this returns a relative path from the tests directory to + the source file. For ESM projects or TypeScript, the path includes a .js extension + (TypeScript convention). For CommonJS, no extension is added. Args: source_file: Path to the source file. @@ -2179,13 +2187,15 @@ def get_module_path(self, source_file: Path, project_root: Path, tests_root: Pat import os from codeflash.cli_cmds.console import logger + from codeflash.languages.javascript.module_system import ModuleSystem, detect_module_system if tests_root is None: tests_root = self.find_test_root(project_root) or project_root try: # Resolve both paths to absolute to ensure consistent relative path calculation - source_file_abs = source_file.resolve().with_suffix("") + # Note: Don't remove extension yet - we'll decide based on module system + source_file_abs = source_file.resolve() tests_root_abs = tests_root.resolve() # Find the project root using language support @@ -2205,16 +2215,40 @@ def get_module_path(self, source_file: Path, project_root: Path, tests_root: Pat if not tests_root_abs.exists(): tests_root_abs = project_root_from_lang + # Detect module system to determine if we need to add .js extension + module_system = detect_module_system(project_root, source_file) + + # Remove source file extension first + source_without_ext = source_file_abs.with_suffix("") + # Use os.path.relpath to compute relative path from tests_root to source file - rel_path = os.path.relpath(str(source_file_abs), str(tests_root_abs)) - logger.debug( - f"!lsp|Module path: source={source_file_abs}, tests_root={tests_root_abs}, rel_path={rel_path}" - ) + rel_path = os.path.relpath(str(source_without_ext), str(tests_root_abs)) + + # For ESM, add .js extension (TypeScript convention) + # TypeScript requires imports to reference the OUTPUT file extension (.js), + # even when the source file is .ts. This is required for Node.js ESM resolution. + if module_system == ModuleSystem.ES_MODULE: + rel_path = rel_path + ".js" + logger.debug( + f"!lsp|Module path (ESM): source={source_file_abs}, tests_root={tests_root_abs}, " + f"rel_path={rel_path} (added .js for ESM)" + ) + else: + logger.debug( + f"!lsp|Module path (CommonJS): source={source_file_abs}, tests_root={tests_root_abs}, " + f"rel_path={rel_path}" + ) + return rel_path except ValueError: # Fallback if paths are on different drives (Windows) rel_path = source_file.relative_to(project_root) - return "../" + rel_path.with_suffix("").as_posix() + # For fallback, also check module system + module_system = detect_module_system(project_root, source_file) + path_without_ext = "../" + rel_path.with_suffix("").as_posix() + if module_system == ModuleSystem.ES_MODULE: + return path_without_ext + ".js" + return path_without_ext def verify_requirements(self, project_root: Path, test_framework: str = "jest") -> tuple[bool, list[str]]: """Verify that all JavaScript requirements are met. diff --git a/tests/test_export_keyword_preservation.py b/tests/test_export_keyword_preservation.py new file mode 100644 index 000000000..ca1344c9e --- /dev/null +++ b/tests/test_export_keyword_preservation.py @@ -0,0 +1,166 @@ +"""Test that export keywords are preserved during code extraction.""" + +import pytest +from pathlib import Path +from codeflash.languages.javascript.support import JavaScriptSupport +from codeflash.models.function_types import FunctionToOptimize + + +class TestExportKeywordPreservation: + """Test export keyword is preserved when extracting code context.""" + + def test_export_class_includes_export_keyword(self, tmp_path: Path): + """Test that exported classes include the export keyword in extracted code.""" + # Arrange: Create a test file with an exported class + test_file = tmp_path / "test.ts" + test_code = """export class WsContextCreator { + public getMetadata(instance: any, methodName: string): any { + return { test: true }; + } +}""" + test_file.write_text(test_code) + + # Create a FunctionToOptimize for the class method + from codeflash.models.function_types import FunctionParent + function = FunctionToOptimize( + function_name="getMetadata", + file_path=test_file, + parents=[FunctionParent(type="ClassDef", name="WsContextCreator")], + starting_line=2, # The method starts at line 2 + ending_line=4, # The method ends at line 4 + starting_col=2, + ending_col=3, + is_async=False, + is_method=True, + language="typescript", + doc_start_line=None, + ) + + # Act: Extract code context + support = JavaScriptSupport() + context = support.extract_code_context(function, tmp_path, tmp_path) + + # Assert: The extracted code should include the export keyword + assert "export" in context.target_code, ( + f"Export keyword missing from extracted code. " + f"Expected code to start with 'export class', but got:\n{context.target_code}" + ) + assert "export class WsContextCreator" in context.target_code, ( + f"Expected 'export class WsContextCreator' in extracted code, but got:\n{context.target_code}" + ) + + def test_export_function_includes_export_keyword(self, tmp_path: Path): + """Test that exported functions include the export keyword in extracted code.""" + # Arrange: Create a test file with an exported function + test_file = tmp_path / "test.ts" + test_code = """export function helperFunction(a: number, b: number): number { + return a + b; +}""" + test_file.write_text(test_code) + + # Create a FunctionToOptimize for the function + function = FunctionToOptimize( + function_name="helperFunction", + file_path=test_file, + parents=[], + starting_line=1, + ending_line=3, + starting_col=0, + ending_col=1, + is_async=False, + is_method=False, + language="typescript", + doc_start_line=None, + ) + + # Act: Extract code context + support = JavaScriptSupport() + context = support.extract_code_context(function, tmp_path, tmp_path) + + # Assert: The extracted code should include the export keyword + assert "export" in context.target_code, ( + f"Export keyword missing from extracted code for function. " + f"Expected code to start with 'export function', but got:\n{context.target_code}" + ) + assert "export function helperFunction" in context.target_code, ( + f"Expected 'export function helperFunction' in extracted code, but got:\n{context.target_code}" + ) + + def test_export_const_arrow_function_includes_export(self, tmp_path: Path): + """Test that exported const arrow functions include the export keyword.""" + # Arrange: Create a test file with an exported const arrow function + test_file = tmp_path / "test.ts" + test_code = """export const multiply = (a: number, b: number): number => { + return a * b; +};""" + test_file.write_text(test_code) + + # Create a FunctionToOptimize for the arrow function + function = FunctionToOptimize( + function_name="multiply", + file_path=test_file, + parents=[], + starting_line=1, + ending_line=3, + starting_col=0, + ending_col=2, + is_async=False, + is_method=False, + language="typescript", + doc_start_line=None, + ) + + # Act: Extract code context + support = JavaScriptSupport() + context = support.extract_code_context(function, tmp_path, tmp_path) + + # Assert: The extracted code should include the export keyword + assert "export" in context.target_code, ( + f"Export keyword missing from exported const arrow function. " + f"Expected code to start with 'export const', but got:\n{context.target_code}" + ) + assert "export const multiply" in context.target_code, ( + f"Expected 'export const multiply' in extracted code, but got:\n{context.target_code}" + ) + + def test_non_exported_class_unchanged(self, tmp_path: Path): + """Test that non-exported classes work correctly (baseline test).""" + # Arrange: Create a test file with a NON-exported class + test_file = tmp_path / "test.ts" + test_code = """class InternalHelper { + process(): void { + console.log('internal'); + } +}""" + test_file.write_text(test_code) + + # Create a FunctionToOptimize for the method + from codeflash.models.function_types import FunctionParent + function = FunctionToOptimize( + function_name="process", + file_path=test_file, + parents=[FunctionParent(type="ClassDef", name="InternalHelper")], + starting_line=2, + ending_line=4, + starting_col=2, + ending_col=3, + is_async=False, + is_method=True, + language="typescript", + doc_start_line=None, + ) + + # Act: Extract code context + support = JavaScriptSupport() + context = support.extract_code_context(function, tmp_path, tmp_path) + + # Assert: The extracted code should NOT include export (it's not exported) + # But it should include the class definition + assert "class InternalHelper" in context.target_code, ( + f"Expected 'class InternalHelper' in extracted code, but got:\n{context.target_code}" + ) + # Should not start with export + stripped_code = context.target_code.lstrip() + assert not stripped_code.startswith("export"), ( + f"Non-exported class should not start with 'export', but got:\n{context.target_code}" + ) diff --git a/tests/test_languages/test_javascript_support.py b/tests/test_languages/test_javascript_support.py index be440d7ae..4b27b98bd 100644 --- a/tests/test_languages/test_javascript_support.py +++ b/tests/test_languages/test_javascript_support.py @@ -756,8 +756,8 @@ def test_extract_class_method_wraps_in_class(self, js_support): context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) # Full string equality check for exact extraction output - # Note: export keyword is not included in extracted class wrapper - expected_code = """class Calculator { + # Export keyword IS included for exported classes + expected_code = """export class Calculator { add(a, b) { return a + b; } @@ -793,9 +793,9 @@ def test_extract_class_method_with_jsdoc(self, js_support): context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) # Full string equality check - includes class JSDoc, class definition, method JSDoc, and method - # Note: export keyword is not included in extracted class wrapper + # Export keyword IS included for exported classes # Note: Class-level JSDoc is not included when extracting a method - expected_code = """class Calculator { + expected_code = """export class Calculator { /** * Adds two numbers. * @param {number} a - First number @@ -831,8 +831,8 @@ def test_extract_class_method_syntax_valid(self, js_support): context = js_support.extract_code_context(fib_method, file_path.parent, file_path.parent) # Full string equality check - # Note: export keyword is not included in extracted class wrapper - expected_code = """class FibonacciCalculator { + # Export keyword IS included for exported classes + expected_code = """export class FibonacciCalculator { fibonacci(n) { if (n <= 1) { return n; @@ -871,8 +871,8 @@ def test_extract_nested_class_method(self, js_support): context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) # Full string equality check - # Note: export keyword is not included in extracted class wrapper - expected_code = """class Outer { + # Export keyword IS included for exported classes + expected_code = """export class Outer { add(a, b) { return a + b; } @@ -900,8 +900,8 @@ def test_extract_async_class_method(self, js_support): context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent) # Full string equality check - # Note: export keyword is not included in extracted class wrapper - expected_code = """class ApiClient { + # Export keyword IS included for exported classes + expected_code = """export class ApiClient { async fetchData(url) { const response = await fetch(url); return response.json(); @@ -934,8 +934,8 @@ def test_extract_static_class_method(self, js_support): context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) # Full string equality check - # Note: export keyword is not included in extracted class wrapper - expected_code = """class MathUtils { + # Export keyword IS included for exported classes + expected_code = """export class MathUtils { static add(a, b) { return a + b; } @@ -962,8 +962,8 @@ def test_extract_class_method_without_class_jsdoc(self, js_support): context = js_support.extract_code_context(method, file_path.parent, file_path.parent) # Full string equality check - # Note: export keyword is not included in extracted class wrapper - expected_code = """class SimpleClass { + # Export keyword IS included for exported classes + expected_code = """export class SimpleClass { simpleMethod() { return "hello"; } @@ -1208,8 +1208,8 @@ def test_class_extending_another(self, js_support): context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent) # Full string equality check - # Note: export keyword is not included in extracted class wrapper - expected_code = """class Dog { + # Export keyword IS included for exported classes + expected_code = """export class Dog { fetch() { return 'ball'; } @@ -1334,8 +1334,9 @@ def test_extract_context_then_replace_method(self, js_support): context = js_support.extract_code_context(increment_func, file_path.parent, file_path.parent) # Verify extraction with exact string equality + # Export keyword IS included for exported classes expected_extraction = """\ -class Counter { +export class Counter { constructor(initial = 0) { this.count = initial; } @@ -1351,10 +1352,10 @@ class Counter { f"Expected:\n{expected_extraction}\n\nGot:\n{context.target_code}" ) - # Step 2: AI returns optimized code as FULL CLASS (not just method) + # Step 2: AI returns optimized code as FULL CLASS (with export) # This simulates what the AI would return - the full context with optimized method optimized_code_from_ai = """\ -class Counter { +export class Counter { constructor(initial = 0) { this.count = initial; } @@ -1431,8 +1432,9 @@ def test_typescript_extract_context_then_replace_method(self): context = ts_support.extract_code_context(get_name_func, file_path.parent, file_path.parent) # Verify extraction with exact string equality + # Export keyword IS included for exported classes expected_extraction = """\ -class User { +export class User { private name: string; private age: number; @@ -1451,9 +1453,9 @@ class User { f"Expected:\n{expected_extraction}\n\nGot:\n{context.target_code}" ) - # Step 2: AI returns optimized code as FULL CLASS + # Step 2: AI returns optimized code as FULL CLASS (with export) optimized_code_from_ai = """\ -class User { +export class User { private name: string; private age: number; @@ -1531,8 +1533,9 @@ def test_extract_replace_preserves_other_methods(self, js_support): context = js_support.extract_code_context(add_func, file_path.parent, file_path.parent) # Verify extraction with exact string equality + # Export keyword IS included for exported classes expected_extraction = """\ -class Calculator { +export class Calculator { constructor(precision = 2) { this.precision = precision; } @@ -1547,9 +1550,9 @@ class Calculator { f"Expected:\n{expected_extraction}\n\nGot:\n{context.target_code}" ) - # AI returns optimized code as FULL CLASS + # AI returns optimized code as FULL CLASS (with export) optimized_code_from_ai = """\ -class Calculator { +export class Calculator { constructor(precision = 2) { this.precision = precision; } @@ -1615,8 +1618,9 @@ def test_extract_static_method_then_replace(self, js_support): context = js_support.extract_code_context(add_func, file_path.parent, file_path.parent) # Verify extraction with exact string equality + # Export keyword IS included for exported classes expected_extraction = """\ -class MathUtils { +export class MathUtils { constructor() { this.cache = {}; } @@ -1631,9 +1635,9 @@ class MathUtils { f"Expected:\n{expected_extraction}\n\nGot:\n{context.target_code}" ) - # AI returns optimized code as FULL CLASS + # AI returns optimized code as FULL CLASS (with export) optimized_code_from_ai = """\ -class MathUtils { +export class MathUtils { constructor() { this.cache = {}; } @@ -1869,3 +1873,75 @@ def test_mixed_top_level_and_indented(self): test('works', () => {}); });""" assert fix_imports_inside_blocks(source) == expected + + +class TestGetModulePath: + """Tests for get_module_path method to ensure proper module resolution.""" + + def test_get_module_path_typescript_esm_adds_js_extension(self, js_support): + """Test that TypeScript files in ESM projects get .js extension in import paths. + + This is the TypeScript convention: imports reference the OUTPUT file extension (.js) + even when the source file is .ts. This is required for Node.js ESM resolution. + + Regression test for: ERR_MODULE_NOT_FOUND when importing TypeScript modules + Trace ID: 08d0e99e-10e6-4ad2-981d-b907e3c068ea + """ + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + # Create a TypeScript source file + source_dir = project_root / "packages" / "microservices" / "server" + source_dir.mkdir(parents=True) + source_file = source_dir / "server-factory.ts" + source_file.write_text("export class ServerFactory {}") + + # Create tests directory + tests_dir = project_root / "packages" / "microservices" / "test" / "codeflash-generated" + tests_dir.mkdir(parents=True) + + # Create package.json with type: module (ESM) + package_json = project_root / "package.json" + package_json.write_text('{"type": "module"}') + + # Get module path + module_path = js_support.get_module_path(source_file, project_root, tests_dir) + + # For ESM/TypeScript, the import path should end with .js + # This is TypeScript's convention: imports use .js extension even for .ts files + assert module_path.endswith(".js"), ( + f"Expected module path to end with .js for ESM/TypeScript, got: {module_path}. " + "Node.js ESM requires explicit file extensions in import statements." + ) + + # The path should be relative (start with ../ or ./) + assert module_path.startswith(("../", "./")), ( + f"Expected relative import path, got: {module_path}" + ) + + def test_get_module_path_commonjs_no_extension(self, js_support): + """Test that CommonJS projects get module paths without extensions.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + # Create a JavaScript source file + source_dir = project_root / "src" + source_dir.mkdir(parents=True) + source_file = source_dir / "utils.js" + source_file.write_text("module.exports = {}") + + # Create tests directory + tests_dir = project_root / "test" + tests_dir.mkdir(parents=True) + + # Create package.json WITHOUT type field (defaults to CommonJS) + package_json = project_root / "package.json" + package_json.write_text('{"name": "test-project"}') + + # Get module path + module_path = js_support.get_module_path(source_file, project_root, tests_dir) + + # For CommonJS, no extension is fine + assert not module_path.endswith((".js", ".ts", ".tsx")), ( + f"Expected module path without extension for CommonJS, got: {module_path}" + )