diff --git a/codeflash/languages/java/context.py b/codeflash/languages/java/context.py index 338ac5102..1b5ffa74a 100644 --- a/codeflash/languages/java/context.py +++ b/codeflash/languages/java/context.py @@ -863,6 +863,10 @@ def extract_class_context(file_path: Path, class_name: str, analyzer: JavaAnalyz # Maximum token budget for imported type skeletons to avoid bloating testgen context IMPORTED_SKELETON_TOKEN_BUDGET = 4000 +# Maximum types to expand from a single wildcard import before filtering to referenced types only. +# Packages with more types than this (e.g. org.jooq with 870+) would waste minutes of disk I/O +# and almost always exceed the token budget. +MAX_WILDCARD_TYPES_UNFILTERED = 50 def _extract_type_names_from_code(code: str, analyzer: JavaAnalyzer) -> set[str]: @@ -932,11 +936,29 @@ def get_java_imported_type_skeletons( resolved_imports: list = [] for imp in imports: if imp.is_wildcard: - # Expand wildcard imports (e.g., com.aerospike.client.policy.*) into individual types - expanded = resolver.expand_wildcard_import(imp.import_path) + # First try unfiltered expansion with a cap. If the package is small enough, take all types. + # If it's huge (e.g. org.jooq.* with 870+ types), filter to only types referenced in the target code. + expanded = resolver.expand_wildcard_import(imp.import_path, max_types=MAX_WILDCARD_TYPES_UNFILTERED + 1) + if len(expanded) > MAX_WILDCARD_TYPES_UNFILTERED: + if priority_types: + expanded = resolver.expand_wildcard_import(imp.import_path, filter_names=priority_types) + logger.debug( + "Wildcard %s.* exceeds %d types, filtered to %d referenced types", + imp.import_path, + MAX_WILDCARD_TYPES_UNFILTERED, + len(expanded), + ) + else: + expanded = expanded[:MAX_WILDCARD_TYPES_UNFILTERED] + logger.debug( + "Wildcard %s.* exceeds %d types, capped (no target types to filter by)", + imp.import_path, + MAX_WILDCARD_TYPES_UNFILTERED, + ) + elif expanded: + logger.debug("Expanded wildcard import %s.* into %d types", imp.import_path, len(expanded)) if expanded: resolved_imports.extend(expanded) - logger.debug("Expanded wildcard import %s.* into %d types", imp.import_path, len(expanded)) continue resolved = resolver.resolve_import(imp) diff --git a/codeflash/languages/java/import_resolver.py b/codeflash/languages/java/import_resolver.py index a1f495555..1df20fb3c 100644 --- a/codeflash/languages/java/import_resolver.py +++ b/codeflash/languages/java/import_resolver.py @@ -220,14 +220,20 @@ def _extract_class_name(self, import_path: str) -> str | None: return last_part return None - def expand_wildcard_import(self, import_path: str) -> list[ResolvedImport]: + def expand_wildcard_import( + self, import_path: str, max_types: int = 0, filter_names: set[str] | None = None + ) -> list[ResolvedImport]: """Expand a wildcard import (e.g., com.example.utils.*) to individual class imports. Resolves the package path to a directory and returns a ResolvedImport for each .java file found in that directory. + + Args: + import_path: The package path (without the trailing .*). + max_types: Maximum number of types to return. 0 means no limit. + filter_names: If provided, only include types whose class name is in this set. + """ - # Convert package path to directory path - # e.g., "com.example.utils" -> "com/example/utils" relative_dir = import_path.replace(".", "/") resolved: list[ResolvedImport] = [] @@ -237,17 +243,21 @@ def expand_wildcard_import(self, import_path: str) -> list[ResolvedImport]: if candidate_dir.is_dir(): for java_file in candidate_dir.glob("*.java"): class_name = java_file.stem - # Only include files that look like class names (start with uppercase) - if class_name and class_name[0].isupper(): - resolved.append( - ResolvedImport( - import_path=f"{import_path}.{class_name}", - file_path=java_file, - is_external=False, - is_wildcard=False, - class_name=class_name, - ) + if not class_name or not class_name[0].isupper(): + continue + if filter_names is not None and class_name not in filter_names: + continue + resolved.append( + ResolvedImport( + import_path=f"{import_path}.{class_name}", + file_path=java_file, + is_external=False, + is_wildcard=False, + class_name=class_name, ) + ) + if max_types and len(resolved) >= max_types: + return resolved return resolved diff --git a/tests/test_languages/test_java/test_context.py b/tests/test_languages/test_java/test_context.py index 17dc1ca25..d9d771c01 100644 --- a/tests/test_languages/test_java/test_context.py +++ b/tests/test_languages/test_java/test_context.py @@ -2530,6 +2530,36 @@ def test_wildcard_imports_are_expanded(self): # Wildcard imports should now be expanded to individual classes found in the package directory assert "MathHelper" in result + def test_large_wildcard_is_filtered_to_referenced_types(self, tmp_path: Path): + """When wildcard expands to >50 types, only types referenced in target code are included.""" + from codeflash.languages.java.context import MAX_WILDCARD_TYPES_UNFILTERED + + # Create a minimal Maven project structure so the resolver finds source roots + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + pkg_dir = tmp_path / "src" / "main" / "java" / "com" / "bigpkg" + pkg_dir.mkdir(parents=True) + for i in range(MAX_WILDCARD_TYPES_UNFILTERED + 20): + (pkg_dir / f"Type{i:03d}.java").write_text( + f"package com.bigpkg;\npublic class Type{i:03d} {{ public int val() {{ return {i}; }} }}\n", + encoding="utf-8", + ) + + analyzer = get_java_analyzer() + # Target code references Type000 and Type001 only + target_code = "Type000 a = new Type000(); Type001 b = a.transform();" + source = "package com.example;\nimport com.bigpkg.*;\npublic class Foo { void bar() {} }" + imports = analyzer.find_imports(source) + + result = get_java_imported_type_skeletons( + imports, tmp_path, tmp_path / "src" / "main" / "java", analyzer, target_code=target_code + ) + + # Only referenced types should appear, not all 70 + assert "Type000" in result + assert "Type001" in result + # Types not referenced in target code should be excluded + assert "Type050" not in result + def test_import_to_nonexistent_class_in_file(self): """When an import resolves to a file but the class doesn't exist in it, skeleton extraction returns None.""" analyzer = get_java_analyzer()