Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions codeflash/languages/java/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,10 @@ def extract_class_context(file_path: Path, class_name: str, analyzer: JavaAnalyz

# Maximum token budget for imported type skeletons to avoid bloating testgen context
IMPORTED_SKELETON_TOKEN_BUDGET = 4000
# Maximum types to expand from a single wildcard import before filtering to referenced types only.
# Packages with more types than this (e.g. org.jooq with 870+) would waste minutes of disk I/O
# and almost always exceed the token budget.
MAX_WILDCARD_TYPES_UNFILTERED = 50


def _extract_type_names_from_code(code: str, analyzer: JavaAnalyzer) -> set[str]:
Expand Down Expand Up @@ -932,11 +936,29 @@ def get_java_imported_type_skeletons(
resolved_imports: list = []
for imp in imports:
if imp.is_wildcard:
# Expand wildcard imports (e.g., com.aerospike.client.policy.*) into individual types
expanded = resolver.expand_wildcard_import(imp.import_path)
# First try unfiltered expansion with a cap. If the package is small enough, take all types.
# If it's huge (e.g. org.jooq.* with 870+ types), filter to only types referenced in the target code.
expanded = resolver.expand_wildcard_import(imp.import_path, max_types=MAX_WILDCARD_TYPES_UNFILTERED + 1)
if len(expanded) > MAX_WILDCARD_TYPES_UNFILTERED:
if priority_types:
expanded = resolver.expand_wildcard_import(imp.import_path, filter_names=priority_types)
logger.debug(
"Wildcard %s.* exceeds %d types, filtered to %d referenced types",
imp.import_path,
MAX_WILDCARD_TYPES_UNFILTERED,
len(expanded),
)
else:
expanded = expanded[:MAX_WILDCARD_TYPES_UNFILTERED]
logger.debug(
"Wildcard %s.* exceeds %d types, capped (no target types to filter by)",
imp.import_path,
MAX_WILDCARD_TYPES_UNFILTERED,
)
elif expanded:
logger.debug("Expanded wildcard import %s.* into %d types", imp.import_path, len(expanded))
if expanded:
resolved_imports.extend(expanded)
logger.debug("Expanded wildcard import %s.* into %d types", imp.import_path, len(expanded))
continue

resolved = resolver.resolve_import(imp)
Expand Down
36 changes: 23 additions & 13 deletions codeflash/languages/java/import_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,20 @@ def _extract_class_name(self, import_path: str) -> str | None:
return last_part
return None

def expand_wildcard_import(self, import_path: str) -> list[ResolvedImport]:
def expand_wildcard_import(
self, import_path: str, max_types: int = 0, filter_names: set[str] | None = None
) -> list[ResolvedImport]:
"""Expand a wildcard import (e.g., com.example.utils.*) to individual class imports.

Resolves the package path to a directory and returns a ResolvedImport for each
.java file found in that directory.

Args:
import_path: The package path (without the trailing .*).
max_types: Maximum number of types to return. 0 means no limit.
filter_names: If provided, only include types whose class name is in this set.

"""
# Convert package path to directory path
# e.g., "com.example.utils" -> "com/example/utils"
relative_dir = import_path.replace(".", "/")

resolved: list[ResolvedImport] = []
Expand All @@ -237,17 +243,21 @@ def expand_wildcard_import(self, import_path: str) -> list[ResolvedImport]:
if candidate_dir.is_dir():
for java_file in candidate_dir.glob("*.java"):
class_name = java_file.stem
# Only include files that look like class names (start with uppercase)
if class_name and class_name[0].isupper():
resolved.append(
ResolvedImport(
import_path=f"{import_path}.{class_name}",
file_path=java_file,
is_external=False,
is_wildcard=False,
class_name=class_name,
)
if not class_name or not class_name[0].isupper():
continue
if filter_names is not None and class_name not in filter_names:
continue
resolved.append(
ResolvedImport(
import_path=f"{import_path}.{class_name}",
file_path=java_file,
is_external=False,
is_wildcard=False,
class_name=class_name,
)
)
if max_types and len(resolved) >= max_types:
return resolved

return resolved

Expand Down
30 changes: 30 additions & 0 deletions tests/test_languages/test_java/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2530,6 +2530,36 @@ def test_wildcard_imports_are_expanded(self):
# Wildcard imports should now be expanded to individual classes found in the package directory
assert "MathHelper" in result

def test_large_wildcard_is_filtered_to_referenced_types(self, tmp_path: Path):
"""When wildcard expands to >50 types, only types referenced in target code are included."""
from codeflash.languages.java.context import MAX_WILDCARD_TYPES_UNFILTERED

# Create a minimal Maven project structure so the resolver finds source roots
(tmp_path / "pom.xml").write_text("<project/>", encoding="utf-8")
pkg_dir = tmp_path / "src" / "main" / "java" / "com" / "bigpkg"
pkg_dir.mkdir(parents=True)
for i in range(MAX_WILDCARD_TYPES_UNFILTERED + 20):
(pkg_dir / f"Type{i:03d}.java").write_text(
f"package com.bigpkg;\npublic class Type{i:03d} {{ public int val() {{ return {i}; }} }}\n",
encoding="utf-8",
)

analyzer = get_java_analyzer()
# Target code references Type000 and Type001 only
target_code = "Type000 a = new Type000(); Type001 b = a.transform();"
source = "package com.example;\nimport com.bigpkg.*;\npublic class Foo { void bar() {} }"
imports = analyzer.find_imports(source)

result = get_java_imported_type_skeletons(
imports, tmp_path, tmp_path / "src" / "main" / "java", analyzer, target_code=target_code
)

# Only referenced types should appear, not all 70
assert "Type000" in result
assert "Type001" in result
# Types not referenced in target code should be excluded
assert "Type050" not in result

def test_import_to_nonexistent_class_in_file(self):
"""When an import resolves to a file but the class doesn't exist in it, skeleton extraction returns None."""
analyzer = get_java_analyzer()
Expand Down
Loading