Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 83 additions & 24 deletions flowrep/crawler.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,104 @@
import ast
import inspect
import types
from typing import Any
from collections.abc import Callable

from pyiron_snippets import versions

from flowrep.models.parsers import object_scope, parser_helpers

class CallCollector(ast.NodeVisitor):
def __init__(self):
self.calls = []
CallDependencies = dict[versions.VersionInfo, list[Callable]]

def visit_Call(self, node):
self.calls.append(node.func)
self.generic_visit(node)

def get_call_dependencies(
func: types.FunctionType,
version_scraping: versions.VersionScrapingMap | None = None,
_call_dependencies: CallDependencies | None = None,
_visited: set[str] | None = None,
) -> CallDependencies:
"""
Recursively collect all callable dependencies of *func* via AST introspection.

def _build_global_namespace(func) -> dict[str, object]:
return dict(func.__globals__)
Each dependency is keyed by its :class:`~pyiron_snippets.versions.VersionInfo`
and maps to the list of concrete callables sharing that identity. The search
is depth-first: for every resolved callee that is a
:class:`~types.FunctionType` (i.e. has inspectable source), the function
recurses into the callee's own scope.

Args:
func: The function whose call-graph to analyse.
version_scraping (VersionScrapingMap | None): Since some modules may store
their version in other ways, this provides an optional map between module
names and callables to leverage for extracting that module's version.
_call_dependencies: Accumulator for recursive calls — do not pass manually.
_visited: Fully-qualified names already traversed — do not pass manually.

def _resolve_ast_node(node: ast.AST, namespace: dict[str, object]) -> Any:
Returns:
A mapping from :class:`VersionInfo` to the callables found under that
identity across the entire (sub-)tree.
"""
Resolve an AST node to its corresponding object in the given namespace.
call_dependencies: CallDependencies = _call_dependencies or {}
visited: set[str] = _visited or set()

func_fqn = versions.VersionInfo.of(func).fully_qualified_name
if func_fqn in visited:
return call_dependencies
visited.add(func_fqn)

scope = object_scope.get_scope(func)
tree = parser_helpers.get_ast_function_node(func)
collector = CallCollector()
collector.visit(tree)

for call in collector.calls:
try:
caller = object_scope.resolve_symbol_to_object(call, scope)
except (ValueError, TypeError):
continue

if not callable(caller):
continue

info = versions.VersionInfo.of(caller, version_scraping=version_scraping)
call_dependencies.setdefault(info, []).append(caller)

# Depth-first search on dependencies — only possible when we have source
if isinstance(caller, types.FunctionType):
get_call_dependencies(caller, version_scraping, call_dependencies, visited)

return call_dependencies


def split_by_version_availability(
call_dependencies: CallDependencies,
) -> tuple[CallDependencies, CallDependencies]:
"""
Partition *call_dependencies* by whether a version string is available.

Args:
node (ast.AST): The AST node to resolve.
namespace (dict[str, object]): The namespace to use for resolution.
call_dependencies: The dependency map to partition.

Returns:
Any: The resolved object, or None if it cannot be resolved.
A ``(has_version, no_version)`` tuple of :data:`CallDependencies` dicts.
"""
if isinstance(node, ast.Name):
return namespace.get(node.id)
has_version: CallDependencies = {}
no_version: CallDependencies = {}
for info, dependents in call_dependencies.items():
if info.version is None:
no_version[info] = dependents
else:
has_version[info] = dependents

if isinstance(node, ast.Attribute):
base = _resolve_ast_node(node.value, namespace)
if base is None:
return None
return getattr(base, node.attr, None)
return has_version, no_version

return None

class CallCollector(ast.NodeVisitor):
def __init__(self):
self.calls: list[ast.expr] = []

def visit_Call(self, node: ast.Call) -> None:
self.calls.append(node.func)
self.generic_visit(node)


def extract_called_functions(func: types.FunctionType) -> set[types.FunctionType]:
Expand All @@ -58,11 +117,11 @@ def extract_called_functions(func: types.FunctionType) -> set[types.FunctionType
collector = CallCollector()
collector.visit(tree)

namespace = _build_global_namespace(func)
namespace = object_scope.get_scope(func)
resolved = set()

for call_node in collector.calls:
obj = _resolve_ast_node(call_node, namespace)
obj = object_scope.resolve_symbol_to_object(call_node, namespace)
if callable(obj):
resolved.add(obj)

Expand Down
215 changes: 215 additions & 0 deletions tests/unit/test_crawler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import math
import unittest

from pyiron_snippets import versions

from flowrep import crawler


Expand All @@ -19,6 +21,77 @@ def more_op(a, b):
return c


# ---------------------------------------------------------------------------
# Helper functions defined at module level so they have inspectable source,
# a proper __module__, and a stable __qualname__.
# ---------------------------------------------------------------------------


def _leaf():
return 42


def _single_call():
return _leaf()


def _diamond_a():
return _leaf()


def _diamond_b():
return _leaf()


def _diamond_root():
_diamond_a()
_diamond_b()


def _mutual_b():
return _leaf()


def _mutual_a():
return _mutual_b()


# Mutual recursion to exercise cycle detection.
def _cycle_a():
return _cycle_b() # noqa: F821 — defined below


def _cycle_b():
return _cycle_a()


def _no_calls():
x = 1 + 2
return x


def _calls_len():
return len([1, 2, 3])


def _nested_call():
return _single_call()


def _multi_call():
a = _leaf()
b = _leaf()
return a + b


def _fqn(func) -> str:
return versions.VersionInfo.of(func).fully_qualified_name


def _fqns(deps: crawler.CallDependencies) -> set[str]:
return {info.fully_qualified_name for info in deps}


class TestCrawler(unittest.TestCase):
def test_analyze_function_dependencies(self):
loc, ext = crawler.analyze_function_dependencies(op)
Expand All @@ -39,5 +112,147 @@ def test_extract_called_functions(self):
self.assertEqual(called, {op})


class TestGetCallDependencies(unittest.TestCase):
"""Tests for :func:`crawler.get_call_dependencies`."""

# --- basic behaviour ---

def test_no_calls_returns_empty(self):
deps = crawler.get_call_dependencies(_no_calls)
self.assertEqual(deps, {})

def test_single_direct_call(self):
deps = crawler.get_call_dependencies(_single_call)
self.assertIn(_fqn(_leaf), _fqns(deps))

def test_transitive_dependencies(self):
deps = crawler.get_call_dependencies(_nested_call)
fqns = _fqns(deps)
# Should find both _single_call and _leaf
self.assertIn(_fqn(_single_call), fqns)
self.assertIn(_fqn(_leaf), fqns)

def test_diamond_dependency_no_duplicate_keys(self):
"""
_diamond_root -> _diamond_a -> _leaf AND _diamond_root -> _diamond_b -> _leaf.
_leaf's VersionInfo should appear exactly once as a key.
"""
deps = crawler.get_call_dependencies(_diamond_root)
matching = [info for info in deps if info.fully_qualified_name == _fqn(_leaf)]
self.assertEqual(len(matching), 1)

# --- cycle safety ---

def test_cycle_does_not_recurse_infinitely(self):
# Should terminate without RecursionError
deps = crawler.get_call_dependencies(_cycle_a)
self.assertIn(_fqn(_cycle_b), _fqns(deps))

# --- builtins / non-FunctionType callables ---

def test_builtin_callable_included(self):
deps = crawler.get_call_dependencies(_calls_len)
self.assertIn(_fqn(len), _fqns(deps))

# --- accumulator semantics ---

def test_same_function_called_twice_appears_multiple_times_in_list(self):
deps = crawler.get_call_dependencies(_multi_call)
matching = [info for info in deps if info.fully_qualified_name == _fqn(_leaf)]
self.assertEqual(len(matching), 1, "single key expected")
# The list value should have two entries (one per call-site)
self.assertEqual(len(deps[matching[0]]), 2)

def test_returns_dict_type(self):
deps = crawler.get_call_dependencies(_leaf)
self.assertIsInstance(deps, dict)


class TestSplitByVersionAvailability(unittest.TestCase):
"""Tests for :func:`crawler.split_by_version_availability`."""

@staticmethod
def _make_info(
module: str, qualname: str, version: str | None = None
) -> versions.VersionInfo:
return versions.VersionInfo(
module=module,
qualname=qualname,
version=version,
)

def test_empty_input(self):
has, no = crawler.split_by_version_availability({})
self.assertEqual(has, {})
self.assertEqual(no, {})

def test_all_versioned(self):
info_a = self._make_info("pkg", "a", "1.0")
info_b = self._make_info("pkg", "b", "2.0")
deps: crawler.CallDependencies = {info_a: [_leaf], info_b: [_leaf]}

has, no = crawler.split_by_version_availability(deps)
self.assertEqual(len(has), 2)
self.assertEqual(len(no), 0)

def test_all_unversioned(self):
info_a = self._make_info("local", "a")
info_b = self._make_info("local", "b")
deps: crawler.CallDependencies = {info_a: [_leaf], info_b: [_leaf]}

has, no = crawler.split_by_version_availability(deps)
self.assertEqual(len(has), 0)
self.assertEqual(len(no), 2)

def test_mixed(self):
versioned = self._make_info("pkg", "x", "3.1")
unversioned = self._make_info("local", "y")
deps: crawler.CallDependencies = {
versioned: [_leaf],
unversioned: [_single_call],
}

has, no = crawler.split_by_version_availability(deps)
self.assertIn(versioned, has)
self.assertIn(unversioned, no)
self.assertNotIn(versioned, no)
self.assertNotIn(unversioned, has)

def test_preserves_callable_lists(self):
info = self._make_info("pkg", "z", "1.0")
callables = [_leaf, _single_call, _no_calls]
deps: crawler.CallDependencies = {info: callables}

has, _ = crawler.split_by_version_availability(deps)
self.assertIs(has[info], callables)

def test_partition_is_exhaustive_and_disjoint(self):
"""Every key in the input appears in exactly one partition."""
infos = [
self._make_info("pkg", "a", "1.0"),
self._make_info("local", "b"),
self._make_info("pkg", "c", "0.1"),
self._make_info("local", "d"),
]
deps: crawler.CallDependencies = {info: [_leaf] for info in infos}

has, no = crawler.split_by_version_availability(deps)
self.assertEqual(set(has) | set(no), set(deps))
self.assertTrue(set(has).isdisjoint(set(no)))

def test_version_none_vs_empty_string(self):
"""Only ``None`` counts as unversioned; an empty string is still 'versioned'."""
none_version = self._make_info("local", "f", None)
empty_version = self._make_info("local", "g", "")
deps: crawler.CallDependencies = {
none_version: [_leaf],
empty_version: [_leaf],
}

has, no = crawler.split_by_version_availability(deps)
self.assertIn(none_version, no)
self.assertIn(empty_version, has)


if __name__ == "__main__":
unittest.main()
Loading