Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 117 additions & 88 deletions src/flowrep/parsers/dependency_parser.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,16 @@
from __future__ import annotations

import ast
import types
import builtins
import inspect
import textwrap
from collections.abc import Callable

from pyiron_snippets import versions

from flowrep.parsers import import_parser, object_scope, parser_helpers

CallDependencies = dict[versions.VersionInfo, Callable]


def get_call_dependencies(
func: types.FunctionType,
version_scraping: versions.VersionScrapingMap | None = None,
_call_dependencies: CallDependencies | None = None,
_visited: set[str] | None = None,
) -> CallDependencies:
"""
Recursively collect all callable dependencies of *func* via AST introspection.

Each dependency is keyed by its :class:`~pyiron_snippets.versions.VersionInfo`
and maps to the callables instance with that identity. The search is depth-first:
for every resolved callee that is a :class:`~types.FunctionType` (i.e. has
inspectable source), the function recurses into the callee's own scope.
from flowrep.parsers import object_scope

Args:
func: The function whose call-graph to analyse.
version_scraping (VersionScrapingMap | None): Since some modules may store
their version in other ways, this provides an optional map between module
names and callables to leverage for extracting that module's version.
_call_dependencies: Accumulator for recursive calls — do not pass manually.
_visited: Fully-qualified names already traversed — do not pass manually.

Returns:
A mapping from :class:`VersionInfo` to the callables found under that
identity across the entire (sub-)tree.
"""
call_dependencies: CallDependencies = _call_dependencies or {}
visited: set[str] = _visited or set()

func_fqn = versions.VersionInfo.of(func).fully_qualified_name
if func_fqn in visited:
return call_dependencies
visited.add(func_fqn)

tree = parser_helpers.get_ast_function_node(func)
collector = CallCollector()
collector.visit(tree)
local_modules = import_parser.build_scope(collector.imports, collector.import_froms)
scope = object_scope.get_scope(func)
for name, obj in local_modules.items():
scope.register(name=name, obj=obj)

for call in collector.calls:
try:
caller = object_scope.resolve_symbol_to_object(call, scope)
except (ValueError, TypeError):
continue

if not callable(caller): # pragma: no cover
# Under remotely normal circumstances, this should be unreachable
raise TypeError(
f"Caller {caller} is not callable, yet was generated from the list of "
f"ast.Call calls, in particular {call}. We're expecting these to "
f"actually connect to callables. Please raise a GitHub issue if you "
f"think this is not a mistake."
)

info = versions.VersionInfo.of(caller, version_scraping=version_scraping)
# In principle, we open ourselves to overwriting an existing dependency here,
# but it would need to somehow have exactly the same version info (including
# qualname) yet be a different object.
# This ought not happen by accident, and in case it somehow does happen on
# purpose (it probably shouldn't), we just silently keep the more recent one.

call_dependencies[info] = caller

# Depth-first search on dependencies — only possible when we have source
if isinstance(caller, types.FunctionType):
get_call_dependencies(caller, version_scraping, call_dependencies, visited)

return call_dependencies
CallDependencies = dict[versions.VersionInfo, object]


def split_by_version_availability(
Expand All @@ -105,20 +36,118 @@ def split_by_version_availability(
return has_version, no_version


class CallCollector(ast.NodeVisitor):
def __init__(self):
self.calls: list[ast.expr] = []
self.imports: list[ast.Import] = []
self.import_froms: list[ast.ImportFrom] = []
class UndefinedVariableVisitor(ast.NodeVisitor):
"""AST visitor that collects used and locally-defined variable names.

def visit_Call(self, node: ast.Call) -> None:
self.calls.append(node.func)
self.generic_visit(node)
Local (nested) function definitions inside the analysed function body are
**not** supported: encountering one raises :exc:`NotImplementedError` so
that callers fail fast with a clear message instead of silently producing
wrong dependency results.

Class definitions at any nesting level are tracked in :attr:`defined_vars`
so that class names used later in the same scope are not reported as
undefined symbols.
"""

def visit_Import(self, node: ast.Import) -> None:
self.imports.append(node)
def __init__(self):
self.used_vars: set[str] = set()
self.defined_vars: set[str] = set()
self._nesting_depth: int = 0

def visit_Name(self, node: ast.Name) -> None:
if isinstance(node.ctx, ast.Load):
self.used_vars.add(node.id)
elif isinstance(node.ctx, ast.Store):
self.defined_vars.add(node.id)

def _visit_function_def(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
if self._nesting_depth > 0:
keyword = "async def" if isinstance(node, ast.AsyncFunctionDef) else "def"
raise NotImplementedError(
f"Local function definitions are not supported: "
f"'{keyword} {node.name}' inside a function body cannot be "
"analysed for dependencies."
)
# Register the function name and all of its parameters so that
# recursive calls and uses of any argument inside the body are not
# reported as undefined external symbols.
self.defined_vars.add(node.name)
all_args = node.args.posonlyargs + node.args.args + node.args.kwonlyargs
for arg in all_args:
self.defined_vars.add(arg.arg)
if node.args.vararg:
self.defined_vars.add(node.args.vararg.arg)
if node.args.kwarg:
self.defined_vars.add(node.args.kwarg.arg)
self._nesting_depth += 1
self.generic_visit(node)
self._nesting_depth -= 1

def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
self._visit_function_def(node)

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
self.import_froms.append(node)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
self._visit_function_def(node)

def visit_ClassDef(self, node: ast.ClassDef) -> None:
self.defined_vars.add(node.name)
self.generic_visit(node)


def find_undefined_variables(
func_or_var: Callable | object,
) -> set[str]:
"""
Find variables that are used but not defined in the source of *func_or_var*.

If the source code for *func_or_var* cannot be retrieved or parsed (e.g.,
for certain built-in objects or when no source is available), this
function returns an empty set instead of raising an exception.
"""
try:
# Prefer actual source code over string representations for both
# callables and other inspectable objects (e.g. classes, modules).
raw_source = inspect.getsource(func_or_var)
except (OSError, TypeError):
# No reliable source available; treat as having no undefined variables.
return set()

source = textwrap.dedent(raw_source)

try:
tree = ast.parse(source)
except SyntaxError:
# Source could not be parsed as Python code; fail gracefully.
return set()

visitor = UndefinedVariableVisitor()
visitor.visit(tree)
undefined_vars = visitor.used_vars - visitor.defined_vars
return undefined_vars.difference(set(dir(builtins)))


def get_call_dependencies(
func_or_var: Callable | object,
version_scraping: versions.VersionScrapingMap | None = None,
_call_dependencies: CallDependencies | None = None,
_visited: set[str] | None = None,
) -> CallDependencies:

call_dependencies: CallDependencies = _call_dependencies or {}
visited: set[str] = _visited or set()

func_fqn = versions.VersionInfo.of(func_or_var).fully_qualified_name
if func_fqn in visited:
return call_dependencies
visited.add(func_fqn)

# Find variables that are used but not defined
scope = object_scope.get_scope(func_or_var)
for item in find_undefined_variables(func_or_var):
obj = object_scope.resolve_attribute_to_object(item, scope)
info = versions.VersionInfo.of(obj, version_scraping=version_scraping)
call_dependencies[info] = obj

if info.version is None:
get_call_dependencies(obj, version_scraping, call_dependencies, visited)
return call_dependencies
Loading
Loading