diff --git a/flowrep/crawler.py b/flowrep/crawler.py index 05fcc3a6..6c1cbcfc 100644 --- a/flowrep/crawler.py +++ b/flowrep/crawler.py @@ -1,45 +1,104 @@ import ast import inspect import types -from typing import Any +from collections.abc import Callable from pyiron_snippets import versions +from flowrep.models.parsers import object_scope, parser_helpers -class CallCollector(ast.NodeVisitor): - def __init__(self): - self.calls = [] +CallDependencies = dict[versions.VersionInfo, list[Callable]] - def visit_Call(self, node): - self.calls.append(node.func) - self.generic_visit(node) +def get_call_dependencies( + func: types.FunctionType, + version_scraping: versions.VersionScrapingMap | None = None, + _call_dependencies: CallDependencies | None = None, + _visited: set[str] | None = None, +) -> CallDependencies: + """ + Recursively collect all callable dependencies of *func* via AST introspection. -def _build_global_namespace(func) -> dict[str, object]: - return dict(func.__globals__) + Each dependency is keyed by its :class:`~pyiron_snippets.versions.VersionInfo` + and maps to the list of concrete callables sharing that identity. The search + is depth-first: for every resolved callee that is a + :class:`~types.FunctionType` (i.e. has inspectable source), the function + recurses into the callee's own scope. + Args: + func: The function whose call-graph to analyse. + version_scraping (VersionScrapingMap | None): Since some modules may store + their version in other ways, this provides an optional map between module + names and callables to leverage for extracting that module's version. + _call_dependencies: Accumulator for recursive calls — do not pass manually. + _visited: Fully-qualified names already traversed — do not pass manually. -def _resolve_ast_node(node: ast.AST, namespace: dict[str, object]) -> Any: + Returns: + A mapping from :class:`VersionInfo` to the callables found under that + identity across the entire (sub-)tree. """ - Resolve an AST node to its corresponding object in the given namespace. + call_dependencies: CallDependencies = _call_dependencies or {} + visited: set[str] = _visited or set() + + func_fqn = versions.VersionInfo.of(func).fully_qualified_name + if func_fqn in visited: + return call_dependencies + visited.add(func_fqn) + + scope = object_scope.get_scope(func) + tree = parser_helpers.get_ast_function_node(func) + collector = CallCollector() + collector.visit(tree) + + for call in collector.calls: + try: + caller = object_scope.resolve_symbol_to_object(call, scope) + except (ValueError, TypeError): + continue + + if not callable(caller): + continue + + info = versions.VersionInfo.of(caller, version_scraping=version_scraping) + call_dependencies.setdefault(info, []).append(caller) + + # Depth-first search on dependencies — only possible when we have source + if isinstance(caller, types.FunctionType): + get_call_dependencies(caller, version_scraping, call_dependencies, visited) + + return call_dependencies + + +def split_by_version_availability( + call_dependencies: CallDependencies, +) -> tuple[CallDependencies, CallDependencies]: + """ + Partition *call_dependencies* by whether a version string is available. Args: - node (ast.AST): The AST node to resolve. - namespace (dict[str, object]): The namespace to use for resolution. + call_dependencies: The dependency map to partition. Returns: - Any: The resolved object, or None if it cannot be resolved. + A ``(has_version, no_version)`` tuple of :data:`CallDependencies` dicts. """ - if isinstance(node, ast.Name): - return namespace.get(node.id) + has_version: CallDependencies = {} + no_version: CallDependencies = {} + for info, dependents in call_dependencies.items(): + if info.version is None: + no_version[info] = dependents + else: + has_version[info] = dependents - if isinstance(node, ast.Attribute): - base = _resolve_ast_node(node.value, namespace) - if base is None: - return None - return getattr(base, node.attr, None) + return has_version, no_version - return None + +class CallCollector(ast.NodeVisitor): + def __init__(self): + self.calls: list[ast.expr] = [] + + def visit_Call(self, node: ast.Call) -> None: + self.calls.append(node.func) + self.generic_visit(node) def extract_called_functions(func: types.FunctionType) -> set[types.FunctionType]: @@ -58,11 +117,11 @@ def extract_called_functions(func: types.FunctionType) -> set[types.FunctionType collector = CallCollector() collector.visit(tree) - namespace = _build_global_namespace(func) + namespace = object_scope.get_scope(func) resolved = set() for call_node in collector.calls: - obj = _resolve_ast_node(call_node, namespace) + obj = object_scope.resolve_symbol_to_object(call_node, namespace) if callable(obj): resolved.add(obj) diff --git a/tests/unit/test_crawler.py b/tests/unit/test_crawler.py index 909d280e..4c72369b 100644 --- a/tests/unit/test_crawler.py +++ b/tests/unit/test_crawler.py @@ -1,6 +1,8 @@ import math import unittest +from pyiron_snippets import versions + from flowrep import crawler @@ -19,6 +21,77 @@ def more_op(a, b): return c +# --------------------------------------------------------------------------- +# Helper functions defined at module level so they have inspectable source, +# a proper __module__, and a stable __qualname__. +# --------------------------------------------------------------------------- + + +def _leaf(): + return 42 + + +def _single_call(): + return _leaf() + + +def _diamond_a(): + return _leaf() + + +def _diamond_b(): + return _leaf() + + +def _diamond_root(): + _diamond_a() + _diamond_b() + + +def _mutual_b(): + return _leaf() + + +def _mutual_a(): + return _mutual_b() + + +# Mutual recursion to exercise cycle detection. +def _cycle_a(): + return _cycle_b() # noqa: F821 — defined below + + +def _cycle_b(): + return _cycle_a() + + +def _no_calls(): + x = 1 + 2 + return x + + +def _calls_len(): + return len([1, 2, 3]) + + +def _nested_call(): + return _single_call() + + +def _multi_call(): + a = _leaf() + b = _leaf() + return a + b + + +def _fqn(func) -> str: + return versions.VersionInfo.of(func).fully_qualified_name + + +def _fqns(deps: crawler.CallDependencies) -> set[str]: + return {info.fully_qualified_name for info in deps} + + class TestCrawler(unittest.TestCase): def test_analyze_function_dependencies(self): loc, ext = crawler.analyze_function_dependencies(op) @@ -39,5 +112,147 @@ def test_extract_called_functions(self): self.assertEqual(called, {op}) +class TestGetCallDependencies(unittest.TestCase): + """Tests for :func:`crawler.get_call_dependencies`.""" + + # --- basic behaviour --- + + def test_no_calls_returns_empty(self): + deps = crawler.get_call_dependencies(_no_calls) + self.assertEqual(deps, {}) + + def test_single_direct_call(self): + deps = crawler.get_call_dependencies(_single_call) + self.assertIn(_fqn(_leaf), _fqns(deps)) + + def test_transitive_dependencies(self): + deps = crawler.get_call_dependencies(_nested_call) + fqns = _fqns(deps) + # Should find both _single_call and _leaf + self.assertIn(_fqn(_single_call), fqns) + self.assertIn(_fqn(_leaf), fqns) + + def test_diamond_dependency_no_duplicate_keys(self): + """ + _diamond_root -> _diamond_a -> _leaf AND _diamond_root -> _diamond_b -> _leaf. + _leaf's VersionInfo should appear exactly once as a key. + """ + deps = crawler.get_call_dependencies(_diamond_root) + matching = [info for info in deps if info.fully_qualified_name == _fqn(_leaf)] + self.assertEqual(len(matching), 1) + + # --- cycle safety --- + + def test_cycle_does_not_recurse_infinitely(self): + # Should terminate without RecursionError + deps = crawler.get_call_dependencies(_cycle_a) + self.assertIn(_fqn(_cycle_b), _fqns(deps)) + + # --- builtins / non-FunctionType callables --- + + def test_builtin_callable_included(self): + deps = crawler.get_call_dependencies(_calls_len) + self.assertIn(_fqn(len), _fqns(deps)) + + # --- accumulator semantics --- + + def test_same_function_called_twice_appears_multiple_times_in_list(self): + deps = crawler.get_call_dependencies(_multi_call) + matching = [info for info in deps if info.fully_qualified_name == _fqn(_leaf)] + self.assertEqual(len(matching), 1, "single key expected") + # The list value should have two entries (one per call-site) + self.assertEqual(len(deps[matching[0]]), 2) + + def test_returns_dict_type(self): + deps = crawler.get_call_dependencies(_leaf) + self.assertIsInstance(deps, dict) + + +class TestSplitByVersionAvailability(unittest.TestCase): + """Tests for :func:`crawler.split_by_version_availability`.""" + + @staticmethod + def _make_info( + module: str, qualname: str, version: str | None = None + ) -> versions.VersionInfo: + return versions.VersionInfo( + module=module, + qualname=qualname, + version=version, + ) + + def test_empty_input(self): + has, no = crawler.split_by_version_availability({}) + self.assertEqual(has, {}) + self.assertEqual(no, {}) + + def test_all_versioned(self): + info_a = self._make_info("pkg", "a", "1.0") + info_b = self._make_info("pkg", "b", "2.0") + deps: crawler.CallDependencies = {info_a: [_leaf], info_b: [_leaf]} + + has, no = crawler.split_by_version_availability(deps) + self.assertEqual(len(has), 2) + self.assertEqual(len(no), 0) + + def test_all_unversioned(self): + info_a = self._make_info("local", "a") + info_b = self._make_info("local", "b") + deps: crawler.CallDependencies = {info_a: [_leaf], info_b: [_leaf]} + + has, no = crawler.split_by_version_availability(deps) + self.assertEqual(len(has), 0) + self.assertEqual(len(no), 2) + + def test_mixed(self): + versioned = self._make_info("pkg", "x", "3.1") + unversioned = self._make_info("local", "y") + deps: crawler.CallDependencies = { + versioned: [_leaf], + unversioned: [_single_call], + } + + has, no = crawler.split_by_version_availability(deps) + self.assertIn(versioned, has) + self.assertIn(unversioned, no) + self.assertNotIn(versioned, no) + self.assertNotIn(unversioned, has) + + def test_preserves_callable_lists(self): + info = self._make_info("pkg", "z", "1.0") + callables = [_leaf, _single_call, _no_calls] + deps: crawler.CallDependencies = {info: callables} + + has, _ = crawler.split_by_version_availability(deps) + self.assertIs(has[info], callables) + + def test_partition_is_exhaustive_and_disjoint(self): + """Every key in the input appears in exactly one partition.""" + infos = [ + self._make_info("pkg", "a", "1.0"), + self._make_info("local", "b"), + self._make_info("pkg", "c", "0.1"), + self._make_info("local", "d"), + ] + deps: crawler.CallDependencies = {info: [_leaf] for info in infos} + + has, no = crawler.split_by_version_availability(deps) + self.assertEqual(set(has) | set(no), set(deps)) + self.assertTrue(set(has).isdisjoint(set(no))) + + def test_version_none_vs_empty_string(self): + """Only ``None`` counts as unversioned; an empty string is still 'versioned'.""" + none_version = self._make_info("local", "f", None) + empty_version = self._make_info("local", "g", "") + deps: crawler.CallDependencies = { + none_version: [_leaf], + empty_version: [_leaf], + } + + has, no = crawler.split_by_version_availability(deps) + self.assertIn(none_version, no) + self.assertIn(empty_version, has) + + if __name__ == "__main__": unittest.main()