diff --git a/src/flowrep/parsers/dependency_parser.py b/src/flowrep/parsers/dependency_parser.py index c3865923..519def8c 100644 --- a/src/flowrep/parsers/dependency_parser.py +++ b/src/flowrep/parsers/dependency_parser.py @@ -1,85 +1,16 @@ +from __future__ import annotations + import ast -import types +import builtins +import inspect +import textwrap from collections.abc import Callable from pyiron_snippets import versions -from flowrep.parsers import import_parser, object_scope, parser_helpers - -CallDependencies = dict[versions.VersionInfo, Callable] - - -def get_call_dependencies( - func: types.FunctionType, - version_scraping: versions.VersionScrapingMap | None = None, - _call_dependencies: CallDependencies | None = None, - _visited: set[str] | None = None, -) -> CallDependencies: - """ - Recursively collect all callable dependencies of *func* via AST introspection. - - Each dependency is keyed by its :class:`~pyiron_snippets.versions.VersionInfo` - and maps to the callables instance with that identity. The search is depth-first: - for every resolved callee that is a :class:`~types.FunctionType` (i.e. has - inspectable source), the function recurses into the callee's own scope. +from flowrep.parsers import object_scope - Args: - func: The function whose call-graph to analyse. - version_scraping (VersionScrapingMap | None): Since some modules may store - their version in other ways, this provides an optional map between module - names and callables to leverage for extracting that module's version. - _call_dependencies: Accumulator for recursive calls — do not pass manually. - _visited: Fully-qualified names already traversed — do not pass manually. - - Returns: - A mapping from :class:`VersionInfo` to the callables found under that - identity across the entire (sub-)tree. - """ - call_dependencies: CallDependencies = _call_dependencies or {} - visited: set[str] = _visited or set() - - func_fqn = versions.VersionInfo.of(func).fully_qualified_name - if func_fqn in visited: - return call_dependencies - visited.add(func_fqn) - - tree = parser_helpers.get_ast_function_node(func) - collector = CallCollector() - collector.visit(tree) - local_modules = import_parser.build_scope(collector.imports, collector.import_froms) - scope = object_scope.get_scope(func) - for name, obj in local_modules.items(): - scope.register(name=name, obj=obj) - - for call in collector.calls: - try: - caller = object_scope.resolve_symbol_to_object(call, scope) - except (ValueError, TypeError): - continue - - if not callable(caller): # pragma: no cover - # Under remotely normal circumstances, this should be unreachable - raise TypeError( - f"Caller {caller} is not callable, yet was generated from the list of " - f"ast.Call calls, in particular {call}. We're expecting these to " - f"actually connect to callables. Please raise a GitHub issue if you " - f"think this is not a mistake." - ) - - info = versions.VersionInfo.of(caller, version_scraping=version_scraping) - # In principle, we open ourselves to overwriting an existing dependency here, - # but it would need to somehow have exactly the same version info (including - # qualname) yet be a different object. - # This ought not happen by accident, and in case it somehow does happen on - # purpose (it probably shouldn't), we just silently keep the more recent one. - - call_dependencies[info] = caller - - # Depth-first search on dependencies — only possible when we have source - if isinstance(caller, types.FunctionType): - get_call_dependencies(caller, version_scraping, call_dependencies, visited) - - return call_dependencies +CallDependencies = dict[versions.VersionInfo, object] def split_by_version_availability( @@ -105,20 +36,118 @@ def split_by_version_availability( return has_version, no_version -class CallCollector(ast.NodeVisitor): - def __init__(self): - self.calls: list[ast.expr] = [] - self.imports: list[ast.Import] = [] - self.import_froms: list[ast.ImportFrom] = [] +class UndefinedVariableVisitor(ast.NodeVisitor): + """AST visitor that collects used and locally-defined variable names. - def visit_Call(self, node: ast.Call) -> None: - self.calls.append(node.func) - self.generic_visit(node) + Local (nested) function definitions inside the analysed function body are + **not** supported: encountering one raises :exc:`NotImplementedError` so + that callers fail fast with a clear message instead of silently producing + wrong dependency results. + + Class definitions at any nesting level are tracked in :attr:`defined_vars` + so that class names used later in the same scope are not reported as + undefined symbols. + """ - def visit_Import(self, node: ast.Import) -> None: - self.imports.append(node) + def __init__(self): + self.used_vars: set[str] = set() + self.defined_vars: set[str] = set() + self._nesting_depth: int = 0 + + def visit_Name(self, node: ast.Name) -> None: + if isinstance(node.ctx, ast.Load): + self.used_vars.add(node.id) + elif isinstance(node.ctx, ast.Store): + self.defined_vars.add(node.id) + + def _visit_function_def(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None: + if self._nesting_depth > 0: + keyword = "async def" if isinstance(node, ast.AsyncFunctionDef) else "def" + raise NotImplementedError( + f"Local function definitions are not supported: " + f"'{keyword} {node.name}' inside a function body cannot be " + "analysed for dependencies." + ) + # Register the function name and all of its parameters so that + # recursive calls and uses of any argument inside the body are not + # reported as undefined external symbols. + self.defined_vars.add(node.name) + all_args = node.args.posonlyargs + node.args.args + node.args.kwonlyargs + for arg in all_args: + self.defined_vars.add(arg.arg) + if node.args.vararg: + self.defined_vars.add(node.args.vararg.arg) + if node.args.kwarg: + self.defined_vars.add(node.args.kwarg.arg) + self._nesting_depth += 1 self.generic_visit(node) + self._nesting_depth -= 1 + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + self._visit_function_def(node) - def visit_ImportFrom(self, node: ast.ImportFrom) -> None: - self.import_froms.append(node) + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + self._visit_function_def(node) + + def visit_ClassDef(self, node: ast.ClassDef) -> None: + self.defined_vars.add(node.name) self.generic_visit(node) + + +def find_undefined_variables( + func_or_var: Callable | object, +) -> set[str]: + """ + Find variables that are used but not defined in the source of *func_or_var*. + + If the source code for *func_or_var* cannot be retrieved or parsed (e.g., + for certain built-in objects or when no source is available), this + function returns an empty set instead of raising an exception. + """ + try: + # Prefer actual source code over string representations for both + # callables and other inspectable objects (e.g. classes, modules). + raw_source = inspect.getsource(func_or_var) + except (OSError, TypeError): + # No reliable source available; treat as having no undefined variables. + return set() + + source = textwrap.dedent(raw_source) + + try: + tree = ast.parse(source) + except SyntaxError: + # Source could not be parsed as Python code; fail gracefully. + return set() + + visitor = UndefinedVariableVisitor() + visitor.visit(tree) + undefined_vars = visitor.used_vars - visitor.defined_vars + return undefined_vars.difference(set(dir(builtins))) + + +def get_call_dependencies( + func_or_var: Callable | object, + version_scraping: versions.VersionScrapingMap | None = None, + _call_dependencies: CallDependencies | None = None, + _visited: set[str] | None = None, +) -> CallDependencies: + + call_dependencies: CallDependencies = _call_dependencies or {} + visited: set[str] = _visited or set() + + func_fqn = versions.VersionInfo.of(func_or_var).fully_qualified_name + if func_fqn in visited: + return call_dependencies + visited.add(func_fqn) + + # Find variables that are used but not defined + scope = object_scope.get_scope(func_or_var) + for item in find_undefined_variables(func_or_var): + obj = object_scope.resolve_attribute_to_object(item, scope) + info = versions.VersionInfo.of(obj, version_scraping=version_scraping) + call_dependencies[info] = obj + + if info.version is None: + get_call_dependencies(obj, version_scraping, call_dependencies, visited) + return call_dependencies diff --git a/tests/unit/parsers/test_dependency_parser.py b/tests/unit/parsers/test_dependency_parser.py index 918e6a33..80e40f68 100644 --- a/tests/unit/parsers/test_dependency_parser.py +++ b/tests/unit/parsers/test_dependency_parser.py @@ -1,281 +1,132 @@ -import math +import ast +import textwrap import unittest - -from pyiron_snippets import versions +from unittest.mock import MagicMock, patch from flowrep.parsers import dependency_parser -# --------------------------------------------------------------------------- -# Helper functions defined at module level so they have inspectable source, -# a proper __module__, and a stable __qualname__. -# --------------------------------------------------------------------------- - - -def _leaf(): - return 42 - - -def _single_call(): - return _leaf() - - -def _diamond_a(): - return _leaf() - - -def _diamond_b(): - return _leaf() - - -def _diamond_root(): - _diamond_a() - _diamond_b() - - -# Mutual recursion to exercise cycle detection. -def _cycle_a(): - return _cycle_b() # noqa: F821 — defined below - - -def _cycle_b(): - return _cycle_a() - - -def _no_calls(): - x = 1 + 2 - return x - - -def _calls_len(): - return len([1, 2, 3]) - - -def _nested_call(): - return _single_call() - - -def _multi_call(): - a = _leaf() - b = _leaf() - return a + b - - -def _attribute_access(x): - return math.sqrt(x) - - -def _nested_expression(x, y, z): - return _single_call(_leaf(x, y), z) - - -def _unresolvable_subscript(): - d = {} - return d["key"]() - - -def _calls_non_callable(): - x = 42 - return x - - -def _fqn(func) -> str: - return versions.VersionInfo.of(func).fully_qualified_name - - -def _fqns(deps: dependency_parser.CallDependencies) -> set[str]: - return {info.fully_qualified_name for info in deps} - - -def _local_imports(x): - import sys as s - from math import sqrt - - a = s.getsizeof(x) - return sqrt(a) +class TestSplitByVersionAvailability(unittest.TestCase): + def test_split_by_version_availability(self): + mock_version_1 = MagicMock(version="1.0.0") + mock_version_2 = MagicMock(version=None) + mock_func_1 = MagicMock() + mock_func_2 = MagicMock() + + call_dependencies = { + mock_version_1: mock_func_1, + mock_version_2: mock_func_2, + } -def _import_from_sibling(x, y): - from .test_for_parser import pair - - a, b = pair(x, y) - return a, b - - -class TestGetCallDependencies(unittest.TestCase): - """Tests for :func:`dependency_parser.get_call_dependencies`.""" - - # --- basic behaviour --- - - def test_no_calls_returns_empty(self): - deps = dependency_parser.get_call_dependencies(_no_calls) - self.assertEqual(deps, {}) + has_version, no_version = dependency_parser.split_by_version_availability( + call_dependencies + ) - def test_single_direct_call(self): - deps = dependency_parser.get_call_dependencies(_single_call) - self.assertIn(_fqn(_leaf), _fqns(deps)) + self.assertIn(mock_version_1, has_version) + self.assertIn(mock_version_2, no_version) + self.assertNotIn(mock_version_1, no_version) + self.assertNotIn(mock_version_2, has_version) - def test_transitive_dependencies(self): - deps = dependency_parser.get_call_dependencies(_nested_call) - fqns = _fqns(deps) - # Should find both _single_call and _leaf - self.assertIn(_fqn(_single_call), fqns) - self.assertIn(_fqn(_leaf), fqns) - def test_diamond_dependency_no_duplicate_keys(self): +class TestUndefinedVariableVisitor(unittest.TestCase): + def test_undefined_variable_visitor(self): + source_code = """ + def test_function(a: int, b): + c = a + b + return d """ - _diamond_root -> _diamond_a -> _leaf AND _diamond_root -> _diamond_b -> _leaf. - _leaf's VersionInfo should appear exactly once as a key. + tree = ast.parse(textwrap.dedent(source_code)) + visitor = dependency_parser.UndefinedVariableVisitor() + visitor.visit(tree) + + self.assertIn("d", visitor.used_vars) + self.assertIn("int", visitor.used_vars) + self.assertIn("a", visitor.defined_vars) + self.assertIn("b", visitor.defined_vars) + self.assertIn("c", visitor.defined_vars) + self.assertNotIn("d", visitor.defined_vars) + + def test_all_argument_kinds_are_defined(self): + source_code = """ + def test_function(posonly, /, regular, *args, kw_only, **kwargs): + return posonly + regular + kw_only """ - deps = dependency_parser.get_call_dependencies(_diamond_root) - matching = [info for info in deps if info.fully_qualified_name == _fqn(_leaf)] - self.assertEqual(len(matching), 1) - - def test_duplicate_call_deduplicated_by_version_info(self): - """Calling the same function twice yields a single key, not two.""" - deps = dependency_parser.get_call_dependencies(_multi_call) - matching = [info for info in deps if info.fully_qualified_name == _fqn(_leaf)] - self.assertEqual(len(matching), 1) - - # --- cycle safety --- - - def test_cycle_does_not_recurse_infinitely(self): - # Should terminate without RecursionError - deps = dependency_parser.get_call_dependencies(_cycle_a) - self.assertIn(_fqn(_cycle_b), _fqns(deps)) - - # --- builtins / non-FunctionType callables --- - - def test_builtin_callable_included(self): - deps = dependency_parser.get_call_dependencies(_calls_len) - self.assertIn(_fqn(len), _fqns(deps)) - - def test_returns_dict_type(self): - deps = dependency_parser.get_call_dependencies(_leaf) - self.assertIsInstance(deps, dict) - - # --- attribute access (module.func) --- - - def test_attribute_access_dependency(self): - """Functions called via attribute access (e.g. math.sqrt) are tracked.""" - deps = dependency_parser.get_call_dependencies(_attribute_access) - self.assertIn(_fqn(math.sqrt), _fqns(deps)) - - # --- nested expressions --- - - def test_nested_expression_collects_all_calls(self): - """All calls in a nested expression like f(g(x), y) are collected.""" - deps = dependency_parser.get_call_dependencies(_nested_expression) - fqns = _fqns(deps) - self.assertIn(_fqn(_single_call), fqns) - self.assertIn(_fqn(_leaf), fqns) - - # --- unresolvable / non-callable targets (coverage for `continue` branches) --- - - def test_unresolvable_call_target_is_skipped(self): - """Calls that resolve_symbol_to_object cannot handle are silently skipped.""" - # _unresolvable_subscript contains d["key"]() which is an ast.Subscript, - # triggering a TypeError in resolve_symbol_to_object - deps = dependency_parser.get_call_dependencies(_unresolvable_subscript) - # Should not raise; the unresolvable call is simply absent - self.assertIsInstance(deps, dict) + tree = ast.parse(textwrap.dedent(source_code)) + visitor = dependency_parser.UndefinedVariableVisitor() + visitor.visit(tree) + + for name in ("posonly", "regular", "args", "kw_only", "kwargs"): + self.assertIn(name, visitor.defined_vars) + + def test_local_function_definition_raises(self): + source_code = """ + def outer(x): + def helper(y): + return y + return helper(x) + """ + tree = ast.parse(textwrap.dedent(source_code)) + visitor = dependency_parser.UndefinedVariableVisitor() + with self.assertRaises(NotImplementedError): + visitor.visit(tree) + + def test_local_async_function_definition_raises(self): + source_code = """ + def outer(x): + async def helper(y): + return y + return helper(x) + """ + tree = ast.parse(textwrap.dedent(source_code)) + visitor = dependency_parser.UndefinedVariableVisitor() + with self.assertRaises(NotImplementedError): + visitor.visit(tree) - def test_non_callable_resolved_symbol_is_skipped(self): - """Symbols that resolve to non-callable objects are silently skipped.""" - # _calls_non_callable doesn't actually have a call in its AST that resolves - # to a non-callable, but we can verify the function itself is crawlable - deps = dependency_parser.get_call_dependencies(_calls_non_callable) - self.assertIsInstance(deps, dict) - def test_local_imports_included(self): - deps = dependency_parser.get_call_dependencies(_local_imports) - fqns = _fqns(deps) - self.assertIn("sys.getsizeof", fqns) - self.assertIn("math.sqrt", fqns) +class TestFindUndefinedVariables(unittest.TestCase): + def test_find_undefined_variables(self): + x = 1 - def test_relative_import_raises(self): - with self.assertRaises(ValueError) as ctx: - dependency_parser.get_call_dependencies(_import_from_sibling) - self.assertIn("Relative imports are not supported", str(ctx.exception)) - self.assertIn("test_for_parser", str(ctx.exception)) + def test_function(a, b): + c = a + b + x + return c + undefined_vars = dependency_parser.find_undefined_variables(test_function) + self.assertIn("x", undefined_vars) + self.assertNotIn("a", undefined_vars) + self.assertNotIn("b", undefined_vars) + self.assertNotIn("c", undefined_vars) -class TestSplitByVersionAvailability(unittest.TestCase): - """Tests for :func:`dependency_parser.split_by_version_availability`.""" - @staticmethod - def _make_info( - module: str, qualname: str, version: str | None = None - ) -> versions.VersionInfo: - return versions.VersionInfo( - module=module, - qualname=qualname, - version=version, +class TestGetCallDependencies(unittest.TestCase): + @patch("flowrep.parsers.object_scope.get_scope") + @patch("flowrep.parsers.object_scope.resolve_attribute_to_object") + @patch("pyiron_snippets.versions.VersionInfo.of") + def test_get_call_dependencies( + self, mock_version_info_of, mock_resolve_attribute_to_object, mock_get_scope + ): + mock_func = MagicMock() + mock_version_info = MagicMock() + mock_version_info.fully_qualified_name = "mock_func" + mock_version_info_of.return_value = mock_version_info + mock_resolve_attribute_to_object.return_value = mock_func + + mock_scope = MagicMock() + mock_get_scope.return_value = mock_scope + + with patch( + "flowrep.parsers.dependency_parser.find_undefined_variables" + ) as mock_find_undefined: + mock_find_undefined.return_value = {"undefined_var"} + call_dependencies = dependency_parser.get_call_dependencies(mock_func) + + self.assertIn(mock_version_info, call_dependencies) + self.assertEqual(call_dependencies[mock_version_info], mock_func) + mock_get_scope.assert_called_once_with(mock_func) + mock_resolve_attribute_to_object.assert_called_once_with( + "undefined_var", mock_scope ) - def test_empty_input(self): - has, no = dependency_parser.split_by_version_availability({}) - self.assertEqual(has, {}) - self.assertEqual(no, {}) - - def test_all_versioned(self): - info_a = self._make_info("pkg", "a", "1.0") - info_b = self._make_info("pkg", "b", "2.0") - deps: dependency_parser.CallDependencies = {info_a: _leaf, info_b: _leaf} - - has, no = dependency_parser.split_by_version_availability(deps) - self.assertEqual(len(has), 2) - self.assertEqual(len(no), 0) - - def test_all_unversioned(self): - info_a = self._make_info("local", "a") - info_b = self._make_info("local", "b") - deps: dependency_parser.CallDependencies = {info_a: _leaf, info_b: _leaf} - - has, no = dependency_parser.split_by_version_availability(deps) - self.assertEqual(len(has), 0) - self.assertEqual(len(no), 2) - - def test_mixed(self): - versioned = self._make_info("pkg", "x", "3.1") - unversioned = self._make_info("local", "y") - deps: dependency_parser.CallDependencies = { - versioned: _leaf, - unversioned: _single_call, - } - - has, no = dependency_parser.split_by_version_availability(deps) - self.assertIn(versioned, has) - self.assertIn(unversioned, no) - self.assertNotIn(versioned, no) - self.assertNotIn(unversioned, has) - - def test_partition_is_exhaustive_and_disjoint(self): - """Every key in the input appears in exactly one partition.""" - infos = [ - self._make_info("pkg", "a", "1.0"), - self._make_info("local", "b"), - self._make_info("pkg", "c", "0.1"), - self._make_info("local", "d"), - ] - deps: dependency_parser.CallDependencies = {info: _leaf for info in infos} - - has, no = dependency_parser.split_by_version_availability(deps) - self.assertEqual(set(has) | set(no), set(deps)) - self.assertTrue(set(has).isdisjoint(set(no))) - - def test_version_none_vs_empty_string(self): - """Only ``None`` counts as unversioned; an empty string is still 'versioned'.""" - none_version = self._make_info("local", "f", None) - empty_version = self._make_info("local", "g", "") - deps: dependency_parser.CallDependencies = { - none_version: _leaf, - empty_version: _leaf, - } - - has, no = dependency_parser.split_by_version_availability(deps) - self.assertIn(none_version, no) - self.assertIn(empty_version, has) - if __name__ == "__main__": unittest.main()