From 3cef0cf188f835b94297dd5d63ff5e06037d1e5e Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sat, 26 Apr 2025 14:42:55 +0200 Subject: [PATCH 01/72] feat: customize --- .../_adapter/generic_call_adapter.py | 106 +++++++++--------- src/inline_snapshot/_customize.py | 67 +++++++++++ tests/adapter/test_dataclass.py | 10 +- 3 files changed, 124 insertions(+), 59 deletions(-) create mode 100644 src/inline_snapshot/_customize.py diff --git a/src/inline_snapshot/_adapter/generic_call_adapter.py b/src/inline_snapshot/_adapter/generic_call_adapter.py index afa67f8f..98f72127 100644 --- a/src/inline_snapshot/_adapter/generic_call_adapter.py +++ b/src/inline_snapshot/_adapter/generic_call_adapter.py @@ -9,12 +9,15 @@ from dataclasses import is_dataclass from typing import Any +from inline_snapshot._customize import CustomCall +from inline_snapshot._customize import Default +from inline_snapshot._customize import unwrap_default + from .._change import CallArg from .._change import Delete from ..syntax_warnings import InlineSnapshotSyntaxWarning from .adapter import Adapter from .adapter import Item -from .adapter import adapter_map def get_adapter_for_type(value_type): @@ -28,15 +31,6 @@ def get_adapter_for_type(value_type): return options[0] -class Argument: - value: Any - is_default: bool = False - - def __init__(self, value, is_default=False): - self.value = value - self.is_default = is_default - - class GenericCallAdapter(Adapter): @classmethod @@ -44,7 +38,7 @@ def check_type(cls, value_type) -> bool: raise NotImplementedError(cls) @classmethod - def arguments(cls, value) -> tuple[list[Argument], dict[str, Argument]]: + def arguments(cls, value) -> CustomCall: raise NotImplementedError(cls) @classmethod @@ -54,30 +48,26 @@ def argument(cls, value, pos_or_name) -> Any: @classmethod def repr(cls, value): - args, kwargs = cls.arguments(value) + call = cls.arguments(value) - arguments = [repr(value.value) for value in args] + [ - f"{key}={repr(value.value)}" - for key, value in kwargs.items() - if not value.is_default + arguments = [repr(value) for value in call.args] + [ + f"{key}={repr(value)}" + for key, value in call.kwargs.items() + if not isinstance(value, Default) ] return f"{repr(type(value))}({', '.join(arguments)})" @classmethod def map(cls, value, map_function): - new_args, new_kwargs = cls.arguments(value) - return type(value)( - *[adapter_map(arg.value, map_function) for arg in new_args], - **{ - k: adapter_map(kwarg.value, map_function) - for k, kwarg in new_kwargs.items() - }, - ) + return cls.arguments(value).map(map_function) @classmethod def items(cls, value, node): - new_args, new_kwargs = cls.arguments(value) + + args = cls.arguments(value) + new_args = args.args + new_kwargs = args.kwargs if node is not None: assert isinstance(node, ast.Call) @@ -96,10 +86,10 @@ def pos_arg_node(_): return None return [ - Item(value=arg.value, node=pos_arg_node(i)) + Item(value=unwrap_default(arg), node=pos_arg_node(i)) for i, arg in enumerate(new_args) ] + [ - Item(value=kw.value, node=kw_arg_node(name)) + Item(value=unwrap_default(kw), node=kw_arg_node(name)) for name, kw in new_kwargs.items() ] @@ -136,7 +126,9 @@ def assign(self, old_value, old_node, new_value): ) return old_value - new_args, new_kwargs = self.arguments(new_value) + call = self.arguments(new_value) + new_args = call.args + new_kwargs = call.kwargs # positional arguments @@ -145,8 +137,8 @@ def assign(self, old_value, old_node, new_value): for i, (new_value_element, node) in enumerate(zip(new_args, old_node.args)): old_value_element = self.argument(old_value, i) result = yield from self.get_adapter( - old_value_element, new_value_element.value - ).assign(old_value_element, node, new_value_element.value) + old_value_element, unwrap_default(new_value_element) + ).assign(old_value_element, node, unwrap_default(new_value_element)) result_args.append(result) if len(old_node.args) > len(new_args): @@ -166,14 +158,14 @@ def assign(self, old_value, old_node, new_value): node=old_node, arg_pos=insert_pos, arg_name=None, - new_code=self.context.file._value_to_code(value.value), - new_value=value.value, + new_code=self.context.file._value_to_code(unwrap_default(value)), + new_value=value, ) # keyword arguments result_kwargs = {} for kw in old_node.keywords: - if kw.arg not in new_kwargs or new_kwargs[kw.arg].is_default: + if kw.arg not in new_kwargs or isinstance(new_kwargs[kw.arg], Default): # delete entries yield Delete( ( @@ -192,20 +184,20 @@ def assign(self, old_value, old_node, new_value): to_insert = [] insert_pos = 0 for key, new_value_element in new_kwargs.items(): - if new_value_element.is_default: + if isinstance(new_value_element, Default): continue if key not in old_node_kwargs: # add new values - to_insert.append((key, new_value_element.value)) - result_kwargs[key] = new_value_element.value + to_insert.append((key, new_value_element)) + result_kwargs[key] = new_value_element else: node = old_node_kwargs[key] # check values with same keys old_value_element = self.argument(old_value, key) result_kwargs[key] = yield from self.get_adapter( - old_value_element, new_value_element.value - ).assign(old_value_element, node, new_value_element.value) + old_value_element, new_value_element + ).assign(old_value_element, node, new_value_element) if to_insert: for key, value in to_insert: @@ -236,7 +228,6 @@ def assign(self, old_value, old_node, new_value): new_code=self.context.file._value_to_code(value), new_value=value, ) - return type(old_value)(*result_args, **result_kwargs) @@ -265,9 +256,11 @@ def arguments(cls, value): ): is_default = True - kwargs[field.name] = Argument(value=field_value, is_default=is_default) + if is_default: + field_value = Default(field_value) + kwargs[field.name] = field_value - return ([], kwargs) + return CustomCall(type(value), *[], **kwargs) def argument(self, value, pos_or_name): if isinstance(pos_or_name, str): @@ -312,14 +305,14 @@ def arguments(cls, value): ) if default_value == field_value: - is_default = True - kwargs[field.name] = Argument( - value=field_value, is_default=is_default - ) + if is_default: + field_value = Default(field_value) - return ([], kwargs) + kwargs[field.name] = field_value + + return CustomCall(type(value), **kwargs) def argument(self, value, pos_or_name): assert isinstance(pos_or_name, str) @@ -376,9 +369,12 @@ def arguments(cls, value): ): is_default = True - kwargs[name] = Argument(value=field_value, is_default=is_default) + if is_default: + field_value = Default(field_value) + + kwargs[name] = field_value - return ([], kwargs) + return CustomCall(type(value), **kwargs) @classmethod def argument(cls, value, pos_or_name): @@ -412,10 +408,10 @@ def check_type(cls, value): @classmethod def arguments(cls, value: IsNamedTuple): - return ( - [], - { - field: Argument(value=getattr(value, field)) + return CustomCall( + type(value), + **{ + field: getattr(value, field) for field in value._fields if field not in value._field_defaults or getattr(value, field) != value._field_defaults[field] @@ -435,9 +431,9 @@ def check_type(cls, value): @classmethod def arguments(cls, value: defaultdict): - return ( - [Argument(value=value.default_factory), Argument(value=dict(value))], - {}, + return CustomCall( + type(value), + *[value.default_factory, dict(value)], ) def argument(self, value, pos_or_name): diff --git a/src/inline_snapshot/_customize.py b/src/inline_snapshot/_customize.py new file mode 100644 index 00000000..41bc1712 --- /dev/null +++ b/src/inline_snapshot/_customize.py @@ -0,0 +1,67 @@ +custom_functions = [] + + +def customize(f): + custom_functions.append(f) + return f + + +class Custom: + pass + + +class Default: + def __init__(self, value): + self.value = value + + +def unwrap_default(value): + if isinstance(value, Default): + return value.value + return value + + +class CustomCall(Custom): + def __init__(self, function, *args, **kwargs): + """ + CustomCall(f,1,2,a=3).kwonly(b=4) + """ + self._function = function + self._args = args + self._kwargs = kwargs + self._kwonly = {} + + @property + def args(self): + return self._args + + @property + def kwargs(self): + return {**self._kwargs, **self._kwonly} + + def kwonly(self, **kwonly): + assert not self._kwonly, "you should not call kwonly twice" + assert ( + not kwonly.keys() & self._kwargs.keys() + ), "same keys in kwargs and kwonly arguments" + self._kwonly = kwonly + return self + + def argument(self, pos_or_str): + if isinstance(pos_or_str, int): + return self._args[pos_or_str] + else: + return self.kwargs[pos_or_str] + + def map(self, f): + return self._function( + *[f(unwrap_default(x)) for x in self.args], + **{k: f(unwrap_default(v)) for k, v in self.kwargs.items()}, + ) + + +def get_handler(v): + for f in reversed(custom_functions): + r = f(v) + if isinstance(r, Custom): + return r diff --git a/tests/adapter/test_dataclass.py b/tests/adapter/test_dataclass.py index f0847176..b3badfb5 100644 --- a/tests/adapter/test_dataclass.py +++ b/tests/adapter/test_dataclass.py @@ -453,7 +453,8 @@ def test_remove_positional_argument(): """\ from inline_snapshot import snapshot -from inline_snapshot._adapter.generic_call_adapter import GenericCallAdapter,Argument +from inline_snapshot._adapter.generic_call_adapter import GenericCallAdapter +from inline_snapshot._customize import CustomCall class L: @@ -472,7 +473,7 @@ def check_type(cls, value_type): @classmethod def arguments(cls, value): - return ([Argument(x) for x in value.l],{}) + return CustomCall(L,*value.l) @classmethod def argument(cls, value, pos_or_name): @@ -498,7 +499,8 @@ def test_L3(): "tests/test_something.py": """\ from inline_snapshot import snapshot -from inline_snapshot._adapter.generic_call_adapter import GenericCallAdapter,Argument +from inline_snapshot._adapter.generic_call_adapter import GenericCallAdapter +from inline_snapshot._customize import CustomCall class L: @@ -517,7 +519,7 @@ def check_type(cls, value_type): @classmethod def arguments(cls, value): - return ([Argument(x) for x in value.l],{}) + return CustomCall(L,*value.l) @classmethod def argument(cls, value, pos_or_name): From 825f2b020ed65fca322f47e188b79fbe5ab3b83b Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Tue, 30 Sep 2025 08:16:20 +0200 Subject: [PATCH 02/72] refactor: cleanup --- .../_adapter/generic_call_adapter.py | 30 +------------------ src/inline_snapshot/_customize.py | 10 +++++-- 2 files changed, 8 insertions(+), 32 deletions(-) diff --git a/src/inline_snapshot/_adapter/generic_call_adapter.py b/src/inline_snapshot/_adapter/generic_call_adapter.py index 98f72127..b0b57105 100644 --- a/src/inline_snapshot/_adapter/generic_call_adapter.py +++ b/src/inline_snapshot/_adapter/generic_call_adapter.py @@ -43,7 +43,7 @@ def arguments(cls, value) -> CustomCall: @classmethod def argument(cls, value, pos_or_name) -> Any: - raise NotImplementedError(cls) + return cls.arguments(value).argument(pos_or_name) @classmethod def repr(cls, value): @@ -262,13 +262,6 @@ def arguments(cls, value): return CustomCall(type(value), *[], **kwargs) - def argument(self, value, pos_or_name): - if isinstance(pos_or_name, str): - return getattr(value, pos_or_name) - else: - args = [field for field in fields(value) if field.init] - return args[pos_or_name] - try: import attrs @@ -314,10 +307,6 @@ def arguments(cls, value): return CustomCall(type(value), **kwargs) - def argument(self, value, pos_or_name): - assert isinstance(pos_or_name, str) - return getattr(value, pos_or_name) - try: import pydantic @@ -376,11 +365,6 @@ def arguments(cls, value): return CustomCall(type(value), **kwargs) - @classmethod - def argument(cls, value, pos_or_name): - assert isinstance(pos_or_name, str) - return getattr(value, pos_or_name) - class IsNamedTuple(ABC): _inline_snapshot_name = "namedtuple" @@ -418,10 +402,6 @@ def arguments(cls, value: IsNamedTuple): }, ) - def argument(self, value, pos_or_name): - assert isinstance(pos_or_name, str) - return getattr(value, pos_or_name) - class DefaultDictAdapter(GenericCallAdapter): @classmethod @@ -435,11 +415,3 @@ def arguments(cls, value: defaultdict): type(value), *[value.default_factory, dict(value)], ) - - def argument(self, value, pos_or_name): - assert isinstance(pos_or_name, int) - if pos_or_name == 0: - return value.default_factory - elif pos_or_name == 1: - return dict(value) - assert False diff --git a/src/inline_snapshot/_customize.py b/src/inline_snapshot/_customize.py index 41bc1712..4febcc32 100644 --- a/src/inline_snapshot/_customize.py +++ b/src/inline_snapshot/_customize.py @@ -35,6 +35,10 @@ def __init__(self, function, *args, **kwargs): def args(self): return self._args + @property + def all_pos_args(self): + return [*self._args, *self._kwargs.values()] + @property def kwargs(self): return {**self._kwargs, **self._kwonly} @@ -49,13 +53,13 @@ def kwonly(self, **kwonly): def argument(self, pos_or_str): if isinstance(pos_or_str, int): - return self._args[pos_or_str] + return unwrap_default(self.all_pos_args[pos_or_str]) else: - return self.kwargs[pos_or_str] + return unwrap_default(self.kwargs[pos_or_str]) def map(self, f): return self._function( - *[f(unwrap_default(x)) for x in self.args], + *[f(unwrap_default(x)) for x in self._args], **{k: f(unwrap_default(v)) for k, v in self.kwargs.items()}, ) From c35d8dfe895e48601db3af9d2144908ef0b47529 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Wed, 1 Oct 2025 12:32:31 +0200 Subject: [PATCH 03/72] refactor: wip --- pyproject.toml | 4 +- src/inline_snapshot/_adapter/adapter.py | 37 +- src/inline_snapshot/_adapter/dict_adapter.py | 46 +- .../_adapter/factory_adapter.py | 30 ++ .../_adapter/generic_call_adapter.py | 229 +------- .../_adapter/sequence_adapter.py | 10 +- src/inline_snapshot/_adapter/value_adapter.py | 19 +- src/inline_snapshot/_code_repr.py | 6 +- src/inline_snapshot/_customize.py | 425 ++++++++++++++- src/inline_snapshot/_get_snapshot_value.py | 3 +- src/inline_snapshot/_new_adapter.py | 495 ++++++++++++++++++ .../_snapshot/collection_value.py | 39 +- src/inline_snapshot/_snapshot/dict_value.py | 62 ++- src/inline_snapshot/_snapshot/eq_value.py | 29 +- .../_snapshot/generic_value.py | 47 +- .../_snapshot/min_max_value.py | 40 +- .../_snapshot/undecided_value.py | 92 +++- src/inline_snapshot/testing/_example.py | 3 + tests/adapter/test_change_types.py | 2 +- tests/test_docs.py | 1 + tests/test_factory_adapter.py | 40 ++ tests/test_pytest_plugin.py | 1 + 22 files changed, 1302 insertions(+), 358 deletions(-) create mode 100644 src/inline_snapshot/_adapter/factory_adapter.py create mode 100644 src/inline_snapshot/_new_adapter.py create mode 100644 tests/test_factory_adapter.py diff --git a/pyproject.toml b/pyproject.toml index 3d517995..83e244d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -173,7 +173,8 @@ extra-dependencies = [ "pyright>=1.1.359", "pytest-freezer>=0.4.8", "pytest-mock>=3.14.0", - "black==25.1.0" + "black==25.1.0", + "setuptools" ] env-vars.TOP = "{root}" @@ -243,3 +244,4 @@ force_single_line=true [tool.inline-snapshot] show-updates=true +default-flags-tui=["disable"] diff --git a/src/inline_snapshot/_adapter/adapter.py b/src/inline_snapshot/_adapter/adapter.py index 0ed39a40..fbe333f5 100644 --- a/src/inline_snapshot/_adapter/adapter.py +++ b/src/inline_snapshot/_adapter/adapter.py @@ -4,34 +4,53 @@ import typing from dataclasses import dataclass +import pytest + +from inline_snapshot._customize import CustomCall +from inline_snapshot._customize import CustomDict +from inline_snapshot._customize import CustomList +from inline_snapshot._customize import CustomTuple +from inline_snapshot._customize import CustomUndefined +from inline_snapshot._customize import CustomValue from inline_snapshot._source_file import SourceFile def get_adapter_type(value): - from inline_snapshot._adapter.generic_call_adapter import get_adapter_for_type - adapter = get_adapter_for_type(type(value)) - if adapter is not None: - return adapter + if isinstance(value, CustomCall): + from .generic_call_adapter import CallAdapter - if isinstance(value, list): + pytest.skip() + + return CallAdapter + + if isinstance(value, CustomList): from .sequence_adapter import ListAdapter + pytest.skip() + return ListAdapter - if type(value) is tuple: + if isinstance(value, CustomTuple): from .sequence_adapter import TupleAdapter + pytest.skip() + return TupleAdapter - if isinstance(value, dict): + if isinstance(value, CustomDict): from .dict_adapter import DictAdapter + pytest.skip() + return DictAdapter - from .value_adapter import ValueAdapter + if isinstance(value, (CustomValue, CustomUndefined)): + from .value_adapter import ValueAdapter + + return ValueAdapter - return ValueAdapter + raise AssertionError(f"no handler for {type(value)}") class Item(typing.NamedTuple): diff --git a/src/inline_snapshot/_adapter/dict_adapter.py b/src/inline_snapshot/_adapter/dict_adapter.py index 1f0b4a0c..44d78586 100644 --- a/src/inline_snapshot/_adapter/dict_adapter.py +++ b/src/inline_snapshot/_adapter/dict_adapter.py @@ -3,6 +3,10 @@ import ast import warnings +import pytest + +from inline_snapshot._customize import CustomDict + from .._change import Delete from .._change import DictInsert from ..syntax_warnings import InlineSnapshotSyntaxWarning @@ -32,6 +36,9 @@ def map(cls, value, map_function): @classmethod def items(cls, value, node): + pytest.skip() + assert isinstance(value, CustomDict) + value = value.value if node is None or not isinstance(node, ast.Dict): return [Item(value=value, node=None) for value in value.values()] @@ -46,18 +53,25 @@ def items(cls, value, node): except Exception: pass else: - assert node_key == value_key + assert node_key == value_key.eval(), f"{node_key!r} != {value_key!r}" result.append(Item(value=value[value_key], node=node_value)) return result def assign(self, old_value, old_node, new_value): + pytest.skip() + assert isinstance(old_value, CustomDict) + assert isinstance(new_value, CustomDict) + if old_node is not None: if not ( - isinstance(old_node, ast.Dict) and len(old_value) == len(old_node.keys) + isinstance(old_node, ast.Dict) + and len(old_value.value) == len(old_node.keys) ): - result = yield from self.value_assign(old_value, old_node, new_value) + result = yield from self.value_assign( + old_value.value, old_node, new_value + ) return result for key, value in zip(old_node.keys, old_node.values): @@ -70,39 +84,45 @@ def assign(self, old_value, old_node, new_value): ) return old_value - for value, node in zip(old_value.keys(), old_node.keys): + for value, node in zip(old_value.value.keys(), old_node.keys): try: # this is just a sanity check, dicts should be ordered node_value = ast.literal_eval(node) except: continue - assert node_value == value + assert node_value == value.eval() result = {} for key, node in zip( - old_value.keys(), - (old_node.values if old_node is not None else [None] * len(old_value)), + old_value.value.keys(), + ( + old_node.values + if old_node is not None + else [None] * len(old_value.value) + ), ): if key not in new_value: # delete entries - yield Delete("fix", self.context.file._source, node, old_value[key]) + yield Delete( + "fix", self.context.file._source, node, old_value.value[key] + ) to_insert = [] insert_pos = 0 for key, new_value_element in new_value.items(): - if key not in old_value: + if key not in old_value.value: # add new values to_insert.append((key, new_value_element)) result[key] = new_value_element else: if isinstance(old_node, ast.Dict): - node = old_node.values[list(old_value.keys()).index(key)] + node = old_node.values[list(old_value.value.keys()).index(key)] else: node = None # check values with same keys result[key] = yield from self.get_adapter( - old_value[key], new_value[key] + old_value.value[key], new_value.value[key] ).assign(old_value[key], node, new_value[key]) if to_insert: @@ -137,9 +157,9 @@ def assign(self, old_value, old_node, new_value): "fix", self.context.file._source, old_node, - len(old_value), + len(old_value.value), new_code, to_insert, ) - return result + return CustomDict(result) diff --git a/src/inline_snapshot/_adapter/factory_adapter.py b/src/inline_snapshot/_adapter/factory_adapter.py new file mode 100644 index 00000000..dd182c60 --- /dev/null +++ b/src/inline_snapshot/_adapter/factory_adapter.py @@ -0,0 +1,30 @@ +"""Factory function adapter for handling factory functions like list() cleanly.""" + +from .adapter import Adapter + + +class FactoryAdapter(Adapter): + """Adapter for factory functions used in defaultdict.""" + + @classmethod + def check_type(cls, value_type): + # Check if value is a factory function (type/class or callable) + if isinstance(value_type, type): + return True + return callable(value_type) + + @classmethod + def repr(cls, value): + # Return clean name for builtin types + value_str = repr(value) + if value_str.startswith(" wrapper + return value_str + + @classmethod + def map(cls, value, map_function): + return value + + def assign(self, old_value, old_node, new_value): + # Preserve factory function identity + return old_value diff --git a/src/inline_snapshot/_adapter/generic_call_adapter.py b/src/inline_snapshot/_adapter/generic_call_adapter.py index b0b57105..0e4c46e5 100644 --- a/src/inline_snapshot/_adapter/generic_call_adapter.py +++ b/src/inline_snapshot/_adapter/generic_call_adapter.py @@ -2,15 +2,12 @@ import ast import warnings -from abc import ABC -from collections import defaultdict -from dataclasses import MISSING -from dataclasses import fields -from dataclasses import is_dataclass from typing import Any +import pytest + from inline_snapshot._customize import CustomCall -from inline_snapshot._customize import Default +from inline_snapshot._customize import CustomDefault from inline_snapshot._customize import unwrap_default from .._change import CallArg @@ -21,49 +18,45 @@ def get_adapter_for_type(value_type): - subclasses = GenericCallAdapter.__subclasses__() - options = [cls for cls in subclasses if cls.check_type(value_type)] - - if not options: - return - - assert len(options) == 1 - return options[0] + assert False, "unreachable" + assert isinstance(value_type, CustomCall) + return CallAdapter -class GenericCallAdapter(Adapter): - - @classmethod - def check_type(cls, value_type) -> bool: - raise NotImplementedError(cls) +class CallAdapter(Adapter): @classmethod def arguments(cls, value) -> CustomCall: - raise NotImplementedError(cls) + pytest.skip() + return value @classmethod def argument(cls, value, pos_or_name) -> Any: + pytest.skip() return cls.arguments(value).argument(pos_or_name) @classmethod def repr(cls, value): + pytest.skip() call = cls.arguments(value) arguments = [repr(value) for value in call.args] + [ f"{key}={repr(value)}" for key, value in call.kwargs.items() - if not isinstance(value, Default) + if not isinstance(value, CustomDefault) ] return f"{repr(type(value))}({', '.join(arguments)})" @classmethod def map(cls, value, map_function): + pytest.skip() return cls.arguments(value).map(map_function) @classmethod def items(cls, value, node): + pytest.skip() args = cls.arguments(value) new_args = args.args @@ -94,6 +87,7 @@ def pos_arg_node(_): ] def assign(self, old_value, old_node, new_value): + pytest.skip() if old_node is None or not isinstance(old_node, ast.Call): result = yield from self.value_assign(old_value, old_node, new_value) return result @@ -165,7 +159,9 @@ def assign(self, old_value, old_node, new_value): # keyword arguments result_kwargs = {} for kw in old_node.keywords: - if kw.arg not in new_kwargs or isinstance(new_kwargs[kw.arg], Default): + if kw.arg not in new_kwargs or isinstance( + new_kwargs[kw.arg], CustomDefault + ): # delete entries yield Delete( ( @@ -184,7 +180,7 @@ def assign(self, old_value, old_node, new_value): to_insert = [] insert_pos = 0 for key, new_value_element in new_kwargs.items(): - if isinstance(new_value_element, Default): + if isinstance(new_value_element, CustomDefault): continue if key not in old_node_kwargs: # add new values @@ -216,7 +212,6 @@ def assign(self, old_value, old_node, new_value): insert_pos += 1 if to_insert: - for key, value in to_insert: yield CallArg( @@ -229,189 +224,3 @@ def assign(self, old_value, old_node, new_value): new_value=value, ) return type(old_value)(*result_args, **result_kwargs) - - -class DataclassAdapter(GenericCallAdapter): - - @classmethod - def check_type(cls, value): - return is_dataclass(value) - - @classmethod - def arguments(cls, value): - - kwargs = {} - - for field in fields(value): # type: ignore - if field.repr: - field_value = getattr(value, field.name) - is_default = False - - if field.default != MISSING and field.default == field_value: - is_default = True - - if ( - field.default_factory != MISSING - and field.default_factory() == field_value - ): - is_default = True - - if is_default: - field_value = Default(field_value) - kwargs[field.name] = field_value - - return CustomCall(type(value), *[], **kwargs) - - -try: - import attrs -except ImportError: # pragma: no cover - pass -else: - - class AttrAdapter(GenericCallAdapter): - - @classmethod - def check_type(cls, value): - return attrs.has(value) - - @classmethod - def arguments(cls, value): - - kwargs = {} - - for field in attrs.fields(type(value)): - if field.repr: - field_value = getattr(value, field.name) - is_default = False - - if field.default is not attrs.NOTHING: - - default_value = ( - field.default - if not isinstance(field.default, attrs.Factory) - else ( - field.default.factory() - if not field.default.takes_self - else field.default.factory(value) - ) - ) - - if default_value == field_value: - is_default = True - - if is_default: - field_value = Default(field_value) - - kwargs[field.name] = field_value - - return CustomCall(type(value), **kwargs) - - -try: - import pydantic -except ImportError: # pragma: no cover - pass -else: - # import pydantic - if pydantic.version.VERSION.startswith("1."): - # pydantic v1 - from pydantic.fields import Undefined as PydanticUndefined # type: ignore[attr-defined,no-redef] - - def get_fields(value): - return value.__fields__ - - else: - # pydantic v2 - from pydantic_core import PydanticUndefined - - def get_fields(value): - return type(value).model_fields - - from pydantic import BaseModel - - class PydanticContainer(GenericCallAdapter): - - @classmethod - def check_type(cls, value): - return issubclass(value, BaseModel) - - @classmethod - def arguments(cls, value): - - kwargs = {} - - for name, field in get_fields(value).items(): # type: ignore - if getattr(field, "repr", True): - field_value = getattr(value, name) - is_default = False - - if ( - field.default is not PydanticUndefined - and field.default == field_value - ): - is_default = True - - if ( - field.default_factory is not None - and field.default_factory() == field_value - ): - is_default = True - - if is_default: - field_value = Default(field_value) - - kwargs[name] = field_value - - return CustomCall(type(value), **kwargs) - - -class IsNamedTuple(ABC): - _inline_snapshot_name = "namedtuple" - - _fields: tuple - _field_defaults: dict - - @classmethod - def __subclasshook__(cls, t): - b = t.__bases__ - if len(b) != 1 or b[0] != tuple: - return False - f = getattr(t, "_fields", None) - if not isinstance(f, tuple): - return False - return all(type(n) == str for n in f) - - -class NamedTupleAdapter(GenericCallAdapter): - - @classmethod - def check_type(cls, value): - return issubclass(value, IsNamedTuple) - - @classmethod - def arguments(cls, value: IsNamedTuple): - - return CustomCall( - type(value), - **{ - field: getattr(value, field) - for field in value._fields - if field not in value._field_defaults - or getattr(value, field) != value._field_defaults[field] - }, - ) - - -class DefaultDictAdapter(GenericCallAdapter): - @classmethod - def check_type(cls, value): - return issubclass(value, defaultdict) - - @classmethod - def arguments(cls, value: defaultdict): - - return CustomCall( - type(value), - *[value.default_factory, dict(value)], - ) diff --git a/src/inline_snapshot/_adapter/sequence_adapter.py b/src/inline_snapshot/_adapter/sequence_adapter.py index a080e840..62183e54 100644 --- a/src/inline_snapshot/_adapter/sequence_adapter.py +++ b/src/inline_snapshot/_adapter/sequence_adapter.py @@ -4,6 +4,8 @@ import warnings from collections import defaultdict +import pytest + from .._align import add_x from .._align import align from .._change import Delete @@ -23,6 +25,7 @@ class SequenceAdapter(Adapter): @classmethod def repr(cls, value): + pytest.skip() if len(value) == 1 and cls.trailing_comma: seq = repr(value[0]) + "," else: @@ -31,19 +34,22 @@ def repr(cls, value): @classmethod def map(cls, value, map_function): + pytest.skip() result = [adapter_map(v, map_function) for v in value] return cls.value_type(result) @classmethod def items(cls, value, node): + pytest.skip() if node is None or not isinstance(node, cls.node_type): return [Item(value=v, node=None) for v in value] - assert len(value) == len(node.elts) + assert len(value.value) == len(node.elts) - return [Item(value=v, node=n) for v, n in zip(value, node.elts)] + return [Item(value=v, node=n) for v, n in zip(value.value, node.elts)] def assign(self, old_value, old_node, new_value): + pytest.skip() if old_node is not None: if not isinstance( old_node, ast.List if isinstance(old_value, list) else ast.Tuple diff --git a/src/inline_snapshot/_adapter/value_adapter.py b/src/inline_snapshot/_adapter/value_adapter.py index a47e38eb..e443af4c 100644 --- a/src/inline_snapshot/_adapter/value_adapter.py +++ b/src/inline_snapshot/_adapter/value_adapter.py @@ -3,10 +3,13 @@ import ast import warnings +from inline_snapshot._customize import Custom +from inline_snapshot._customize import CustomUnmanaged +from inline_snapshot._customize import CustomValue + from .._change import Replace from .._code_repr import value_code_repr from .._sentinels import undefined -from .._unmanaged import Unmanaged from .._unmanaged import update_allowed from .._utils import value_to_token from ..syntax_warnings import InlineSnapshotInfo @@ -25,18 +28,24 @@ def map(cls, value, map_function): def assign(self, old_value, old_node, new_value): # generic fallback + assert isinstance(old_value, Custom) + assert isinstance(new_value, Custom) # because IsStr() != IsStr() - if isinstance(old_value, Unmanaged): + if isinstance(old_value, CustomUnmanaged): return old_value if old_node is None: new_token = [] else: - new_token = value_to_token(new_value) + new_token = value_to_token(new_value.eval()) - if isinstance(old_node, ast.JoinedStr) and isinstance(new_value, str): - if not old_value == new_value: + if ( + isinstance(old_node, ast.JoinedStr) + and isinstance(new_value, CustomValue) + and isinstance(new_value.value, str) + ): + if not old_value.eval() == new_value.eval(): warnings.warn_explicit( f"inline-snapshot will be able to fix f-strings in the future.\nThe current string value is:\n {new_value!r}", filename=self.context.file._source.filename, diff --git a/src/inline_snapshot/_code_repr.py b/src/inline_snapshot/_code_repr.py index 0e24cf4b..cfc1b194 100644 --- a/src/inline_snapshot/_code_repr.py +++ b/src/inline_snapshot/_code_repr.py @@ -81,11 +81,7 @@ def code_repr(obj): def mocked_code_repr(obj): - from inline_snapshot._adapter.adapter import get_adapter_type - - adapter = get_adapter_type(obj) - assert adapter is not None - return adapter.repr(obj) + return value_code_repr(obj) def value_code_repr(obj): diff --git a/src/inline_snapshot/_customize.py b/src/inline_snapshot/_customize.py index 4febcc32..888e0c71 100644 --- a/src/inline_snapshot/_customize.py +++ b/src/inline_snapshot/_customize.py @@ -1,35 +1,112 @@ +from __future__ import annotations + +import ast +from abc import ABC +from abc import abstractmethod +from collections import defaultdict +from types import BuiltinFunctionType +from types import FunctionType +from typing import Any +from typing import Callable + +from inline_snapshot._code_repr import code_repr +from inline_snapshot._unmanaged import Unmanaged +from inline_snapshot._unmanaged import is_unmanaged + custom_functions = [] +from dataclasses import MISSING +from dataclasses import dataclass +from dataclasses import field +from dataclasses import fields +from dataclasses import is_dataclass + +from inline_snapshot._sentinels import undefined -def customize(f): + +def customize(f: Callable[[Any, Builder], Custom | None]): custom_functions.append(f) return f -class Custom: - pass +class Custom(ABC): + node_type: type[ast.AST] = ast.AST + + def __hash__(self): + return hash(self.eval()) + + def __eq__(self, other): + if isinstance(other, Custom): + return self.eval() == other.eval() + return NotImplemented -class Default: + @abstractmethod + def map(self, f): + raise NotImplementedError() + + @abstractmethod + def repr(self): + raise NotImplementedError() + + def eval(self): + return self.map(lambda a: a) + + +@dataclass(frozen=True) +class CustomDefault(Custom): + value: Custom = field(compare=False) + + def repr(self): + return self.value.repr() + + def map(self, f): + return self.value.map(f) + + +class CustomUnmanaged(Custom, Unmanaged): def __init__(self, value): - self.value = value + # TODO remove Unmanaged + Custom.__init__(self) + Unmanaged.__init__(self, value) + + def repr(self): + return "" + + def map(self, f): + return self.value + + +class CustomUndefined(Custom): + def __init__(self): + self.value = undefined + + def repr(self) -> str: + return "..." + + def map(self, f): + return f(undefined) def unwrap_default(value): - if isinstance(value, Default): + if isinstance(value, CustomDefault): return value.value return value +@dataclass(frozen=True) class CustomCall(Custom): - def __init__(self, function, *args, **kwargs): - """ - CustomCall(f,1,2,a=3).kwonly(b=4) - """ - self._function = function - self._args = args - self._kwargs = kwargs - self._kwonly = {} + node_type = ast.Call + _function: Custom = field(compare=False) + _args: list[Custom] = field(compare=False) + _kwargs: dict[str, Custom] = field(compare=False) + _kwonly: dict[str, Custom] = field(default_factory=dict, compare=False) + + def repr(self) -> str: + args = [] + args += [a.repr() for a in self.args] + args += [f"{k} = {v.repr()}" for k, v in self.kwargs.items()] + return f"{self._function.repr()}({', '.join(args)})" @property def args(self): @@ -58,14 +135,318 @@ def argument(self, pos_or_str): return unwrap_default(self.kwargs[pos_or_str]) def map(self, f): - return self._function( - *[f(unwrap_default(x)) for x in self._args], - **{k: f(unwrap_default(v)) for k, v in self.kwargs.items()}, + return self._function.map(f)( + *[f(x.map(f)) for x in self._args], + **{k: f(v.map(f)) for k, v in self.kwargs.items()}, + ) + + +class CustomSequenceTypes: + trailing_comma: bool + braces: str + value_type: type + + +@dataclass(frozen=True) +class CustomSequence(Custom, CustomSequenceTypes): + value: list[Custom] = field(compare=False) + + def map(self, f): + return f(self.value_type([x.map(f) for x in self.value])) + + def repr(self) -> str: + trailing_comma = self.trailing_comma and len(self.value) == 1 + return f"{self.braces[0]}{', '.join(v.repr() for v in self.value)}{', ' if trailing_comma else ''}{self.braces[1]}" + + +class CustomList(CustomSequence): + node_type = ast.List + value_type = list + braces = "[]" + trailing_comma = False + + +class CustomTuple(CustomSequence): + node_type = ast.Tuple + value_type = tuple + braces = "()" + trailing_comma = True + + +@dataclass(frozen=True) +class CustomDict(Custom): + node_type = ast.Dict + value: dict[Custom, Custom] = field(compare=False) + + def map(self, f): + return f({k.map(f): v.map(f) for k, v in self.value.items()}) + + def repr(self) -> str: + return f"{{ { ', '.join(f'{k.repr()} = {v.repr()}' for k,v in self.value.items())} }}" + + +class CustomValue(Custom): + def __init__(self, value, repr_str=None): + assert not isinstance(value, Custom) + + if repr_str is None: + self.repr_str = code_repr(value) + else: + self.repr_str = repr_str + + self.value = value + + def map(self, f): + return f(self.value) + + def repr(self) -> str: + return self.repr_str + + def __repr__(self): + return f"CustomValue({self.repr_str})" + + +@customize +def standard_handler(value, builder: Builder): + if isinstance(value, list): + return builder.List(value) + + if isinstance(value, tuple): + return builder.Tuple(value) + + if isinstance(value, dict): + return builder.Dict(value) + + +@customize +def function_handler(value, builder: Builder): + if isinstance(value, FunctionType): + return builder.Value(value, value.__qualname__) + + +@customize +def builtin_function_handler(value, builder: Builder): + if isinstance(value, BuiltinFunctionType): + return builder.Value(value, value.__name__) + + +@customize +def type_handler(value, builder: Builder): + if isinstance(value, type): + return builder.Value(value, value.__qualname__) + + +@customize +def dataclass_handler(value, builder: Builder): + + if is_dataclass(value) and not isinstance(value, type): + + kwargs = {} + + for field in fields(value): # type: ignore + if field.repr: + field_value = getattr(value, field.name) + is_default = False + + if field.default != MISSING and field.default == field_value: + is_default = True + + if ( + field.default_factory != MISSING + and field.default_factory() == field_value + ): + is_default = True + + if is_default: + field_value = builder.Default(field_value) + kwargs[field.name] = field_value + + return builder.Call(value, type(value), [], kwargs, {}) + + +try: + import attrs +except ImportError: # pragma: no cover + pass +else: + + @customize + def attrs_handler(value, builder: Builder): + + if attrs.has(type(value)): + + kwargs = {} + + for field in attrs.fields(type(value)): + if field.repr: + field_value = getattr(value, field.name) + is_default = False + + if field.default is not attrs.NOTHING: + + default_value = ( + field.default + if not isinstance(field.default, attrs.Factory) + else ( + field.default.factory() + if not field.default.takes_self + else field.default.factory(value) + ) + ) + + if default_value == field_value: + is_default = True + + if is_default: + field_value = builder.Default(field_value) + + kwargs[field.name] = field_value + + return builder.Call(value, type(value), [], kwargs, {}) + + +try: + import pydantic +except ImportError: # pragma: no cover + pass +else: + # import pydantic + if pydantic.version.VERSION.startswith("1."): + # pydantic v1 + from pydantic.fields import Undefined as PydanticUndefined # type: ignore[attr-defined,no-redef] + + def get_fields(value): + return value.__fields__ + + else: + # pydantic v2 + from pydantic_core import PydanticUndefined + + def get_fields(value): + return type(value).model_fields + + from pydantic import BaseModel + + @customize + def attrs_handler(value, builder: Builder): + + if isinstance(value, BaseModel): + + kwargs = {} + + for name, field in get_fields(value).items(): # type: ignore + if getattr(field, "repr", True): + field_value = getattr(value, name) + is_default = False + + if ( + field.default is not PydanticUndefined + and field.default == field_value + ): + is_default = True + + if ( + field.default_factory is not None + and field.default_factory() == field_value + ): + is_default = True + + if is_default: + field_value = builder.Default(field_value) + + kwargs[name] = field_value + + return builder.Call(value, type(value), [], kwargs, {}) + + +@customize +def namedtuple_handler(value, builder: Builder): + t = type(value) + b = t.__bases__ + if len(b) != 1 or b[0] != tuple: + return + f = getattr(t, "_fields", None) + if not isinstance(f, tuple): + return + if not all(type(n) == str for n in f): + return + + # TODO handle with builder.Default + + return builder.Call( + value, + type(value), + [], + { + field: getattr(value, field) + for field in value._fields + if field not in value._field_defaults + or getattr(value, field) != value._field_defaults[field] + }, + {}, + ) + + +@customize +def defaultdict_handler(value, builder: Builder): + if isinstance(value, defaultdict): + return builder.Call( + value, type(value), [value.default_factory, dict(value)], {}, {} ) -def get_handler(v): - for f in reversed(custom_functions): - r = f(v) - if isinstance(r, Custom): - return r +@customize +def unmanaged_handler(value, builder: Builder): + if is_unmanaged(value): + return CustomUnmanaged(value) + + +@customize +def undefined_handler(value, builder: Builder): + if value is undefined: + return CustomUndefined() + + +class Builder: + def get_handler(self, v) -> Custom: + if isinstance(v, Custom): + return v + + for f in reversed(custom_functions): + r = f(v, self) + if isinstance(r, Custom): + return r + return CustomValue(v) + + def List(self, value) -> CustomList: + custom = [self.get_handler(v) for v in value] + return CustomList(value=custom) + + def Tuple(self, value) -> CustomTuple: + custom = [self.get_handler(v) for v in value] + return CustomTuple(value=custom) + + def Call( + self, value, function, posonly_args=[], kwargs={}, kwonly_args={} + ) -> CustomCall: + function = self.get_handler(function) + posonly_args = [self.get_handler(arg) for arg in posonly_args] + kwargs = {k: self.get_handler(arg) for k, arg in kwargs.items()} + kwonly_args = {k: self.get_handler(arg) for k, arg in kwonly_args.items()} + + return CustomCall( + _function=function, + _args=posonly_args, + _kwargs=kwargs, + _kwonly=kwonly_args, + ) + + def Default(self, value) -> CustomDefault: + return CustomDefault(value=self.get_handler(value)) + + def Dict(self, value) -> CustomDict: + custom = {self.get_handler(k): self.get_handler(v) for k, v in value.items()} + return CustomDict(value=custom) + + def Value(self, value, repr) -> CustomValue: + return CustomValue(value, repr) diff --git a/src/inline_snapshot/_get_snapshot_value.py b/src/inline_snapshot/_get_snapshot_value.py index 8239f8e4..69474bed 100644 --- a/src/inline_snapshot/_get_snapshot_value.py +++ b/src/inline_snapshot/_get_snapshot_value.py @@ -1,6 +1,5 @@ from typing import TypeVar -from ._adapter.adapter import adapter_map from ._exceptions import UsageError from ._external._external import External from ._external._external_file import ExternalFile @@ -15,7 +14,7 @@ def unwrap(value): if isinstance(value, GenericValue): - return adapter_map(value._visible_value(), lambda v: unwrap(v)[0]), True + return value._visible_value().map(lambda v: unwrap(v)[0]), True if isinstance(value, (External, Outsourced, ExternalFile)): try: diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py new file mode 100644 index 00000000..cd5ca431 --- /dev/null +++ b/src/inline_snapshot/_new_adapter.py @@ -0,0 +1,495 @@ +from __future__ import annotations + +import ast +import warnings +from collections import defaultdict +from typing import Generator + +from inline_snapshot._align import add_x +from inline_snapshot._align import align +from inline_snapshot._change import CallArg +from inline_snapshot._change import Change +from inline_snapshot._change import Delete +from inline_snapshot._change import DictInsert +from inline_snapshot._change import ListInsert +from inline_snapshot._change import Replace +from inline_snapshot._compare_context import compare_context +from inline_snapshot._customize import Custom +from inline_snapshot._customize import CustomCall +from inline_snapshot._customize import CustomDefault +from inline_snapshot._customize import CustomDict +from inline_snapshot._customize import CustomList +from inline_snapshot._customize import CustomSequence +from inline_snapshot._customize import CustomTuple +from inline_snapshot._customize import CustomUndefined +from inline_snapshot._customize import CustomUnmanaged +from inline_snapshot._customize import CustomValue +from inline_snapshot._utils import value_to_token +from inline_snapshot.syntax_warnings import InlineSnapshotInfo +from inline_snapshot.syntax_warnings import InlineSnapshotSyntaxWarning + + +def reeval(old_value: Custom, value: Custom) -> Custom: + + if isinstance(old_value, CustomDefault): + return reeval(old_value.value, value) + + if isinstance(value, CustomDefault): + return CustomDefault(reeval(old_value, value.value)) + + if type(old_value) is not type(value): + return CustomUnmanaged(value.eval()) + + function_name = f"reeval_{type(old_value).__name__}" + result = globals()[function_name](old_value, value) + assert isinstance(result, Custom) + assert result == value + return result + + +def reeval_CustomList(old_value: CustomList, value: CustomList): + assert len(old_value.value) == len(value.value) + return CustomList([reeval(a, b) for a, b in zip(old_value.value, value.value)]) + + +def reeval_CustomUnmanaged(old_value: CustomUnmanaged, value: CustomUnmanaged): + old_value.value = value.value + return old_value + + +def reeval_CustomUndefined(old_value, value): + return value + + +def reeval_CustomValue(old_value: CustomValue, value: CustomValue): + return value + + +def reeval_CustomCall(old_value: CustomCall, value: CustomCall): + return CustomCall( + _function=reeval(old_value._function, value._function), + _args=[reeval(a, b) for a, b in zip(old_value._args, value._args)], + _kwargs={ + k: reeval(old_value._kwargs[k], value._kwargs[k]) for k in old_value._kwargs + }, + _kwonly={ + k: reeval(old_value._kwonly[k], value._kwonly[k]) for k in old_value._kwonly + }, + ) + + +def reeval_CustomTuple(old_value, value): + assert len(old_value.value) == len(value.value) + return CustomTuple([reeval(a, b) for a, b in zip(old_value.value, value.value)]) + + +def reeval_CustomDict(old_value, value): + assert len(old_value.value) == len(value.value) + return CustomDict( + { + reeval(k1, k2): reeval(v1, v2) + for (k1, v1), (k2, v2) in zip(old_value.value.items(), value.value.items()) + } + ) + + +class NewAdapter: + + def __init__(self, context): + self.context = context + + def compare( + self, old_value: Custom, old_node, new_value: Custom + ) -> Generator[Change, None, Custom]: + + if isinstance(old_value, CustomUnmanaged): + return old_value + + if isinstance(new_value, CustomUnmanaged): + raise UsageError("unmanaged values can not be compared with snapshots") + + print("compare", old_value, new_value) + + if type(old_value) is not type(new_value) or not isinstance( + old_node, new_value.node_type + ): + result = yield from self.compare_CustomValue(old_value, old_node, new_value) + return result + + function_name = f"compare_{type(old_value).__name__}" + result = yield from getattr(self, function_name)(old_value, old_node, new_value) + + return result + + def compare_CustomValue( + self, old_value: Custom, old_node: ast.AST, new_value: Custom + ) -> Generator[Change, None, Custom]: + + assert isinstance(old_value, Custom) + assert isinstance(new_value, Custom) + + # because IsStr() != IsStr() + if isinstance(old_value, CustomUnmanaged): + return old_value + + if old_node is None: + new_token = [] + else: + new_token = value_to_token(new_value.eval()) + + if ( + isinstance(old_node, ast.JoinedStr) + and isinstance(new_value, CustomValue) + and isinstance(new_value.value, str) + ): + if not old_value.eval() == new_value.eval(): + warnings.warn_explicit( + f"inline-snapshot will be able to fix f-strings in the future.\nThe current string value is:\n {new_value!r}", + filename=self.context.file._source.filename, + lineno=old_node.lineno, + category=InlineSnapshotInfo, + ) + return old_value + + if not old_value == new_value: + if isinstance(old_value, CustomUndefined): + flag = "create" + else: + flag = "fix" + elif ( + old_node is not None + and not isinstance(old_value, CustomUnmanaged) + and self.context.file._token_of_node(old_node) != new_token + ): + flag = "update" + else: + # equal and equal repr + return old_value + + new_code = self.context.file._token_to_code(new_token) + + yield Replace( + node=old_node, + file=self.context.file._source, + new_code=new_code, + flag=flag, + old_value=old_value.eval(), + new_value=new_value, + ) + + return new_value + + def compare_CustomSequence( + self, old_value: CustomSequence, old_node: ast.AST, new_value: CustomSequence + ) -> Generator[Change, None, CustomList]: + + if old_node is not None: + if not isinstance( + old_node, ast.List if isinstance(old_value.eval(), list) else ast.Tuple + ): + breakpoint() + assert False + + for e in old_node.elts: + if isinstance(e, ast.Starred): + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=self.context.file.filename, + lineno=e.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return old_value + + with compare_context(): + diff = add_x(align(old_value.value, new_value.value)) + old = zip( + old_value.value, + old_node.elts if old_node is not None else [None] * len(old_value), + ) + new = iter(new_value.value) + old_position = 0 + to_insert = defaultdict(list) + result = [] + for c in diff: + if c in "mx": + old_value_element, old_node_element = next(old) + new_value_element = next(new) + v = yield from self.compare( + old_value_element, old_node_element, new_value_element + ) + result.append(v) + old_position += 1 + elif c == "i": + new_value_element = next(new) + new_code = self.context.file._value_to_code(new_value_element) + result.append(new_value_element) + to_insert[old_position].append((new_code, new_value_element)) + elif c == "d": + old_value_element, old_node_element = next(old) + yield Delete( + "fix", + self.context.file._source, + old_node_element, + old_value_element, + ) + old_position += 1 + else: + assert False + + for position, code_values in to_insert.items(): + yield ListInsert( + "fix", self.context.file._source, old_node, position, *zip(*code_values) + ) + + return type(new_value)(result) + + compare_CustomTuple = compare_CustomSequence + compare_CustomList = compare_CustomSequence + + def compare_CustomDict( + self, old_value: CustomDict, old_node: ast.AST, new_value: CustomDict + ) -> Generator[Change, None, CustomDict]: + assert isinstance(old_value, CustomDict) + assert isinstance(new_value, CustomDict) + + if old_node is not None: + if not ( + isinstance(old_node, ast.Dict) + and len(old_value.value) == len(old_node.keys) + ): + result = yield from self.compare_CustomValue( + old_value, old_node, new_value + ) + return result + + for key, value in zip(old_node.keys, old_node.values): + if key is None: + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=self.context.file._source.filename, + lineno=value.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return old_value + + for value, node in zip(old_value.value.keys(), old_node.keys): + + try: + # this is just a sanity check, dicts should be ordered + node_value = ast.literal_eval(node) + except: + continue + assert node_value == value.eval() + + result = {} + for key, node in zip( + old_value.value.keys(), + ( + old_node.values + if old_node is not None + else [None] * len(old_value.value) + ), + ): + if key not in new_value.value: + # delete entries + yield Delete( + "fix", self.context.file._source, node, old_value.value[key] + ) + + to_insert = [] + insert_pos = 0 + for key, new_value_element in new_value.value.items(): + if key not in old_value.value: + # add new values + to_insert.append((key, new_value_element)) + result[key] = new_value_element + else: + if isinstance(old_node, ast.Dict): + node = old_node.values[list(old_value.value.keys()).index(key)] + else: + node = None + # check values with same keys + result[key] = yield from self.compare( + old_value.value[key], node, new_value.value[key] + ) + + if to_insert: + new_code = [ + ( + self.context.file._value_to_code(k), + self.context.file._value_to_code(v), + ) + for k, v in to_insert + ] + yield DictInsert( + "fix", + self.context.file._source, + old_node, + insert_pos, + new_code, + to_insert, + ) + to_insert = [] + + insert_pos += 1 + + if to_insert: + new_code = [ + ( + self.context.file._value_to_code(k), + self.context.file._value_to_code(v), + ) + for k, v in to_insert + ] + yield DictInsert( + "fix", + self.context.file._source, + old_node, + len(old_value.value), + new_code, + to_insert, + ) + + return CustomDict(result) + + def compare_CustomCall( + self, old_value: CustomCall, old_node: ast.AST, new_value: CustomCall + ) -> Generator[Change, None, CustomCall]: + + if old_node is None or not isinstance(old_node, ast.Call): + result = yield from self.compare_CustomValue(old_value, old_node, new_value) + return result + + # positional arguments + for pos_arg in old_node.args: + if isinstance(pos_arg, ast.Starred): + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=self.context.file._source.filename, + lineno=pos_arg.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return old_value + + # keyword arguments + for kw in old_node.keywords: + if kw.arg is None: + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=self.context.file._source.filename, + lineno=kw.value.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return old_value + + call = new_value + new_args = call.args + new_kwargs = call.kwargs + + # positional arguments + + result_args = [] + + for i, (new_value_element, node) in enumerate(zip(new_args, old_node.args)): + old_value_element = old_value.argument(i) + result = yield from self.compare(old_value_element, node, new_value_element) + result_args.append(result) + + if len(old_node.args) > len(new_args): + for arg_pos, node in list(enumerate(old_node.args))[len(new_args) :]: + yield Delete( + "fix", + self.context.file._source, + node, + old_value.argument(arg_pos), + ) + + if len(old_node.args) < len(new_args): + for insert_pos, value in list(enumerate(new_args))[len(old_node.args) :]: + yield CallArg( + flag="fix", + file=self.context.file._source, + node=old_node, + arg_pos=insert_pos, + arg_name=None, + new_code=value.repr(), + new_value=value, + ) + + # keyword arguments + result_kwargs = {} + for kw in old_node.keywords: + if kw.arg not in new_kwargs or isinstance( + new_kwargs[kw.arg], CustomDefault + ): + # delete entries + yield Delete( + ( + "update" + if old_value.argument(kw.arg) == new_value.argument(kw.arg) + else "fix" + ), + self.context.file._source, + kw.value, + old_value.argument(kw.arg), + ) + + old_node_kwargs = {kw.arg: kw.value for kw in old_node.keywords} + + to_insert = [] + insert_pos = 0 + for key, new_value_element in new_kwargs.items(): + if isinstance(new_value_element, CustomDefault): + continue + if key not in old_node_kwargs: + # add new values + to_insert.append((key, new_value_element)) + result_kwargs[key] = new_value_element + else: + node = old_node_kwargs[key] + + # check values with same keys + old_value_element = old_value.argument(key) + result_kwargs[key] = yield from self.compare( + old_value_element, node, new_value_element + ) + + if to_insert: + for key, value in to_insert: + + yield CallArg( + flag="fix", + file=self.context.file._source, + node=old_node, + arg_pos=insert_pos, + arg_name=key, + new_code=value.repr(), + new_value=value, + ) + to_insert = [] + + insert_pos += 1 + + if to_insert: + + for key, value in to_insert: + + yield CallArg( + flag="fix", + file=self.context.file._source, + node=old_node, + arg_pos=insert_pos, + arg_name=key, + new_code=value.repr(), + new_value=value, + ) + print(new_value._function) + return CustomCall( + _function=( + yield from self.compare( + old_value._function, old_node.func, new_value._function + ) + ), + _args=result_args, + _kwargs=result_kwargs, + ) diff --git a/src/inline_snapshot/_snapshot/collection_value.py b/src/inline_snapshot/_snapshot/collection_value.py index 951ca021..e5ad7bbb 100644 --- a/src/inline_snapshot/_snapshot/collection_value.py +++ b/src/inline_snapshot/_snapshot/collection_value.py @@ -1,15 +1,17 @@ import ast from typing import Iterator +from inline_snapshot._customize import Builder +from inline_snapshot._customize import CustomList +from inline_snapshot._customize import CustomUndefined + from .._change import Change from .._change import Delete from .._change import ListInsert from .._change import Replace from .._global_state import state -from .._sentinels import undefined from .._utils import value_to_token from .generic_value import GenericValue -from .generic_value import clone from .generic_value import ignore_old_value @@ -17,33 +19,36 @@ class CollectionValue(GenericValue): _current_op = "x in snapshot" def __contains__(self, item): - if self._old_value is undefined: + if isinstance(self._old_value, CustomUndefined): state().missing_values += 1 - if self._new_value is undefined: - self._new_value = [clone(item)] + if isinstance(self._new_value, CustomUndefined): + self._new_value = CustomList([item], [Builder().get_handler(item)]) else: - if item not in self._new_value: - self._new_value.append(clone(item)) + if item not in self._new_value.value: + self._new_value.value.append(Builder().get_handler(item)) - if ignore_old_value() or self._old_value is undefined: + if ignore_old_value() or isinstance(self._old_value, CustomUndefined): return True else: - return self._return(item in self._old_value) + return self._return(item in self._old_value.eval()) def _new_code(self): - return self._file._value_to_code(self._new_value) + # TODO repr() ... + return self._file._value_to_code(self._new_value.eval()) def _get_changes(self) -> Iterator[Change]: + assert isinstance(self._old_value, CustomList) + assert isinstance(self._new_value, CustomList), self._new_value if self._ast_node is None: - elements = [None] * len(self._old_value) + elements = [None] * len(self._old_value.value) else: assert isinstance(self._ast_node, ast.List) elements = self._ast_node.elts - for old_value, old_node in zip(self._old_value, elements): - if old_value not in self._new_value: + for old_value, old_node in zip(self._old_value.value, elements): + if old_value not in self._new_value.value: yield Delete( flag="trim", file=self._file, @@ -53,7 +58,7 @@ def _get_changes(self) -> Iterator[Change]: continue # check for update - new_token = value_to_token(old_value) + new_token = value_to_token(old_value.eval()) if ( old_node is not None @@ -70,13 +75,15 @@ def _get_changes(self) -> Iterator[Change]: new_value=old_value, ) - new_values = [v for v in self._new_value if v not in self._old_value] + new_values = [ + v.eval() for v in self._new_value.value if v not in self._old_value.value + ] if new_values: yield ListInsert( flag="fix", file=self._file, node=self._ast_node, - position=len(self._old_value), + position=len(self._old_value.value), new_code=[self._file._value_to_code(v) for v in new_values], new_values=new_values, ) diff --git a/src/inline_snapshot/_snapshot/dict_value.py b/src/inline_snapshot/_snapshot/dict_value.py index afed0073..42c5f31f 100644 --- a/src/inline_snapshot/_snapshot/dict_value.py +++ b/src/inline_snapshot/_snapshot/dict_value.py @@ -1,6 +1,12 @@ import ast from typing import Iterator +import pytest + +from inline_snapshot._customize import Builder +from inline_snapshot._customize import CustomDict +from inline_snapshot._customize import CustomUndefined + from .._adapter.adapter import AdapterContext from .._change import Change from .._change import Delete @@ -16,35 +22,43 @@ class DictValue(GenericValue): def __getitem__(self, index): - if self._new_value is undefined: - self._new_value = {} + pytest.skip() + + if isinstance(self._new_value, CustomUndefined): + self._new_value = CustomDict({}, {}) - if index not in self._new_value: - old_value = self._old_value - if old_value is undefined: + index = Builder().get_handler(index) + + if index not in self._new_value.value: + if isinstance(self._old_value, CustomUndefined): state().missing_values += 1 old_value = {} + else: + old_value = self._old_value.value child_node = None if self._ast_node is not None: assert isinstance(self._ast_node, ast.Dict) - if index in old_value: - pos = list(old_value.keys()).index(index) + old_keys = [k for k in old_value.keys()] + if index in old_keys: + pos = old_keys.index(index) child_node = self._ast_node.values[pos] - self._new_value[index] = UndecidedValue( + self._new_value.value[index] = UndecidedValue( old_value.get(index, undefined), child_node, self._context ) - return self._new_value[index] + return self._new_value.value[index] def _re_eval(self, value, context: AdapterContext): super()._re_eval(value, context) - if self._new_value is not undefined and self._old_value is not undefined: - for key, s in self._new_value.items(): - if key in self._old_value: - s._re_eval(self._old_value[key], context) + if not isinstance(self._new_value, CustomUndefined) and not isinstance( + self._old_value, CustomUndefined + ): + for key, s in self._new_value.value.items(): + if key in self._old_value.value: + s._re_eval(self._old_value.value[key], context) def _new_code(self): return ( @@ -52,7 +66,7 @@ def _new_code(self): + ", ".join( [ f"{self._file._value_to_code(k)}: {v._new_code()}" - for k, v in self._new_value.items() + for k, v in self._new_value.value.items() if not isinstance(v, UndecidedValue) ] ) @@ -61,37 +75,37 @@ def _new_code(self): def _get_changes(self) -> Iterator[Change]: - assert self._old_value is not undefined + assert not isinstance(self._old_value, CustomUndefined) if self._ast_node is None: - values = [None] * len(self._old_value) + values = [None] * len(self._old_value.value) else: assert isinstance(self._ast_node, ast.Dict) values = self._ast_node.values - for key, node in zip(self._old_value.keys(), values): - if key in self._new_value: + for key, node in zip(self._old_value.value.keys(), values): + if key in self._new_value.value: # check values with same keys - yield from self._new_value[key]._get_changes() + yield from self._new_value.value[key]._get_changes() else: # delete entries - yield Delete("trim", self._file, node, self._old_value[key]) + yield Delete("trim", self._file, node, self._old_value.value[key]) to_insert = [] - for key, new_value_element in self._new_value.items(): - if key not in self._old_value and not isinstance( + for key, new_value_element in self._new_value.value.items(): + if key not in self._old_value.value and not isinstance( new_value_element, UndecidedValue ): # add new values to_insert.append((key, new_value_element._new_code())) if to_insert: - new_code = [(self._file._value_to_code(k), v) for k, v in to_insert] + new_code = [(self._file._value_to_code(k.eval()), v) for k, v in to_insert] yield DictInsert( "create", self._file, self._ast_node, - len(self._old_value), + len(self._old_value.value), new_code, to_insert, ) diff --git a/src/inline_snapshot/_snapshot/eq_value.py b/src/inline_snapshot/_snapshot/eq_value.py index 9e141964..7fb441ab 100644 --- a/src/inline_snapshot/_snapshot/eq_value.py +++ b/src/inline_snapshot/_snapshot/eq_value.py @@ -1,14 +1,14 @@ from typing import Iterator from typing import List -from inline_snapshot._adapter.adapter import Adapter +from inline_snapshot._customize import Builder +from inline_snapshot._customize import CustomUndefined +from inline_snapshot._new_adapter import NewAdapter from .._change import Change from .._compare_context import compare_only from .._global_state import state -from .._sentinels import undefined from .generic_value import GenericValue -from .generic_value import clone class EqValue(GenericValue): @@ -16,13 +16,19 @@ class EqValue(GenericValue): _changes: List[Change] def __eq__(self, other): - if self._old_value is undefined: + other = Builder().get_handler(other) + print("===") + print(self._old_value) + print(other) + + if isinstance(self._old_value, CustomUndefined): state().missing_values += 1 - if not compare_only() and self._new_value is undefined: + if not compare_only() and isinstance(self._new_value, CustomUndefined): self._changes = [] - adapter = Adapter(self._context).get_adapter(self._old_value, other) - it = iter(adapter.assign(self._old_value, self._ast_node, clone(other))) + + adapter = NewAdapter(self._context) + it = iter(adapter.compare(self._old_value, self._ast_node, other)) while True: try: self._changes.append(next(it)) @@ -30,10 +36,13 @@ def __eq__(self, other): self._new_value = ex.value break - return self._return(self._old_value == other, self._new_value == other) + return self._return( + self._old_value.eval() == other.eval(), + self._new_value.eval() == other.eval(), + ) def _new_code(self): - return self._file._value_to_code(self._new_value) + return self._new_value.repr() def _get_changes(self) -> Iterator[Change]: - return iter(self._changes) + return iter(getattr(self, "_changes", [])) diff --git a/src/inline_snapshot/_snapshot/generic_value.py b/src/inline_snapshot/_snapshot/generic_value.py index 29db3fc8..63b0a8a6 100644 --- a/src/inline_snapshot/_snapshot/generic_value.py +++ b/src/inline_snapshot/_snapshot/generic_value.py @@ -1,19 +1,23 @@ import ast import copy -from typing import Any from typing import Iterator +import pytest + +from inline_snapshot._customize import Builder +from inline_snapshot._customize import Custom +from inline_snapshot._customize import CustomUndefined +from inline_snapshot._customize import CustomUnmanaged +from inline_snapshot._new_adapter import reeval + from .._adapter.adapter import AdapterContext from .._adapter.adapter import get_adapter_type from .._change import Change from .._code_repr import code_repr from .._exceptions import UsageError from .._global_state import state -from .._sentinels import undefined from .._types import SnapshotBase -from .._unmanaged import Unmanaged from .._unmanaged import declare_unmanaged -from .._unmanaged import update_allowed def clone(obj): @@ -40,8 +44,8 @@ def ignore_old_value(): @declare_unmanaged class GenericValue(SnapshotBase): - _new_value: Any - _old_value: Any + _new_value: Custom + _old_value: Custom _current_op = "undefined" _ast_node: ast.Expr _context: AdapterContext @@ -52,7 +56,12 @@ def _return(self, result, new_result=True): state().incorrect_values += 1 flags = state().update_flags - if flags.fix or flags.create or flags.update or self._old_value is undefined: + if ( + flags.fix + or flags.create + or flags.update + or isinstance(self._old_value, CustomUndefined) + ): return new_result return result @@ -64,10 +73,14 @@ def get_adapter(self, value): return get_adapter_type(value)(self._context) def _re_eval(self, value, context: AdapterContext): + + self._old_value = reeval(self._old_value, Builder().get_handler(value)) + return + self._context = context def re_eval(old_value, node, value): - if isinstance(old_value, Unmanaged): + if isinstance(old_value, CustomUnmanaged): old_value.value = value return @@ -83,22 +96,24 @@ def re_eval(old_value, node, value): re_eval(old_item.value, old_item.node, new_item.value) else: - if update_allowed(old_value): - if not old_value == value: + if not isinstance(old_value, CustomUnmanaged): + if not old_value.eval() == value.eval(): raise UsageError( "snapshot value should not change. Use Is(...) for dynamic snapshot parts." ) else: - assert False, "old_value should be converted to Unmanaged" + assert ( + False + ), f"old_value ({type(old_value)}) should be converted to Unmanaged" - re_eval(self._old_value, self._ast_node, value) + re_eval(self._old_value, self._ast_node, Builder().get_handler(value)) def _ignore_old(self): return ( state().update_flags.fix or state().update_flags.update or state().update_flags.create - or self._old_value is undefined + or isinstance(self._old_value, CustomUndefined) ) def _visible_value(self): @@ -114,7 +129,7 @@ def _new_code(self): raise NotImplementedError() def __repr__(self): - return repr(self._visible_value()) + return repr(self._visible_value().eval()) def _type_error(self, op): __tracebackhide__ = True @@ -127,17 +142,21 @@ def __eq__(self, _other): self._type_error("==") def __le__(self, _other): + pytest.skip() __tracebackhide__ = True self._type_error("<=") def __ge__(self, _other): + pytest.skip() __tracebackhide__ = True self._type_error(">=") def __contains__(self, _other): + pytest.skip() __tracebackhide__ = True self._type_error("in") def __getitem__(self, _item): + pytest.skip() __tracebackhide__ = True self._type_error("snapshot[key]") diff --git a/src/inline_snapshot/_snapshot/min_max_value.py b/src/inline_snapshot/_snapshot/min_max_value.py index 9ef0a65c..ccf9738e 100644 --- a/src/inline_snapshot/_snapshot/min_max_value.py +++ b/src/inline_snapshot/_snapshot/min_max_value.py @@ -1,12 +1,15 @@ from typing import Iterator +import pytest + +from inline_snapshot._customize import Builder +from inline_snapshot._customize import CustomUndefined + from .._change import Change from .._change import Replace from .._global_state import state -from .._sentinels import undefined from .._utils import value_to_token from .generic_value import GenericValue -from .generic_value import clone from .generic_value import ignore_old_value @@ -18,28 +21,33 @@ def cmp(a, b): raise NotImplementedError def _generic_cmp(self, other): - if self._old_value is undefined: + pytest.skip() + if isinstance(self._old_value, CustomUndefined): state().missing_values += 1 - if self._new_value is undefined: - self._new_value = clone(other) - if self._old_value is undefined or ignore_old_value(): + if isinstance(self._new_value, CustomUndefined): + self._new_value = Builder().get_handler(other) + if isinstance(self._old_value, CustomUndefined) or ignore_old_value(): return True - return self._return(self.cmp(self._old_value, other)) + return self._return(self.cmp(self._old_value.eval(), other)) else: - if not self.cmp(self._new_value, other): - self._new_value = clone(other) + if not self.cmp(self._new_value.eval(), other): + self._new_value = Builder.get_handler(other) - return self._return(self.cmp(self._visible_value(), other)) + return self._return(self.cmp(self._visible_value().eval(), other)) def _new_code(self): - return self._file._value_to_code(self._new_value) + # TODO repr() ... + return self._file._value_to_code(self._new_value.eval()) def _get_changes(self) -> Iterator[Change]: - new_token = value_to_token(self._new_value) - if not self.cmp(self._old_value, self._new_value): + pytest.skip() + # TODO repr() ... + new_token = value_to_token(self._new_value.eval()) + + if not self.cmp(self._old_value.eval(), self._new_value.eval()): flag = "fix" - elif not self.cmp(self._new_value, self._old_value): + elif not self.cmp(self._new_value.eval(), self._old_value.eval()): flag = "trim" elif ( self._ast_node is not None @@ -56,8 +64,8 @@ def _get_changes(self) -> Iterator[Change]: file=self._file, new_code=new_code, flag=flag, - old_value=self._old_value, - new_value=self._new_value, + old_value=self._old_value.eval(), + new_value=self._new_value.eval(), ) diff --git a/src/inline_snapshot/_snapshot/undecided_value.py b/src/inline_snapshot/_snapshot/undecided_value.py index 3098d2dc..e2a94fac 100644 --- a/src/inline_snapshot/_snapshot/undecided_value.py +++ b/src/inline_snapshot/_snapshot/undecided_value.py @@ -1,24 +1,100 @@ +import ast from typing import Iterator -from inline_snapshot._adapter.adapter import adapter_map +from inline_snapshot._customize import Builder +from inline_snapshot._customize import Custom +from inline_snapshot._customize import CustomCall +from inline_snapshot._customize import CustomDict +from inline_snapshot._customize import CustomList +from inline_snapshot._customize import CustomTuple +from inline_snapshot._customize import CustomUndefined +from inline_snapshot._customize import CustomUnmanaged +from inline_snapshot._customize import CustomValue from .._adapter.adapter import AdapterContext from .._adapter.adapter import get_adapter_type from .._change import Change from .._change import Replace -from .._sentinels import undefined -from .._unmanaged import Unmanaged -from .._unmanaged import map_unmanaged from .._utils import value_to_token from .generic_value import GenericValue +def verify(value: Custom, node: ast.AST, eval) -> Custom: + """Verify that a Custom value matches its corresponding AST node structure.""" + if isinstance(node, ast.List): + return verify_list(value, node, eval) + elif isinstance(node, ast.Tuple): + return verify_tuple(value, node, eval) + elif isinstance(node, ast.Dict): + return verify_dict(value, node, eval) + elif isinstance(node, ast.Call): + return verify_call(value, node, eval) + else: + # For other types, return the value as-is + return value + + +def verify_list(value: Custom, node: ast.List, eval) -> Custom: + """Verify a CustomList matches its List AST node.""" + assert isinstance(value, CustomList) + return CustomList([verify(v, n, eval) for v, n in zip(value.value, node.elts)]) + + +def verify_tuple(value: Custom, node: ast.Tuple, eval) -> Custom: + """Verify a CustomTuple matches its Tuple AST node.""" + assert isinstance(value, CustomTuple) + return CustomTuple([verify(v, n, eval) for v, n in zip(value.value, node.elts)]) + + +def verify_dict(value: Custom, node: ast.Dict, eval) -> Custom: + """Verify a CustomDict matches its Dict AST node.""" + assert isinstance(value, CustomDict) + verified_items = {} + for (key, val), key_node, val_node in zip( + value.value.items(), node.keys, node.values + ): + verified_key = verify(key, key_node, eval) if key_node else key + verified_val = verify(val, val_node, eval) + verified_items[verified_key] = verified_val + return CustomDict(value=verified_items) + + +def verify_call(value: Custom, node: ast.Call, eval) -> Custom: + """Verify a CustomCall matches its Call AST node.""" + + if not isinstance(value, CustomCall) or eval(node.func) != value._function.eval(): + return CustomValue(eval(node), ast.unparse(node)) + + # Verify function + verified_function = verify(value._function, node.func, eval) + + # Verify positional arguments + verified_args = [] + for arg, arg_node in zip(value._args, node.args): + verified_args.append(verify(arg, arg_node, eval)) + + # Verify keyword arguments + verified_kwargs = {} + keyword_map = {kw.arg: kw.value for kw in node.keywords if kw.arg} + for key, val in value._kwargs.items(): + if key in keyword_map: + verified_kwargs[key] = verify(val, keyword_map[key], eval) + else: + verified_kwargs[key] = val + + return CustomCall( + _function=verified_function, _args=verified_args, _kwargs=verified_kwargs + ) + + class UndecidedValue(GenericValue): def __init__(self, old_value, ast_node, context: AdapterContext): - old_value = adapter_map(old_value, map_unmanaged) + old_value = verify(Builder().get_handler(old_value), ast_node, context.eval) + + assert isinstance(old_value, Custom) self._old_value = old_value - self._new_value = undefined + self._new_value = CustomUndefined() self._ast_node = ast_node self._context = context @@ -38,8 +114,8 @@ def handle(node, obj): yield from handle(item.node, item.value) return - if not isinstance(obj, Unmanaged) and node is not None: - new_token = value_to_token(obj) + if not isinstance(obj, CustomUnmanaged) and node is not None: + new_token = value_to_token(obj.eval()) if self._file._token_of_node(node) != new_token: new_code = self._file._token_to_code(new_token) diff --git a/src/inline_snapshot/testing/_example.py b/src/inline_snapshot/testing/_example.py index 4917cfcf..ebe0cd06 100644 --- a/src/inline_snapshot/testing/_example.py +++ b/src/inline_snapshot/testing/_example.py @@ -445,6 +445,9 @@ def run_pytest( Returns: A new Example instance containing the changed files. """ + import pytest + + pytest.skip() self.dump_files() with TemporaryDirectory() as dir: diff --git a/tests/adapter/test_change_types.py b/tests/adapter/test_change_types.py index add6804f..f28d3c17 100644 --- a/tests/adapter/test_change_types.py +++ b/tests/adapter/test_change_types.py @@ -31,7 +31,7 @@ def f(v): def code_repr(v): g = {} - exec(context + f"r=repr({a})", g) + exec(context + f"r=repr({v})", g) return g["r"] def code(a, b): diff --git a/tests/test_docs.py b/tests/test_docs.py index 26ea291c..142e8b2c 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -269,6 +269,7 @@ def change_block(block): ], ) def test_docs(file): + pytest.skip() file_test(file) diff --git a/tests/test_factory_adapter.py b/tests/test_factory_adapter.py new file mode 100644 index 00000000..572db777 --- /dev/null +++ b/tests/test_factory_adapter.py @@ -0,0 +1,40 @@ +"""Tests for the factory adapter handling.""" + +from collections import defaultdict + +from inline_snapshot import snapshot + + +def test_factory_adapter_defaultdict(): + """Test that factory functions in defaultdict are handled correctly.""" + d = defaultdict(list) + d["test"].append(1) + d["other"].append(2) + snapshot(d) == defaultdict(list, {"test": [1], "other": [2]}) + + +def test_factory_adapter_subclass(): + """Test that custom factory functions are handled correctly.""" + + class CustomFactory: + def __call__(self): + return 42 + + d = defaultdict(CustomFactory()) + d["test"] # Access to trigger factory + snapshot(d) == defaultdict(CustomFactory(), {"test": 42}) + + +def test_factory_adapter_lambda(): + """Test that lambda factory functions are handled correctly.""" + d = defaultdict(lambda: "default") + d["test"] # Access to trigger factory + snapshot(d) == defaultdict(lambda: "default", {"test": "default"}) + + +def test_factory_adapter_builtin_types(): + """Test that builtin type factories are handled correctly.""" + for factory in (list, dict, set, str, int, float): + d = defaultdict(factory) + d["test"] # Access to trigger factory + snapshot(d) == defaultdict(factory, {"test": factory()}) diff --git a/tests/test_pytest_plugin.py b/tests/test_pytest_plugin.py index 2a9cb981..51828c70 100644 --- a/tests/test_pytest_plugin.py +++ b/tests/test_pytest_plugin.py @@ -510,6 +510,7 @@ def test_assertion_error(project): @pytest.mark.no_rewriting def test_run_without_pytest(pytester): + pytest.skip() # snapshots are deactivated by default pytester.makepyfile( test_file=""" From 05598f04df5c0b6e84929054919acca17b91e743 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Fri, 28 Nov 2025 22:08:52 +0100 Subject: [PATCH 04/72] refactoring --- src/inline_snapshot/_code_repr.py | 4 ++- src/inline_snapshot/_customize.py | 34 ++++++++++++------- src/inline_snapshot/_get_snapshot_value.py | 4 --- src/inline_snapshot/_new_adapter.py | 4 +++ .../_snapshot/undecided_value.py | 11 +++++- src/inline_snapshot/_unmanaged.py | 20 ----------- src/inline_snapshot/fix_pytest_diff.py | 25 +++++++------- 7 files changed, 51 insertions(+), 51 deletions(-) diff --git a/src/inline_snapshot/_code_repr.py b/src/inline_snapshot/_code_repr.py index cfc1b194..5edc93cf 100644 --- a/src/inline_snapshot/_code_repr.py +++ b/src/inline_snapshot/_code_repr.py @@ -81,7 +81,9 @@ def code_repr(obj): def mocked_code_repr(obj): - return value_code_repr(obj) + from inline_snapshot._customize import Builder + + return Builder().get_handler(obj).repr() def value_code_repr(obj): diff --git a/src/inline_snapshot/_customize.py b/src/inline_snapshot/_customize.py index 888e0c71..bfe332ab 100644 --- a/src/inline_snapshot/_customize.py +++ b/src/inline_snapshot/_customize.py @@ -3,14 +3,14 @@ import ast from abc import ABC from abc import abstractmethod +from collections import Counter from collections import defaultdict from types import BuiltinFunctionType from types import FunctionType from typing import Any from typing import Callable -from inline_snapshot._code_repr import code_repr -from inline_snapshot._unmanaged import Unmanaged +from inline_snapshot._code_repr import value_code_repr from inline_snapshot._unmanaged import is_unmanaged custom_functions = [] @@ -64,11 +64,9 @@ def map(self, f): return self.value.map(f) -class CustomUnmanaged(Custom, Unmanaged): - def __init__(self, value): - # TODO remove Unmanaged - Custom.__init__(self) - Unmanaged.__init__(self, value) +@dataclass() +class CustomUnmanaged(Custom): + value: Any def repr(self): return "" @@ -105,7 +103,11 @@ class CustomCall(Custom): def repr(self) -> str: args = [] args += [a.repr() for a in self.args] - args += [f"{k} = {v.repr()}" for k, v in self.kwargs.items()] + args += [ + f"{k}={v.repr()}" + for k, v in self.kwargs.items() + if not isinstance(v, CustomDefault) + ] return f"{self._function.repr()}({', '.join(args)})" @property @@ -182,7 +184,9 @@ def map(self, f): return f({k.map(f): v.map(f) for k, v in self.value.items()}) def repr(self) -> str: - return f"{{ { ', '.join(f'{k.repr()} = {v.repr()}' for k,v in self.value.items())} }}" + return ( + f"{{{ ', '.join(f'{k.repr()}: {v.repr()}' for k,v in self.value.items())}}}" + ) class CustomValue(Custom): @@ -190,7 +194,7 @@ def __init__(self, value, repr_str=None): assert not isinstance(value, Custom) if repr_str is None: - self.repr_str = code_repr(value) + self.repr_str = value_code_repr(value) else: self.repr_str = repr_str @@ -211,13 +215,19 @@ def standard_handler(value, builder: Builder): if isinstance(value, list): return builder.List(value) - if isinstance(value, tuple): + if type(value) is tuple: return builder.Tuple(value) if isinstance(value, dict): return builder.Dict(value) +@customize +def counter_handler(value, builder: Builder): + if isinstance(value, Counter): + return builder.Call(value, Counter, [dict(value)]) + + @customize def function_handler(value, builder: Builder): if isinstance(value, FunctionType): @@ -398,7 +408,7 @@ def defaultdict_handler(value, builder: Builder): @customize def unmanaged_handler(value, builder: Builder): if is_unmanaged(value): - return CustomUnmanaged(value) + return CustomUnmanaged(value=value) @customize diff --git a/src/inline_snapshot/_get_snapshot_value.py b/src/inline_snapshot/_get_snapshot_value.py index 69474bed..abe99f89 100644 --- a/src/inline_snapshot/_get_snapshot_value.py +++ b/src/inline_snapshot/_get_snapshot_value.py @@ -9,7 +9,6 @@ from ._is import Is from ._snapshot.generic_value import GenericValue from ._types import Snapshot -from ._unmanaged import Unmanaged def unwrap(value): @@ -22,9 +21,6 @@ def unwrap(value): except (UsageError, StorageLookupError): return (None, False) - if isinstance(value, Unmanaged): - return unwrap(value.value)[0], True - if isinstance(value, Is): return value.value, True diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index cd5ca431..d1607653 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -43,6 +43,10 @@ def reeval(old_value: Custom, value: Custom) -> Custom: function_name = f"reeval_{type(old_value).__name__}" result = globals()[function_name](old_value, value) assert isinstance(result, Custom) + + if not result == value: + breakpoint() + assert result == value return result diff --git a/src/inline_snapshot/_snapshot/undecided_value.py b/src/inline_snapshot/_snapshot/undecided_value.py index e2a94fac..3759ad6b 100644 --- a/src/inline_snapshot/_snapshot/undecided_value.py +++ b/src/inline_snapshot/_snapshot/undecided_value.py @@ -4,6 +4,7 @@ from inline_snapshot._customize import Builder from inline_snapshot._customize import Custom from inline_snapshot._customize import CustomCall +from inline_snapshot._customize import CustomDefault from inline_snapshot._customize import CustomDict from inline_snapshot._customize import CustomList from inline_snapshot._customize import CustomTuple @@ -21,6 +22,11 @@ def verify(value: Custom, node: ast.AST, eval) -> Custom: """Verify that a Custom value matches its corresponding AST node structure.""" + if isinstance(value, CustomUnmanaged): + return value + if isinstance(value, CustomDefault): + return CustomDefault(value=verify(value.value, node, eval)) + if isinstance(node, ast.List): return verify_list(value, node, eval) elif isinstance(node, ast.Tuple): @@ -90,7 +96,10 @@ def verify_call(value: Custom, node: ast.Call, eval) -> Custom: class UndecidedValue(GenericValue): def __init__(self, old_value, ast_node, context: AdapterContext): - old_value = verify(Builder().get_handler(old_value), ast_node, context.eval) + old_value = Builder().get_handler(old_value) + print("before verify", old_value) + old_value = verify(old_value, ast_node, context.eval) + print("after verify", old_value) assert isinstance(old_value, Custom) self._old_value = old_value diff --git a/src/inline_snapshot/_unmanaged.py b/src/inline_snapshot/_unmanaged.py index 3f33393a..9d15e626 100644 --- a/src/inline_snapshot/_unmanaged.py +++ b/src/inline_snapshot/_unmanaged.py @@ -29,23 +29,3 @@ def declare_unmanaged(data_type): global unmanaged_types unmanaged_types.append(data_type) return data_type - - -class Unmanaged: - def __init__(self, value): - self.value = value - - def __eq__(self, other): - assert not isinstance(other, Unmanaged) - - return self.value == other - - def __repr__(self): - return repr(self.value) - - -def map_unmanaged(value): - if is_unmanaged(value): - return Unmanaged(value) - else: - return value diff --git a/src/inline_snapshot/fix_pytest_diff.py b/src/inline_snapshot/fix_pytest_diff.py index b481f2f2..e33ff366 100644 --- a/src/inline_snapshot/fix_pytest_diff.py +++ b/src/inline_snapshot/fix_pytest_diff.py @@ -4,7 +4,6 @@ from inline_snapshot._is import Is from inline_snapshot._snapshot.generic_value import GenericValue -from inline_snapshot._unmanaged import Unmanaged def fix_pytest_diff(): @@ -23,18 +22,18 @@ def _pprint_snapshot( PrettyPrinter._dispatch[GenericValue.__repr__] = _pprint_snapshot - def _pprint_unmanaged( - self, - object: Any, - stream: IO[str], - indent: int, - allowance: int, - context: Set[int], - level: int, - ) -> None: - self._format(object.value, stream, indent, allowance, context, level) - - PrettyPrinter._dispatch[Unmanaged.__repr__] = _pprint_unmanaged + # def _pprint_unmanaged( + # self, + # object: Any, + # stream: IO[str], + # indent: int, + # allowance: int, + # context: Set[int], + # level: int, + # ) -> None: + # self._format(object.value, stream, indent, allowance, context, level) + + # PrettyPrinter._dispatch[Unmanaged.__repr__] = _pprint_unmanaged def _pprint_is( self, From 5d03e4d73c719e1d68f4e36917079a83f2e75f45 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sat, 29 Nov 2025 07:34:35 +0100 Subject: [PATCH 05/72] refactoring --- pyproject.toml | 1 - src/inline_snapshot/_adapter/adapter.py | 10 --- src/inline_snapshot/_adapter/dict_adapter.py | 4 -- .../_adapter/generic_call_adapter.py | 8 --- .../_adapter/sequence_adapter.py | 6 -- src/inline_snapshot/_code_repr.py | 2 +- src/inline_snapshot/_customize.py | 15 ++++- src/inline_snapshot/_inline_snapshot.py | 5 +- src/inline_snapshot/_new_adapter.py | 15 +++-- .../_snapshot/collection_value.py | 6 +- src/inline_snapshot/_snapshot/dict_value.py | 10 +-- src/inline_snapshot/_snapshot/eq_value.py | 6 +- .../_snapshot/generic_value.py | 26 -------- .../_snapshot/min_max_value.py | 6 +- .../_snapshot/undecided_value.py | 65 +++++++++++-------- src/inline_snapshot/_utils.py | 28 +++++++- src/inline_snapshot/testing/_example.py | 2 - tests/adapter/test_dataclass.py | 42 +++--------- tests/test_docs.py | 1 - tests/test_factory_adapter.py | 40 ------------ tests/test_preserve_values.py | 2 +- tests/test_pytest_plugin.py | 1 - 22 files changed, 109 insertions(+), 192 deletions(-) delete mode 100644 tests/test_factory_adapter.py diff --git a/pyproject.toml b/pyproject.toml index 83e244d7..aabf921f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -244,4 +244,3 @@ force_single_line=true [tool.inline-snapshot] show-updates=true -default-flags-tui=["disable"] diff --git a/src/inline_snapshot/_adapter/adapter.py b/src/inline_snapshot/_adapter/adapter.py index fbe333f5..5757bca3 100644 --- a/src/inline_snapshot/_adapter/adapter.py +++ b/src/inline_snapshot/_adapter/adapter.py @@ -4,8 +4,6 @@ import typing from dataclasses import dataclass -import pytest - from inline_snapshot._customize import CustomCall from inline_snapshot._customize import CustomDict from inline_snapshot._customize import CustomList @@ -20,29 +18,21 @@ def get_adapter_type(value): if isinstance(value, CustomCall): from .generic_call_adapter import CallAdapter - pytest.skip() - return CallAdapter if isinstance(value, CustomList): from .sequence_adapter import ListAdapter - pytest.skip() - return ListAdapter if isinstance(value, CustomTuple): from .sequence_adapter import TupleAdapter - pytest.skip() - return TupleAdapter if isinstance(value, CustomDict): from .dict_adapter import DictAdapter - pytest.skip() - return DictAdapter if isinstance(value, (CustomValue, CustomUndefined)): diff --git a/src/inline_snapshot/_adapter/dict_adapter.py b/src/inline_snapshot/_adapter/dict_adapter.py index 44d78586..7740b767 100644 --- a/src/inline_snapshot/_adapter/dict_adapter.py +++ b/src/inline_snapshot/_adapter/dict_adapter.py @@ -3,8 +3,6 @@ import ast import warnings -import pytest - from inline_snapshot._customize import CustomDict from .._change import Delete @@ -36,7 +34,6 @@ def map(cls, value, map_function): @classmethod def items(cls, value, node): - pytest.skip() assert isinstance(value, CustomDict) value = value.value if node is None or not isinstance(node, ast.Dict): @@ -60,7 +57,6 @@ def items(cls, value, node): return result def assign(self, old_value, old_node, new_value): - pytest.skip() assert isinstance(old_value, CustomDict) assert isinstance(new_value, CustomDict) diff --git a/src/inline_snapshot/_adapter/generic_call_adapter.py b/src/inline_snapshot/_adapter/generic_call_adapter.py index 0e4c46e5..173465af 100644 --- a/src/inline_snapshot/_adapter/generic_call_adapter.py +++ b/src/inline_snapshot/_adapter/generic_call_adapter.py @@ -4,8 +4,6 @@ import warnings from typing import Any -import pytest - from inline_snapshot._customize import CustomCall from inline_snapshot._customize import CustomDefault from inline_snapshot._customize import unwrap_default @@ -27,17 +25,14 @@ class CallAdapter(Adapter): @classmethod def arguments(cls, value) -> CustomCall: - pytest.skip() return value @classmethod def argument(cls, value, pos_or_name) -> Any: - pytest.skip() return cls.arguments(value).argument(pos_or_name) @classmethod def repr(cls, value): - pytest.skip() call = cls.arguments(value) @@ -51,12 +46,10 @@ def repr(cls, value): @classmethod def map(cls, value, map_function): - pytest.skip() return cls.arguments(value).map(map_function) @classmethod def items(cls, value, node): - pytest.skip() args = cls.arguments(value) new_args = args.args @@ -87,7 +80,6 @@ def pos_arg_node(_): ] def assign(self, old_value, old_node, new_value): - pytest.skip() if old_node is None or not isinstance(old_node, ast.Call): result = yield from self.value_assign(old_value, old_node, new_value) return result diff --git a/src/inline_snapshot/_adapter/sequence_adapter.py b/src/inline_snapshot/_adapter/sequence_adapter.py index 62183e54..f917d246 100644 --- a/src/inline_snapshot/_adapter/sequence_adapter.py +++ b/src/inline_snapshot/_adapter/sequence_adapter.py @@ -4,8 +4,6 @@ import warnings from collections import defaultdict -import pytest - from .._align import add_x from .._align import align from .._change import Delete @@ -25,7 +23,6 @@ class SequenceAdapter(Adapter): @classmethod def repr(cls, value): - pytest.skip() if len(value) == 1 and cls.trailing_comma: seq = repr(value[0]) + "," else: @@ -34,13 +31,11 @@ def repr(cls, value): @classmethod def map(cls, value, map_function): - pytest.skip() result = [adapter_map(v, map_function) for v in value] return cls.value_type(result) @classmethod def items(cls, value, node): - pytest.skip() if node is None or not isinstance(node, cls.node_type): return [Item(value=v, node=None) for v in value] @@ -49,7 +44,6 @@ def items(cls, value, node): return [Item(value=v, node=n) for v, n in zip(value.value, node.elts)] def assign(self, old_value, old_node, new_value): - pytest.skip() if old_node is not None: if not isinstance( old_node, ast.List if isinstance(old_value, list) else ast.Tuple diff --git a/src/inline_snapshot/_code_repr.py b/src/inline_snapshot/_code_repr.py index 5edc93cf..16b2ada4 100644 --- a/src/inline_snapshot/_code_repr.py +++ b/src/inline_snapshot/_code_repr.py @@ -35,7 +35,7 @@ def __eq__(self, other): if type(other) is not self._type: return False - other_repr = code_repr(other) + other_repr = value_code_repr(other) return other_repr == self._str_repr or other_repr == repr(self) diff --git a/src/inline_snapshot/_customize.py b/src/inline_snapshot/_customize.py index bfe332ab..960cbe88 100644 --- a/src/inline_snapshot/_customize.py +++ b/src/inline_snapshot/_customize.py @@ -5,6 +5,8 @@ from abc import abstractmethod from collections import Counter from collections import defaultdict +from pathlib import Path +from pathlib import PurePath from types import BuiltinFunctionType from types import FunctionType from typing import Any @@ -12,6 +14,7 @@ from inline_snapshot._code_repr import value_code_repr from inline_snapshot._unmanaged import is_unmanaged +from inline_snapshot._utils import clone custom_functions = [] @@ -72,7 +75,7 @@ def repr(self): return "" def map(self, f): - return self.value + return f(self.value) class CustomUndefined(Custom): @@ -192,6 +195,7 @@ def repr(self) -> str: class CustomValue(Custom): def __init__(self, value, repr_str=None): assert not isinstance(value, Custom) + value = clone(value) if repr_str is None: self.repr_str = value_code_repr(value) @@ -246,6 +250,15 @@ def type_handler(value, builder: Builder): return builder.Value(value, value.__qualname__) +@customize +def path_handler(value, builder: Builder): + if isinstance(value, Path): + return builder.Call(value, Path, [value.as_posix()]) + + if isinstance(value, PurePath): + return builder.Call(value, PurePath, [value.as_posix()]) + + @customize def dataclass_handler(value, builder: Builder): diff --git a/src/inline_snapshot/_inline_snapshot.py b/src/inline_snapshot/_inline_snapshot.py index 0f098153..d2200d92 100644 --- a/src/inline_snapshot/_inline_snapshot.py +++ b/src/inline_snapshot/_inline_snapshot.py @@ -7,6 +7,7 @@ from executing import Source +from inline_snapshot._customize import CustomUndefined from inline_snapshot._source_file import SourceFile from inline_snapshot._types import SnapshotRefBase @@ -128,12 +129,12 @@ def create_raw(obj, context: AdapterContext): def _changes(self) -> Iterator[Change]: if ( - self._value._old_value is undefined + isinstance(self._value._old_value, CustomUndefined) if self._expr is None else not self._expr.node.args ): - if self._value._new_value is undefined: + if isinstance(self._value._new_value, CustomUndefined): return new_code = self._value._new_code() diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index d1607653..3048fcc1 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -24,6 +24,7 @@ from inline_snapshot._customize import CustomUndefined from inline_snapshot._customize import CustomUnmanaged from inline_snapshot._customize import CustomValue +from inline_snapshot._exceptions import UsageError from inline_snapshot._utils import value_to_token from inline_snapshot.syntax_warnings import InlineSnapshotInfo from inline_snapshot.syntax_warnings import InlineSnapshotSyntaxWarning @@ -44,9 +45,6 @@ def reeval(old_value: Custom, value: Custom) -> Custom: result = globals()[function_name](old_value, value) assert isinstance(result, Custom) - if not result == value: - breakpoint() - assert result == value return result @@ -66,6 +64,12 @@ def reeval_CustomUndefined(old_value, value): def reeval_CustomValue(old_value: CustomValue, value: CustomValue): + + if not old_value.eval() == value.eval(): + raise UsageError( + "snapshot value should not change. Use Is(...) for dynamic snapshot parts." + ) + return value @@ -112,8 +116,6 @@ def compare( if isinstance(new_value, CustomUnmanaged): raise UsageError("unmanaged values can not be compared with snapshots") - print("compare", old_value, new_value) - if type(old_value) is not type(new_value) or not isinstance( old_node, new_value.node_type ): @@ -148,7 +150,7 @@ def compare_CustomValue( ): if not old_value.eval() == new_value.eval(): warnings.warn_explicit( - f"inline-snapshot will be able to fix f-strings in the future.\nThe current string value is:\n {new_value!r}", + f"inline-snapshot will be able to fix f-strings in the future.\nThe current string value is:\n {new_value.repr()}", filename=self.context.file._source.filename, lineno=old_node.lineno, category=InlineSnapshotInfo, @@ -487,7 +489,6 @@ def compare_CustomCall( new_code=value.repr(), new_value=value, ) - print(new_value._function) return CustomCall( _function=( yield from self.compare( diff --git a/src/inline_snapshot/_snapshot/collection_value.py b/src/inline_snapshot/_snapshot/collection_value.py index e5ad7bbb..2fea9504 100644 --- a/src/inline_snapshot/_snapshot/collection_value.py +++ b/src/inline_snapshot/_snapshot/collection_value.py @@ -23,9 +23,9 @@ def __contains__(self, item): state().missing_values += 1 if isinstance(self._new_value, CustomUndefined): - self._new_value = CustomList([item], [Builder().get_handler(item)]) + self._new_value = CustomList([Builder().get_handler(item)]) else: - if item not in self._new_value.value: + if item not in self._new_value.eval(): self._new_value.value.append(Builder().get_handler(item)) if ignore_old_value() or isinstance(self._old_value, CustomUndefined): @@ -38,7 +38,7 @@ def _new_code(self): return self._file._value_to_code(self._new_value.eval()) def _get_changes(self) -> Iterator[Change]: - assert isinstance(self._old_value, CustomList) + assert isinstance(self._old_value, CustomList), self._old_value assert isinstance(self._new_value, CustomList), self._new_value if self._ast_node is None: diff --git a/src/inline_snapshot/_snapshot/dict_value.py b/src/inline_snapshot/_snapshot/dict_value.py index 42c5f31f..46d9435f 100644 --- a/src/inline_snapshot/_snapshot/dict_value.py +++ b/src/inline_snapshot/_snapshot/dict_value.py @@ -1,8 +1,6 @@ import ast from typing import Iterator -import pytest - from inline_snapshot._customize import Builder from inline_snapshot._customize import CustomDict from inline_snapshot._customize import CustomUndefined @@ -13,7 +11,6 @@ from .._change import DictInsert from .._global_state import state from .._inline_snapshot import UndecidedValue -from .._sentinels import undefined from .generic_value import GenericValue @@ -21,11 +18,8 @@ class DictValue(GenericValue): _current_op = "snapshot[key]" def __getitem__(self, index): - - pytest.skip() - if isinstance(self._new_value, CustomUndefined): - self._new_value = CustomDict({}, {}) + self._new_value = CustomDict({}) index = Builder().get_handler(index) @@ -45,7 +39,7 @@ def __getitem__(self, index): child_node = self._ast_node.values[pos] self._new_value.value[index] = UndecidedValue( - old_value.get(index, undefined), child_node, self._context + old_value.get(index, CustomUndefined()), child_node, self._context ) return self._new_value.value[index] diff --git a/src/inline_snapshot/_snapshot/eq_value.py b/src/inline_snapshot/_snapshot/eq_value.py index 7fb441ab..67fada04 100644 --- a/src/inline_snapshot/_snapshot/eq_value.py +++ b/src/inline_snapshot/_snapshot/eq_value.py @@ -4,6 +4,7 @@ from inline_snapshot._customize import Builder from inline_snapshot._customize import CustomUndefined from inline_snapshot._new_adapter import NewAdapter +from inline_snapshot._utils import map_strings from .._change import Change from .._compare_context import compare_only @@ -17,9 +18,6 @@ class EqValue(GenericValue): def __eq__(self, other): other = Builder().get_handler(other) - print("===") - print(self._old_value) - print(other) if isinstance(self._old_value, CustomUndefined): state().missing_values += 1 @@ -42,7 +40,7 @@ def __eq__(self, other): ) def _new_code(self): - return self._new_value.repr() + return self._file._token_to_code(map_strings(self._new_value.repr())) def _get_changes(self) -> Iterator[Change]: return iter(getattr(self, "_changes", [])) diff --git a/src/inline_snapshot/_snapshot/generic_value.py b/src/inline_snapshot/_snapshot/generic_value.py index 63b0a8a6..c0d3e4b4 100644 --- a/src/inline_snapshot/_snapshot/generic_value.py +++ b/src/inline_snapshot/_snapshot/generic_value.py @@ -1,9 +1,6 @@ import ast -import copy from typing import Iterator -import pytest - from inline_snapshot._customize import Builder from inline_snapshot._customize import Custom from inline_snapshot._customize import CustomUndefined @@ -13,31 +10,12 @@ from .._adapter.adapter import AdapterContext from .._adapter.adapter import get_adapter_type from .._change import Change -from .._code_repr import code_repr from .._exceptions import UsageError from .._global_state import state from .._types import SnapshotBase from .._unmanaged import declare_unmanaged -def clone(obj): - new = copy.deepcopy(obj) - if not obj == new: - raise UsageError( - f"""\ -inline-snapshot uses `copy.deepcopy` to copy objects, -but the copied object is not equal to the original one: - -value = {code_repr(obj)} -copied_value = copy.deepcopy(value) -assert value == copied_value - -Please fix the way your object is copied or your __eq__ implementation. -""" - ) - return new - - def ignore_old_value(): return state().update_flags.fix or state().update_flags.update @@ -142,21 +120,17 @@ def __eq__(self, _other): self._type_error("==") def __le__(self, _other): - pytest.skip() __tracebackhide__ = True self._type_error("<=") def __ge__(self, _other): - pytest.skip() __tracebackhide__ = True self._type_error(">=") def __contains__(self, _other): - pytest.skip() __tracebackhide__ = True self._type_error("in") def __getitem__(self, _item): - pytest.skip() __tracebackhide__ = True self._type_error("snapshot[key]") diff --git a/src/inline_snapshot/_snapshot/min_max_value.py b/src/inline_snapshot/_snapshot/min_max_value.py index ccf9738e..ffa2d680 100644 --- a/src/inline_snapshot/_snapshot/min_max_value.py +++ b/src/inline_snapshot/_snapshot/min_max_value.py @@ -1,7 +1,5 @@ from typing import Iterator -import pytest - from inline_snapshot._customize import Builder from inline_snapshot._customize import CustomUndefined @@ -21,7 +19,6 @@ def cmp(a, b): raise NotImplementedError def _generic_cmp(self, other): - pytest.skip() if isinstance(self._old_value, CustomUndefined): state().missing_values += 1 @@ -32,7 +29,7 @@ def _generic_cmp(self, other): return self._return(self.cmp(self._old_value.eval(), other)) else: if not self.cmp(self._new_value.eval(), other): - self._new_value = Builder.get_handler(other) + self._new_value = Builder().get_handler(other) return self._return(self.cmp(self._visible_value().eval(), other)) @@ -41,7 +38,6 @@ def _new_code(self): return self._file._value_to_code(self._new_value.eval()) def _get_changes(self) -> Iterator[Change]: - pytest.skip() # TODO repr() ... new_token = value_to_token(self._new_value.eval()) diff --git a/src/inline_snapshot/_snapshot/undecided_value.py b/src/inline_snapshot/_snapshot/undecided_value.py index 3759ad6b..98580ad9 100644 --- a/src/inline_snapshot/_snapshot/undecided_value.py +++ b/src/inline_snapshot/_snapshot/undecided_value.py @@ -11,12 +11,10 @@ from inline_snapshot._customize import CustomUndefined from inline_snapshot._customize import CustomUnmanaged from inline_snapshot._customize import CustomValue +from inline_snapshot._new_adapter import NewAdapter from .._adapter.adapter import AdapterContext -from .._adapter.adapter import get_adapter_type from .._change import Change -from .._change import Replace -from .._utils import value_to_token from .generic_value import GenericValue @@ -55,6 +53,9 @@ def verify_tuple(value: Custom, node: ast.Tuple, eval) -> Custom: def verify_dict(value: Custom, node: ast.Dict, eval) -> Custom: """Verify a CustomDict matches its Dict AST node.""" assert isinstance(value, CustomDict) + if any(key is None for key in node.keys): + return value + verified_items = {} for (key, val), key_node, val_node in zip( value.value.items(), node.keys, node.values @@ -97,9 +98,7 @@ class UndecidedValue(GenericValue): def __init__(self, old_value, ast_node, context: AdapterContext): old_value = Builder().get_handler(old_value) - print("before verify", old_value) old_value = verify(old_value, ast_node, context.eval) - print("after verify", old_value) assert isinstance(old_value, Custom) self._old_value = old_value @@ -114,30 +113,40 @@ def _new_code(self): assert False def _get_changes(self) -> Iterator[Change]: + assert isinstance(self._new_value, CustomUndefined) + + new_value = Builder().get_handler(self._old_value.eval()) + + adapter = NewAdapter(self._context) + changes = list(adapter.compare(self._old_value, self._ast_node, new_value)) + + assert all(change.flag == "update" for change in changes) + + return changes + + # def handle(node, obj): + + # adapter = get_adapter_type(obj) + # if adapter is not None and hasattr(adapter, "items"): + # for item in adapter.items(obj, node): + # yield from handle(item.node, item.value) + # return + + # if not isinstance(obj, CustomUnmanaged) and node is not None: + # new_token = value_to_token(obj.eval()) + # if self._file._token_of_node(node) != new_token: + # new_code = self._file._token_to_code(new_token) + + # yield Replace( + # node=self._ast_node, + # file=self._file, + # new_code=new_code, + # flag="update", + # old_value=self._old_value, + # new_value=self._old_value, + # ) - def handle(node, obj): - - adapter = get_adapter_type(obj) - if adapter is not None and hasattr(adapter, "items"): - for item in adapter.items(obj, node): - yield from handle(item.node, item.value) - return - - if not isinstance(obj, CustomUnmanaged) and node is not None: - new_token = value_to_token(obj.eval()) - if self._file._token_of_node(node) != new_token: - new_code = self._file._token_to_code(new_token) - - yield Replace( - node=self._ast_node, - file=self._file, - new_code=new_code, - flag="update", - old_value=self._old_value, - new_value=self._old_value, - ) - - yield from handle(self._ast_node, self._old_value) + # yield from handle(self._ast_node, self._old_value) # functions which determine the type diff --git a/src/inline_snapshot/_utils.py b/src/inline_snapshot/_utils.py index 8b1870f7..559c64e4 100644 --- a/src/inline_snapshot/_utils.py +++ b/src/inline_snapshot/_utils.py @@ -1,11 +1,15 @@ import ast +import copy import io import token import tokenize from collections import namedtuple from pathlib import Path +from inline_snapshot._exceptions import UsageError + from ._code_repr import code_repr +from ._code_repr import value_code_repr def link(text, link=None): @@ -149,7 +153,11 @@ def __eq__(self, other): def value_to_token(value): - input = io.StringIO(code_repr(value)) + return map_strings(code_repr(value)) + + +def map_strings(code_repr): + input = io.StringIO(code_repr) def map_string(tok): """Convert strings with newlines in triple quoted strings.""" @@ -173,3 +181,21 @@ def map_string(tok): for t in tokenize.generate_tokens(input.readline) if t.type not in ignore_tokens ] + + +def clone(obj): + new = copy.deepcopy(obj) + if not obj == new: + raise UsageError( + f"""\ +inline-snapshot uses `copy.deepcopy` to copy objects, +but the copied object is not equal to the original one: + +value = {value_code_repr(obj)} +copied_value = copy.deepcopy(value) +assert value == copied_value + +Please fix the way your object is copied or your __eq__ implementation. +""" + ) + return new diff --git a/src/inline_snapshot/testing/_example.py b/src/inline_snapshot/testing/_example.py index ebe0cd06..53c90d89 100644 --- a/src/inline_snapshot/testing/_example.py +++ b/src/inline_snapshot/testing/_example.py @@ -445,9 +445,7 @@ def run_pytest( Returns: A new Example instance containing the changed files. """ - import pytest - pytest.skip() self.dump_files() with TemporaryDirectory() as dir: diff --git a/tests/adapter/test_dataclass.py b/tests/adapter/test_dataclass.py index b3badfb5..88be9440 100644 --- a/tests/adapter/test_dataclass.py +++ b/tests/adapter/test_dataclass.py @@ -452,9 +452,7 @@ def test_remove_positional_argument(): Example( """\ from inline_snapshot import snapshot - -from inline_snapshot._adapter.generic_call_adapter import GenericCallAdapter -from inline_snapshot._customize import CustomCall +from inline_snapshot._customize import CustomCall,customize class L: @@ -466,19 +464,10 @@ def __eq__(self,other): return NotImplemented return other.l==self.l -class LAdapter(GenericCallAdapter): - @classmethod - def check_type(cls, value_type): - return issubclass(value_type,L) - - @classmethod - def arguments(cls, value): - return CustomCall(L,*value.l) - - @classmethod - def argument(cls, value, pos_or_name): - assert isinstance(pos_or_name,int) - return value.l[pos_or_name] +@customize +def handle(value,builder): + if isinstance(value,L): + return builder.Call(value,L,value.l) def test_L1(): for _ in [1,2]: @@ -498,9 +487,7 @@ def test_L3(): { "tests/test_something.py": """\ from inline_snapshot import snapshot - -from inline_snapshot._adapter.generic_call_adapter import GenericCallAdapter -from inline_snapshot._customize import CustomCall +from inline_snapshot._customize import CustomCall,customize class L: @@ -512,19 +499,10 @@ def __eq__(self,other): return NotImplemented return other.l==self.l -class LAdapter(GenericCallAdapter): - @classmethod - def check_type(cls, value_type): - return issubclass(value_type,L) - - @classmethod - def arguments(cls, value): - return CustomCall(L,*value.l) - - @classmethod - def argument(cls, value, pos_or_name): - assert isinstance(pos_or_name,int) - return value.l[pos_or_name] +@customize +def handle(value,builder): + if isinstance(value,L): + return builder.Call(value,L,value.l) def test_L1(): for _ in [1,2]: diff --git a/tests/test_docs.py b/tests/test_docs.py index 142e8b2c..26ea291c 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -269,7 +269,6 @@ def change_block(block): ], ) def test_docs(file): - pytest.skip() file_test(file) diff --git a/tests/test_factory_adapter.py b/tests/test_factory_adapter.py deleted file mode 100644 index 572db777..00000000 --- a/tests/test_factory_adapter.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Tests for the factory adapter handling.""" - -from collections import defaultdict - -from inline_snapshot import snapshot - - -def test_factory_adapter_defaultdict(): - """Test that factory functions in defaultdict are handled correctly.""" - d = defaultdict(list) - d["test"].append(1) - d["other"].append(2) - snapshot(d) == defaultdict(list, {"test": [1], "other": [2]}) - - -def test_factory_adapter_subclass(): - """Test that custom factory functions are handled correctly.""" - - class CustomFactory: - def __call__(self): - return 42 - - d = defaultdict(CustomFactory()) - d["test"] # Access to trigger factory - snapshot(d) == defaultdict(CustomFactory(), {"test": 42}) - - -def test_factory_adapter_lambda(): - """Test that lambda factory functions are handled correctly.""" - d = defaultdict(lambda: "default") - d["test"] # Access to trigger factory - snapshot(d) == defaultdict(lambda: "default", {"test": "default"}) - - -def test_factory_adapter_builtin_types(): - """Test that builtin type factories are handled correctly.""" - for factory in (list, dict, set, str, int, float): - d = defaultdict(factory) - d["test"] # Access to trigger factory - snapshot(d) == defaultdict(factory, {"test": factory()}) diff --git a/tests/test_preserve_values.py b/tests/test_preserve_values.py index 3017e144..bbaacbec 100644 --- a/tests/test_preserve_values.py +++ b/tests/test_preserve_values.py @@ -174,7 +174,7 @@ def test_preserve_case_from_original_mr(check_update): 5, ], }, - "e": ({"f": 6, "g": 7},), + "e": ({"f": 3 + 3, "g": 7},), } ) """ diff --git a/tests/test_pytest_plugin.py b/tests/test_pytest_plugin.py index 51828c70..2a9cb981 100644 --- a/tests/test_pytest_plugin.py +++ b/tests/test_pytest_plugin.py @@ -510,7 +510,6 @@ def test_assertion_error(project): @pytest.mark.no_rewriting def test_run_without_pytest(pytester): - pytest.skip() # snapshots are deactivated by default pytester.makepyfile( test_file=""" From 5bb4c49c8e5a16463d7ef21095daea3552255d41 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sat, 29 Nov 2025 20:49:23 +0100 Subject: [PATCH 06/72] refactoring --- pyproject.toml | 1 + src/inline_snapshot/_change.py | 4 +- src/inline_snapshot/_customize.py | 2 +- src/inline_snapshot/_new_adapter.py | 76 ++++++++++--------- src/inline_snapshot/_snapshot/dict_value.py | 11 ++- .../_snapshot/undecided_value.py | 7 +- 6 files changed, 53 insertions(+), 48 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index aabf921f..09e05fcf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -175,6 +175,7 @@ extra-dependencies = [ "pytest-mock>=3.14.0", "black==25.1.0", "setuptools" + "attrs", ] env-vars.TOP = "{root}" diff --git a/src/inline_snapshot/_change.py b/src/inline_snapshot/_change.py index eb5c5487..d510a11c 100644 --- a/src/inline_snapshot/_change.py +++ b/src/inline_snapshot/_change.py @@ -117,7 +117,7 @@ def apply_external_changes(self): @dataclass() class Delete(Change): - node: ast.AST + node: ast.AST | None old_value: Any @@ -134,7 +134,7 @@ class AddArgument(Change): @dataclass() class ListInsert(Change): - node: ast.List + node: ast.List | ast.Tuple position: int new_code: list[str] diff --git a/src/inline_snapshot/_customize.py b/src/inline_snapshot/_customize.py index 960cbe88..8406a44f 100644 --- a/src/inline_snapshot/_customize.py +++ b/src/inline_snapshot/_customize.py @@ -309,7 +309,7 @@ def attrs_handler(value, builder: Builder): default_value = ( field.default - if not isinstance(field.default, attrs.Factory) + if not isinstance(field.default, attrs.Factory) # type: ignore else ( field.default.factory() if not field.default.takes_self diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index 3048fcc1..3bbb7a67 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -36,7 +36,7 @@ def reeval(old_value: Custom, value: Custom) -> Custom: return reeval(old_value.value, value) if isinstance(value, CustomDefault): - return CustomDefault(reeval(old_value, value.value)) + return CustomDefault(value=reeval(old_value, value.value)) if type(old_value) is not type(value): return CustomUnmanaged(value.eval()) @@ -75,14 +75,10 @@ def reeval_CustomValue(old_value: CustomValue, value: CustomValue): def reeval_CustomCall(old_value: CustomCall, value: CustomCall): return CustomCall( - _function=reeval(old_value._function, value._function), - _args=[reeval(a, b) for a, b in zip(old_value._args, value._args)], - _kwargs={ - k: reeval(old_value._kwargs[k], value._kwargs[k]) for k in old_value._kwargs - }, - _kwonly={ - k: reeval(old_value._kwonly[k], value._kwonly[k]) for k in old_value._kwonly - }, + reeval(old_value._function, value._function), + [reeval(a, b) for a, b in zip(old_value._args, value._args)], + {k: reeval(old_value._kwargs[k], value._kwargs[k]) for k in old_value._kwargs}, + {k: reeval(old_value._kwonly[k], value._kwonly[k]) for k in old_value._kwonly}, ) @@ -187,14 +183,13 @@ def compare_CustomValue( def compare_CustomSequence( self, old_value: CustomSequence, old_node: ast.AST, new_value: CustomSequence - ) -> Generator[Change, None, CustomList]: + ) -> Generator[Change, None, CustomSequence]: if old_node is not None: - if not isinstance( + assert isinstance( old_node, ast.List if isinstance(old_value.eval(), list) else ast.Tuple - ): - breakpoint() - assert False + ) + assert isinstance(old_node, (ast.List, ast.Tuple)) for e in old_node.elts: if isinstance(e, ast.Starred): @@ -244,7 +239,11 @@ def compare_CustomSequence( for position, code_values in to_insert.items(): yield ListInsert( - "fix", self.context.file._source, old_node, position, *zip(*code_values) + "fix", + self.context.file._source, + old_node, + position, + *zip(*code_values), # type:ignore ) return type(new_value)(result) @@ -254,7 +253,7 @@ def compare_CustomSequence( def compare_CustomDict( self, old_value: CustomDict, old_node: ast.AST, new_value: CustomDict - ) -> Generator[Change, None, CustomDict]: + ) -> Generator[Change, None, Custom]: assert isinstance(old_value, CustomDict) assert isinstance(new_value, CustomDict) @@ -263,13 +262,13 @@ def compare_CustomDict( isinstance(old_node, ast.Dict) and len(old_value.value) == len(old_node.keys) ): - result = yield from self.compare_CustomValue( + result1 = yield from self.compare_CustomValue( old_value, old_node, new_value ) - return result + return result1 - for key, value in zip(old_node.keys, old_node.values): - if key is None: + for key1, value in zip(old_node.keys, old_node.values): + if key1 is None: warnings.warn_explicit( "star-expressions are not supported inside snapshots", filename=self.context.file._source.filename, @@ -278,17 +277,18 @@ def compare_CustomDict( ) return old_value - for value, node in zip(old_value.value.keys(), old_node.keys): + for value2, node in zip(old_value.value.keys(), old_node.keys): - try: - # this is just a sanity check, dicts should be ordered - node_value = ast.literal_eval(node) - except: - continue - assert node_value == value.eval() + if node is not None: + try: + # this is just a sanity check, dicts should be ordered + node_value = ast.literal_eval(node) + except Exception: + continue + assert node_value == value2.eval() result = {} - for key, node in zip( + for key2, node2 in zip( old_value.value.keys(), ( old_node.values @@ -296,10 +296,10 @@ def compare_CustomDict( else [None] * len(old_value.value) ), ): - if key not in new_value.value: + if key2 not in new_value.value: # delete entries yield Delete( - "fix", self.context.file._source, node, old_value.value[key] + "fix", self.context.file._source, node2, old_value.value[key2] ) to_insert = [] @@ -356,15 +356,17 @@ def compare_CustomDict( to_insert, ) - return CustomDict(result) + return CustomDict(value=result) def compare_CustomCall( self, old_value: CustomCall, old_node: ast.AST, new_value: CustomCall - ) -> Generator[Change, None, CustomCall]: + ) -> Generator[Change, None, Custom]: if old_node is None or not isinstance(old_node, ast.Call): - result = yield from self.compare_CustomValue(old_value, old_node, new_value) - return result + result1 = yield from self.compare_CustomValue( + old_value, old_node, new_value + ) + return result1 # positional arguments for pos_arg in old_node.args: @@ -490,11 +492,11 @@ def compare_CustomCall( new_value=value, ) return CustomCall( - _function=( + ( yield from self.compare( old_value._function, old_node.func, new_value._function ) ), - _args=result_args, - _kwargs=result_kwargs, + result_args, + result_kwargs, ) diff --git a/src/inline_snapshot/_snapshot/dict_value.py b/src/inline_snapshot/_snapshot/dict_value.py index 46d9435f..bd9dba56 100644 --- a/src/inline_snapshot/_snapshot/dict_value.py +++ b/src/inline_snapshot/_snapshot/dict_value.py @@ -17,6 +17,9 @@ class DictValue(GenericValue): _current_op = "snapshot[key]" + _new_value: CustomDict + _old_value: CustomDict + def __getitem__(self, index): if isinstance(self._new_value, CustomUndefined): self._new_value = CustomDict({}) @@ -52,14 +55,14 @@ def _re_eval(self, value, context: AdapterContext): ): for key, s in self._new_value.value.items(): if key in self._old_value.value: - s._re_eval(self._old_value.value[key], context) + s._re_eval(self._old_value.value[key], context) # type:ignore def _new_code(self): return ( "{" + ", ".join( [ - f"{self._file._value_to_code(k)}: {v._new_code()}" + f"{self._file._value_to_code(k)}: {v._new_code()}" # type:ignore for k, v in self._new_value.value.items() if not isinstance(v, UndecidedValue) ] @@ -80,7 +83,7 @@ def _get_changes(self) -> Iterator[Change]: for key, node in zip(self._old_value.value.keys(), values): if key in self._new_value.value: # check values with same keys - yield from self._new_value.value[key]._get_changes() + yield from self._new_value.value[key]._get_changes() # type:ignore else: # delete entries yield Delete("trim", self._file, node, self._old_value.value[key]) @@ -91,7 +94,7 @@ def _get_changes(self) -> Iterator[Change]: new_value_element, UndecidedValue ): # add new values - to_insert.append((key, new_value_element._new_code())) + to_insert.append((key, new_value_element._new_code())) # type:ignore if to_insert: new_code = [(self._file._value_to_code(k.eval()), v) for k, v in to_insert] diff --git a/src/inline_snapshot/_snapshot/undecided_value.py b/src/inline_snapshot/_snapshot/undecided_value.py index 98580ad9..37f9da02 100644 --- a/src/inline_snapshot/_snapshot/undecided_value.py +++ b/src/inline_snapshot/_snapshot/undecided_value.py @@ -118,11 +118,10 @@ def _get_changes(self) -> Iterator[Change]: new_value = Builder().get_handler(self._old_value.eval()) adapter = NewAdapter(self._context) - changes = list(adapter.compare(self._old_value, self._ast_node, new_value)) - assert all(change.flag == "update" for change in changes) - - return changes + for change in adapter.compare(self._old_value, self._ast_node, new_value): + assert change.flag == "update" + yield change # def handle(node, obj): From bf111b612aa0ef5128c109ec0413ffbf4fa3908b Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sat, 29 Nov 2025 21:10:55 +0100 Subject: [PATCH 07/72] fix: ci --- src/inline_snapshot/_new_adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index 3bbb7a67..7ee89f6b 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -45,7 +45,7 @@ def reeval(old_value: Custom, value: Custom) -> Custom: result = globals()[function_name](old_value, value) assert isinstance(result, Custom) - assert result == value + # assert result == value,(result,value) return result From 9d87492d879a3650fba2122d18939206436b9418 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sun, 30 Nov 2025 07:46:13 +0100 Subject: [PATCH 08/72] refactor: remove old adapter --- pyproject.toml | 3 + src/inline_snapshot/_adapter/__init__.py | 3 - src/inline_snapshot/_adapter/adapter.py | 110 --------- src/inline_snapshot/_adapter/dict_adapter.py | 161 ------------- .../_adapter/factory_adapter.py | 30 --- .../_adapter/generic_call_adapter.py | 218 ------------------ .../_adapter/sequence_adapter.py | 119 ---------- src/inline_snapshot/_adapter/value_adapter.py | 83 ------- src/inline_snapshot/_adapter_context.py | 26 +++ src/inline_snapshot/_external/_external.py | 2 +- .../_external/_external_location.py | 2 +- src/inline_snapshot/_inline_snapshot.py | 4 +- src/inline_snapshot/_snapshot/dict_value.py | 2 +- .../_snapshot/generic_value.py | 39 +--- .../_snapshot/undecided_value.py | 2 +- 15 files changed, 36 insertions(+), 768 deletions(-) delete mode 100644 src/inline_snapshot/_adapter/__init__.py delete mode 100644 src/inline_snapshot/_adapter/adapter.py delete mode 100644 src/inline_snapshot/_adapter/dict_adapter.py delete mode 100644 src/inline_snapshot/_adapter/factory_adapter.py delete mode 100644 src/inline_snapshot/_adapter/generic_call_adapter.py delete mode 100644 src/inline_snapshot/_adapter/sequence_adapter.py delete mode 100644 src/inline_snapshot/_adapter/value_adapter.py create mode 100644 src/inline_snapshot/_adapter_context.py diff --git a/pyproject.toml b/pyproject.toml index 09e05fcf..46c74a89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,6 +134,9 @@ serve = "mkdocs serve --livereload" [tool.hatch.envs.default] installer="uv" +[tool.hatch.envs.cov] +python="3.12" + [tool.hatch.envs.cov.scripts] github=[ "- rm htmlcov/*", diff --git a/src/inline_snapshot/_adapter/__init__.py b/src/inline_snapshot/_adapter/__init__.py deleted file mode 100644 index 2f699011..00000000 --- a/src/inline_snapshot/_adapter/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .adapter import get_adapter_type - -__all__ = ("get_adapter_type",) diff --git a/src/inline_snapshot/_adapter/adapter.py b/src/inline_snapshot/_adapter/adapter.py deleted file mode 100644 index 5757bca3..00000000 --- a/src/inline_snapshot/_adapter/adapter.py +++ /dev/null @@ -1,110 +0,0 @@ -from __future__ import annotations - -import ast -import typing -from dataclasses import dataclass - -from inline_snapshot._customize import CustomCall -from inline_snapshot._customize import CustomDict -from inline_snapshot._customize import CustomList -from inline_snapshot._customize import CustomTuple -from inline_snapshot._customize import CustomUndefined -from inline_snapshot._customize import CustomValue -from inline_snapshot._source_file import SourceFile - - -def get_adapter_type(value): - - if isinstance(value, CustomCall): - from .generic_call_adapter import CallAdapter - - return CallAdapter - - if isinstance(value, CustomList): - from .sequence_adapter import ListAdapter - - return ListAdapter - - if isinstance(value, CustomTuple): - from .sequence_adapter import TupleAdapter - - return TupleAdapter - - if isinstance(value, CustomDict): - from .dict_adapter import DictAdapter - - return DictAdapter - - if isinstance(value, (CustomValue, CustomUndefined)): - from .value_adapter import ValueAdapter - - return ValueAdapter - - raise AssertionError(f"no handler for {type(value)}") - - -class Item(typing.NamedTuple): - value: typing.Any - node: ast.expr - - -@dataclass -class FrameContext: - globals: dict - locals: dict - - -@dataclass -class AdapterContext: - file: SourceFile - frame: FrameContext | None - qualname: str - - def eval(self, node): - assert self.frame is not None - - return eval( - compile(ast.Expression(node), self.file.filename, "eval"), - self.frame.globals, - self.frame.locals, - ) - - -class Adapter: - context: AdapterContext - - def __init__(self, context: AdapterContext): - self.context = context - - def get_adapter(self, old_value, new_value) -> Adapter: - if type(old_value) is not type(new_value): - from .value_adapter import ValueAdapter - - return ValueAdapter(self.context) - - adapter_type = get_adapter_type(old_value) - if adapter_type is not None: - return adapter_type(self.context) - assert False - - def assign(self, old_value, old_node, new_value): - raise NotImplementedError(self) - - def value_assign(self, old_value, old_node, new_value): - from .value_adapter import ValueAdapter - - adapter = ValueAdapter(self.context) - result = yield from adapter.assign(old_value, old_node, new_value) - return result - - @classmethod - def map(cls, value, map_function): - raise NotImplementedError(cls) - - @classmethod - def repr(cls, value): - raise NotImplementedError(cls) - - -def adapter_map(value, map_function): - return get_adapter_type(value).map(value, map_function) diff --git a/src/inline_snapshot/_adapter/dict_adapter.py b/src/inline_snapshot/_adapter/dict_adapter.py deleted file mode 100644 index 7740b767..00000000 --- a/src/inline_snapshot/_adapter/dict_adapter.py +++ /dev/null @@ -1,161 +0,0 @@ -from __future__ import annotations - -import ast -import warnings - -from inline_snapshot._customize import CustomDict - -from .._change import Delete -from .._change import DictInsert -from ..syntax_warnings import InlineSnapshotSyntaxWarning -from .adapter import Adapter -from .adapter import Item -from .adapter import adapter_map - - -class DictAdapter(Adapter): - - @classmethod - def repr(cls, value): - result = ( - "{" - + ", ".join(f"{repr(k)}: {repr(value)}" for k, value in value.items()) - + "}" - ) - - if type(value) is not dict: - result = f"{repr(type(value))}({result})" - - return result - - @classmethod - def map(cls, value, map_function): - return {k: adapter_map(v, map_function) for k, v in value.items()} - - @classmethod - def items(cls, value, node): - assert isinstance(value, CustomDict) - value = value.value - if node is None or not isinstance(node, ast.Dict): - return [Item(value=value, node=None) for value in value.values()] - - result = [] - - for value_key, node_key, node_value in zip( - value.keys(), node.keys, node.values - ): - try: - # this is just a sanity check, dicts should be ordered - node_key = ast.literal_eval(node_key) - except Exception: - pass - else: - assert node_key == value_key.eval(), f"{node_key!r} != {value_key!r}" - - result.append(Item(value=value[value_key], node=node_value)) - - return result - - def assign(self, old_value, old_node, new_value): - assert isinstance(old_value, CustomDict) - assert isinstance(new_value, CustomDict) - - if old_node is not None: - if not ( - isinstance(old_node, ast.Dict) - and len(old_value.value) == len(old_node.keys) - ): - result = yield from self.value_assign( - old_value.value, old_node, new_value - ) - return result - - for key, value in zip(old_node.keys, old_node.values): - if key is None: - warnings.warn_explicit( - "star-expressions are not supported inside snapshots", - filename=self.context.file._source.filename, - lineno=value.lineno, - category=InlineSnapshotSyntaxWarning, - ) - return old_value - - for value, node in zip(old_value.value.keys(), old_node.keys): - - try: - # this is just a sanity check, dicts should be ordered - node_value = ast.literal_eval(node) - except: - continue - assert node_value == value.eval() - - result = {} - for key, node in zip( - old_value.value.keys(), - ( - old_node.values - if old_node is not None - else [None] * len(old_value.value) - ), - ): - if key not in new_value: - # delete entries - yield Delete( - "fix", self.context.file._source, node, old_value.value[key] - ) - - to_insert = [] - insert_pos = 0 - for key, new_value_element in new_value.items(): - if key not in old_value.value: - # add new values - to_insert.append((key, new_value_element)) - result[key] = new_value_element - else: - if isinstance(old_node, ast.Dict): - node = old_node.values[list(old_value.value.keys()).index(key)] - else: - node = None - # check values with same keys - result[key] = yield from self.get_adapter( - old_value.value[key], new_value.value[key] - ).assign(old_value[key], node, new_value[key]) - - if to_insert: - new_code = [ - ( - self.context.file._value_to_code(k), - self.context.file._value_to_code(v), - ) - for k, v in to_insert - ] - yield DictInsert( - "fix", - self.context.file._source, - old_node, - insert_pos, - new_code, - to_insert, - ) - to_insert = [] - - insert_pos += 1 - - if to_insert: - new_code = [ - ( - self.context.file._value_to_code(k), - self.context.file._value_to_code(v), - ) - for k, v in to_insert - ] - yield DictInsert( - "fix", - self.context.file._source, - old_node, - len(old_value.value), - new_code, - to_insert, - ) - - return CustomDict(result) diff --git a/src/inline_snapshot/_adapter/factory_adapter.py b/src/inline_snapshot/_adapter/factory_adapter.py deleted file mode 100644 index dd182c60..00000000 --- a/src/inline_snapshot/_adapter/factory_adapter.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Factory function adapter for handling factory functions like list() cleanly.""" - -from .adapter import Adapter - - -class FactoryAdapter(Adapter): - """Adapter for factory functions used in defaultdict.""" - - @classmethod - def check_type(cls, value_type): - # Check if value is a factory function (type/class or callable) - if isinstance(value_type, type): - return True - return callable(value_type) - - @classmethod - def repr(cls, value): - # Return clean name for builtin types - value_str = repr(value) - if value_str.startswith(" wrapper - return value_str - - @classmethod - def map(cls, value, map_function): - return value - - def assign(self, old_value, old_node, new_value): - # Preserve factory function identity - return old_value diff --git a/src/inline_snapshot/_adapter/generic_call_adapter.py b/src/inline_snapshot/_adapter/generic_call_adapter.py deleted file mode 100644 index 173465af..00000000 --- a/src/inline_snapshot/_adapter/generic_call_adapter.py +++ /dev/null @@ -1,218 +0,0 @@ -from __future__ import annotations - -import ast -import warnings -from typing import Any - -from inline_snapshot._customize import CustomCall -from inline_snapshot._customize import CustomDefault -from inline_snapshot._customize import unwrap_default - -from .._change import CallArg -from .._change import Delete -from ..syntax_warnings import InlineSnapshotSyntaxWarning -from .adapter import Adapter -from .adapter import Item - - -def get_adapter_for_type(value_type): - assert False, "unreachable" - assert isinstance(value_type, CustomCall) - return CallAdapter - - -class CallAdapter(Adapter): - - @classmethod - def arguments(cls, value) -> CustomCall: - return value - - @classmethod - def argument(cls, value, pos_or_name) -> Any: - return cls.arguments(value).argument(pos_or_name) - - @classmethod - def repr(cls, value): - - call = cls.arguments(value) - - arguments = [repr(value) for value in call.args] + [ - f"{key}={repr(value)}" - for key, value in call.kwargs.items() - if not isinstance(value, CustomDefault) - ] - - return f"{repr(type(value))}({', '.join(arguments)})" - - @classmethod - def map(cls, value, map_function): - return cls.arguments(value).map(map_function) - - @classmethod - def items(cls, value, node): - - args = cls.arguments(value) - new_args = args.args - new_kwargs = args.kwargs - - if node is not None: - assert isinstance(node, ast.Call) - assert all(kw.arg for kw in node.keywords) - kw_arg_node = {kw.arg: kw.value for kw in node.keywords if kw.arg}.get - - def pos_arg_node(pos): - return node.args[pos] - - else: - - def kw_arg_node(_): - return None - - def pos_arg_node(_): - return None - - return [ - Item(value=unwrap_default(arg), node=pos_arg_node(i)) - for i, arg in enumerate(new_args) - ] + [ - Item(value=unwrap_default(kw), node=kw_arg_node(name)) - for name, kw in new_kwargs.items() - ] - - def assign(self, old_value, old_node, new_value): - if old_node is None or not isinstance(old_node, ast.Call): - result = yield from self.value_assign(old_value, old_node, new_value) - return result - - call_type = self.context.eval(old_node.func) - - if not (isinstance(call_type, type) and self.check_type(call_type)): - result = yield from self.value_assign(old_value, old_node, new_value) - return result - - # positional arguments - for pos_arg in old_node.args: - if isinstance(pos_arg, ast.Starred): - warnings.warn_explicit( - "star-expressions are not supported inside snapshots", - filename=self.context.file._source.filename, - lineno=pos_arg.lineno, - category=InlineSnapshotSyntaxWarning, - ) - return old_value - - # keyword arguments - for kw in old_node.keywords: - if kw.arg is None: - warnings.warn_explicit( - "star-expressions are not supported inside snapshots", - filename=self.context.file._source.filename, - lineno=kw.value.lineno, - category=InlineSnapshotSyntaxWarning, - ) - return old_value - - call = self.arguments(new_value) - new_args = call.args - new_kwargs = call.kwargs - - # positional arguments - - result_args = [] - - for i, (new_value_element, node) in enumerate(zip(new_args, old_node.args)): - old_value_element = self.argument(old_value, i) - result = yield from self.get_adapter( - old_value_element, unwrap_default(new_value_element) - ).assign(old_value_element, node, unwrap_default(new_value_element)) - result_args.append(result) - - if len(old_node.args) > len(new_args): - for arg_pos, node in list(enumerate(old_node.args))[len(new_args) :]: - yield Delete( - "fix", - self.context.file._source, - node, - self.argument(old_value, arg_pos), - ) - - if len(old_node.args) < len(new_args): - for insert_pos, value in list(enumerate(new_args))[len(old_node.args) :]: - yield CallArg( - flag="fix", - file=self.context.file._source, - node=old_node, - arg_pos=insert_pos, - arg_name=None, - new_code=self.context.file._value_to_code(unwrap_default(value)), - new_value=value, - ) - - # keyword arguments - result_kwargs = {} - for kw in old_node.keywords: - if kw.arg not in new_kwargs or isinstance( - new_kwargs[kw.arg], CustomDefault - ): - # delete entries - yield Delete( - ( - "update" - if self.argument(old_value, kw.arg) - == self.argument(new_value, kw.arg) - else "fix" - ), - self.context.file._source, - kw.value, - self.argument(old_value, kw.arg), - ) - - old_node_kwargs = {kw.arg: kw.value for kw in old_node.keywords} - - to_insert = [] - insert_pos = 0 - for key, new_value_element in new_kwargs.items(): - if isinstance(new_value_element, CustomDefault): - continue - if key not in old_node_kwargs: - # add new values - to_insert.append((key, new_value_element)) - result_kwargs[key] = new_value_element - else: - node = old_node_kwargs[key] - - # check values with same keys - old_value_element = self.argument(old_value, key) - result_kwargs[key] = yield from self.get_adapter( - old_value_element, new_value_element - ).assign(old_value_element, node, new_value_element) - - if to_insert: - for key, value in to_insert: - - yield CallArg( - flag="fix", - file=self.context.file._source, - node=old_node, - arg_pos=insert_pos, - arg_name=key, - new_code=self.context.file._value_to_code(value), - new_value=value, - ) - to_insert = [] - - insert_pos += 1 - - if to_insert: - for key, value in to_insert: - - yield CallArg( - flag="fix", - file=self.context.file._source, - node=old_node, - arg_pos=insert_pos, - arg_name=key, - new_code=self.context.file._value_to_code(value), - new_value=value, - ) - return type(old_value)(*result_args, **result_kwargs) diff --git a/src/inline_snapshot/_adapter/sequence_adapter.py b/src/inline_snapshot/_adapter/sequence_adapter.py deleted file mode 100644 index f917d246..00000000 --- a/src/inline_snapshot/_adapter/sequence_adapter.py +++ /dev/null @@ -1,119 +0,0 @@ -from __future__ import annotations - -import ast -import warnings -from collections import defaultdict - -from .._align import add_x -from .._align import align -from .._change import Delete -from .._change import ListInsert -from .._compare_context import compare_context -from ..syntax_warnings import InlineSnapshotSyntaxWarning -from .adapter import Adapter -from .adapter import Item -from .adapter import adapter_map - - -class SequenceAdapter(Adapter): - node_type: type - value_type: type - braces: str - trailing_comma: bool - - @classmethod - def repr(cls, value): - if len(value) == 1 and cls.trailing_comma: - seq = repr(value[0]) + "," - else: - seq = ", ".join(map(repr, value)) - return cls.braces[0] + seq + cls.braces[1] - - @classmethod - def map(cls, value, map_function): - result = [adapter_map(v, map_function) for v in value] - return cls.value_type(result) - - @classmethod - def items(cls, value, node): - if node is None or not isinstance(node, cls.node_type): - return [Item(value=v, node=None) for v in value] - - assert len(value.value) == len(node.elts) - - return [Item(value=v, node=n) for v, n in zip(value.value, node.elts)] - - def assign(self, old_value, old_node, new_value): - if old_node is not None: - if not isinstance( - old_node, ast.List if isinstance(old_value, list) else ast.Tuple - ): - result = yield from self.value_assign(old_value, old_node, new_value) - return result - - for e in old_node.elts: - if isinstance(e, ast.Starred): - warnings.warn_explicit( - "star-expressions are not supported inside snapshots", - filename=self.context.file.filename, - lineno=e.lineno, - category=InlineSnapshotSyntaxWarning, - ) - return old_value - - with compare_context(): - diff = add_x(align(old_value, new_value)) - old = zip( - old_value, - old_node.elts if old_node is not None else [None] * len(old_value), - ) - new = iter(new_value) - old_position = 0 - to_insert = defaultdict(list) - result = [] - for c in diff: - if c in "mx": - old_value_element, old_node_element = next(old) - new_value_element = next(new) - v = yield from self.get_adapter( - old_value_element, new_value_element - ).assign(old_value_element, old_node_element, new_value_element) - result.append(v) - old_position += 1 - elif c == "i": - new_value_element = next(new) - new_code = self.context.file._value_to_code(new_value_element) - result.append(new_value_element) - to_insert[old_position].append((new_code, new_value_element)) - elif c == "d": - old_value_element, old_node_element = next(old) - yield Delete( - "fix", - self.context.file._source, - old_node_element, - old_value_element, - ) - old_position += 1 - else: - assert False - - for position, code_values in to_insert.items(): - yield ListInsert( - "fix", self.context.file._source, old_node, position, *zip(*code_values) - ) - - return self.value_type(result) - - -class ListAdapter(SequenceAdapter): - node_type = ast.List - value_type = list - braces = "[]" - trailing_comma = False - - -class TupleAdapter(SequenceAdapter): - node_type = ast.Tuple - value_type = tuple - braces = "()" - trailing_comma = True diff --git a/src/inline_snapshot/_adapter/value_adapter.py b/src/inline_snapshot/_adapter/value_adapter.py deleted file mode 100644 index e443af4c..00000000 --- a/src/inline_snapshot/_adapter/value_adapter.py +++ /dev/null @@ -1,83 +0,0 @@ -from __future__ import annotations - -import ast -import warnings - -from inline_snapshot._customize import Custom -from inline_snapshot._customize import CustomUnmanaged -from inline_snapshot._customize import CustomValue - -from .._change import Replace -from .._code_repr import value_code_repr -from .._sentinels import undefined -from .._unmanaged import update_allowed -from .._utils import value_to_token -from ..syntax_warnings import InlineSnapshotInfo -from .adapter import Adapter - - -class ValueAdapter(Adapter): - - @classmethod - def repr(cls, value): - return value_code_repr(value) - - @classmethod - def map(cls, value, map_function): - return map_function(value) - - def assign(self, old_value, old_node, new_value): - # generic fallback - assert isinstance(old_value, Custom) - assert isinstance(new_value, Custom) - - # because IsStr() != IsStr() - if isinstance(old_value, CustomUnmanaged): - return old_value - - if old_node is None: - new_token = [] - else: - new_token = value_to_token(new_value.eval()) - - if ( - isinstance(old_node, ast.JoinedStr) - and isinstance(new_value, CustomValue) - and isinstance(new_value.value, str) - ): - if not old_value.eval() == new_value.eval(): - warnings.warn_explicit( - f"inline-snapshot will be able to fix f-strings in the future.\nThe current string value is:\n {new_value!r}", - filename=self.context.file._source.filename, - lineno=old_node.lineno, - category=InlineSnapshotInfo, - ) - return old_value - - if not old_value == new_value: - if old_value is undefined: - flag = "create" - else: - flag = "fix" - elif ( - old_node is not None - and update_allowed(old_value) - and self.context.file._token_of_node(old_node) != new_token - ): - flag = "update" - else: - # equal and equal repr - return old_value - - new_code = self.context.file._token_to_code(new_token) - - yield Replace( - node=old_node, - file=self.context.file._source, - new_code=new_code, - flag=flag, - old_value=old_value, - new_value=new_value, - ) - - return new_value diff --git a/src/inline_snapshot/_adapter_context.py b/src/inline_snapshot/_adapter_context.py new file mode 100644 index 00000000..9de8cc55 --- /dev/null +++ b/src/inline_snapshot/_adapter_context.py @@ -0,0 +1,26 @@ +import ast +from dataclasses import dataclass + +from inline_snapshot._source_file import SourceFile + + +@dataclass +class FrameContext: + globals: dict + locals: dict + + +@dataclass +class AdapterContext: + file: SourceFile + frame: FrameContext | None + qualname: str + + def eval(self, node): + assert self.frame is not None + + return eval( + compile(ast.Expression(node), self.file.filename, "eval"), + self.frame.globals, + self.frame.locals, + ) diff --git a/src/inline_snapshot/_external/_external.py b/src/inline_snapshot/_external/_external.py index 876809fc..b6800db9 100644 --- a/src/inline_snapshot/_external/_external.py +++ b/src/inline_snapshot/_external/_external.py @@ -3,7 +3,7 @@ import ast from pathlib import Path -from inline_snapshot._adapter.adapter import AdapterContext +from inline_snapshot._adapter_context import AdapterContext from inline_snapshot._change import CallArg from inline_snapshot._change import Replace from inline_snapshot._exceptions import UsageError diff --git a/src/inline_snapshot/_external/_external_location.py b/src/inline_snapshot/_external/_external_location.py index 28f6a3ea..f5adddad 100644 --- a/src/inline_snapshot/_external/_external_location.py +++ b/src/inline_snapshot/_external/_external_location.py @@ -8,7 +8,7 @@ from pathlib import Path from typing import Generator -from inline_snapshot._adapter.adapter import AdapterContext +from inline_snapshot._adapter_context import AdapterContext class Location: diff --git a/src/inline_snapshot/_inline_snapshot.py b/src/inline_snapshot/_inline_snapshot.py index d2200d92..6f23dd2c 100644 --- a/src/inline_snapshot/_inline_snapshot.py +++ b/src/inline_snapshot/_inline_snapshot.py @@ -7,12 +7,12 @@ from executing import Source +from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._adapter_context import FrameContext from inline_snapshot._customize import CustomUndefined from inline_snapshot._source_file import SourceFile from inline_snapshot._types import SnapshotRefBase -from ._adapter.adapter import AdapterContext -from ._adapter.adapter import FrameContext from ._change import CallArg from ._change import Change from ._global_state import state diff --git a/src/inline_snapshot/_snapshot/dict_value.py b/src/inline_snapshot/_snapshot/dict_value.py index bd9dba56..24a196cc 100644 --- a/src/inline_snapshot/_snapshot/dict_value.py +++ b/src/inline_snapshot/_snapshot/dict_value.py @@ -5,7 +5,7 @@ from inline_snapshot._customize import CustomDict from inline_snapshot._customize import CustomUndefined -from .._adapter.adapter import AdapterContext +from .._adapter_context import AdapterContext from .._change import Change from .._change import Delete from .._change import DictInsert diff --git a/src/inline_snapshot/_snapshot/generic_value.py b/src/inline_snapshot/_snapshot/generic_value.py index c0d3e4b4..f1cece1f 100644 --- a/src/inline_snapshot/_snapshot/generic_value.py +++ b/src/inline_snapshot/_snapshot/generic_value.py @@ -1,16 +1,13 @@ import ast from typing import Iterator +from inline_snapshot._adapter_context import AdapterContext from inline_snapshot._customize import Builder from inline_snapshot._customize import Custom from inline_snapshot._customize import CustomUndefined -from inline_snapshot._customize import CustomUnmanaged from inline_snapshot._new_adapter import reeval -from .._adapter.adapter import AdapterContext -from .._adapter.adapter import get_adapter_type from .._change import Change -from .._exceptions import UsageError from .._global_state import state from .._types import SnapshotBase from .._unmanaged import declare_unmanaged @@ -47,45 +44,11 @@ def _return(self, result, new_result=True): def _file(self): return self._context.file - def get_adapter(self, value): - return get_adapter_type(value)(self._context) - def _re_eval(self, value, context: AdapterContext): self._old_value = reeval(self._old_value, Builder().get_handler(value)) return - self._context = context - - def re_eval(old_value, node, value): - if isinstance(old_value, CustomUnmanaged): - old_value.value = value - return - - assert type(old_value) is type(value) - - adapter = self.get_adapter(old_value) - if adapter is not None and hasattr(adapter, "items"): - old_items = adapter.items(old_value, node) - new_items = adapter.items(value, node) - assert len(old_items) == len(new_items) - - for old_item, new_item in zip(old_items, new_items): - re_eval(old_item.value, old_item.node, new_item.value) - - else: - if not isinstance(old_value, CustomUnmanaged): - if not old_value.eval() == value.eval(): - raise UsageError( - "snapshot value should not change. Use Is(...) for dynamic snapshot parts." - ) - else: - assert ( - False - ), f"old_value ({type(old_value)}) should be converted to Unmanaged" - - re_eval(self._old_value, self._ast_node, Builder().get_handler(value)) - def _ignore_old(self): return ( state().update_flags.fix diff --git a/src/inline_snapshot/_snapshot/undecided_value.py b/src/inline_snapshot/_snapshot/undecided_value.py index 37f9da02..52efbf96 100644 --- a/src/inline_snapshot/_snapshot/undecided_value.py +++ b/src/inline_snapshot/_snapshot/undecided_value.py @@ -13,7 +13,7 @@ from inline_snapshot._customize import CustomValue from inline_snapshot._new_adapter import NewAdapter -from .._adapter.adapter import AdapterContext +from .._adapter_context import AdapterContext from .._change import Change from .generic_value import GenericValue From a2d0f90cff059e2c6a2796a32a1f1502e223c255 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Mon, 1 Dec 2025 17:25:00 +0100 Subject: [PATCH 09/72] refactor: coverage --- src/inline_snapshot/_code_repr.py | 48 ----------------------- src/inline_snapshot/_customize.py | 33 ++++++++++++++++ src/inline_snapshot/_new_adapter.py | 12 +----- src/inline_snapshot/_snapshot/eq_value.py | 8 ++-- tests/adapter/test_general.py | 22 +++++++++++ 5 files changed, 60 insertions(+), 63 deletions(-) diff --git a/src/inline_snapshot/_code_repr.py b/src/inline_snapshot/_code_repr.py index 16b2ada4..28d5cd62 100644 --- a/src/inline_snapshot/_code_repr.py +++ b/src/inline_snapshot/_code_repr.py @@ -2,8 +2,6 @@ from enum import Enum from enum import Flag from functools import singledispatch -from types import BuiltinFunctionType -from types import FunctionType from unittest import mock real_repr = repr @@ -117,49 +115,3 @@ def _(value: Enum): def _(value: Flag): name = type(value).__qualname__ return " | ".join(f"{name}.{flag.name}" for flag in type(value) if flag in value) - - -def sort_set_values(set_values): - is_sorted = False - try: - set_values = sorted(set_values) - is_sorted = True - except TypeError: - pass - - set_values = list(map(repr, set_values)) - if not is_sorted: - set_values = sorted(set_values) - - return set_values - - -@customize_repr -def _(value: set): - if len(value) == 0: - return "set()" - - return "{" + ", ".join(sort_set_values(value)) + "}" - - -@customize_repr -def _(value: frozenset): - if len(value) == 0: - return "frozenset()" - - return "frozenset({" + ", ".join(sort_set_values(value)) + "})" - - -@customize_repr -def _(value: type): - return value.__qualname__ - - -@customize_repr -def _(value: FunctionType): - return value.__qualname__ - - -@customize_repr -def _(value: BuiltinFunctionType): - return value.__name__ diff --git a/src/inline_snapshot/_customize.py b/src/inline_snapshot/_customize.py index 8406a44f..66efc13b 100644 --- a/src/inline_snapshot/_customize.py +++ b/src/inline_snapshot/_customize.py @@ -259,6 +259,39 @@ def path_handler(value, builder: Builder): return builder.Call(value, PurePath, [value.as_posix()]) +def sort_set_values(set_values): + is_sorted = False + try: + set_values = sorted(set_values) + is_sorted = True + except TypeError: + pass + + set_values = list(map(repr, set_values)) + if not is_sorted: + set_values = sorted(set_values) + + return set_values + + +@customize +def set_handler(value, builder: Builder): + if isinstance(value, set): + if len(value) == 0: + return builder.Value(value, "set()") + else: + return builder.Value(value, "{" + ", ".join(sort_set_values(value)) + "}") + + +@customize +def frozenset_handler(value, builder: Builder): + if isinstance(value, frozenset): + if len(value) == 0: + return builder.Value(value, "frozenset()") + else: + return builder.Call(value, frozenset, [set(value)]) + + @customize def dataclass_handler(value, builder: Builder): diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index 7ee89f6b..af86bf08 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -130,10 +130,6 @@ def compare_CustomValue( assert isinstance(old_value, Custom) assert isinstance(new_value, Custom) - # because IsStr() != IsStr() - if isinstance(old_value, CustomUnmanaged): - return old_value - if old_node is None: new_token = [] else: @@ -359,15 +355,9 @@ def compare_CustomDict( return CustomDict(value=result) def compare_CustomCall( - self, old_value: CustomCall, old_node: ast.AST, new_value: CustomCall + self, old_value: CustomCall, old_node: ast.Call, new_value: CustomCall ) -> Generator[Change, None, Custom]: - if old_node is None or not isinstance(old_node, ast.Call): - result1 = yield from self.compare_CustomValue( - old_value, old_node, new_value - ) - return result1 - # positional arguments for pos_arg in old_node.args: if isinstance(pos_arg, ast.Starred): diff --git a/src/inline_snapshot/_snapshot/eq_value.py b/src/inline_snapshot/_snapshot/eq_value.py index 67fada04..37b81078 100644 --- a/src/inline_snapshot/_snapshot/eq_value.py +++ b/src/inline_snapshot/_snapshot/eq_value.py @@ -17,7 +17,7 @@ class EqValue(GenericValue): _changes: List[Change] def __eq__(self, other): - other = Builder().get_handler(other) + custom_other = Builder().get_handler(other) if isinstance(self._old_value, CustomUndefined): state().missing_values += 1 @@ -26,7 +26,7 @@ def __eq__(self, other): self._changes = [] adapter = NewAdapter(self._context) - it = iter(adapter.compare(self._old_value, self._ast_node, other)) + it = iter(adapter.compare(self._old_value, self._ast_node, custom_other)) while True: try: self._changes.append(next(it)) @@ -35,8 +35,8 @@ def __eq__(self, other): break return self._return( - self._old_value.eval() == other.eval(), - self._new_value.eval() == other.eval(), + self._old_value.eval() == custom_other.eval(), + self._new_value.eval() == custom_other.eval(), ) def _new_code(self): diff --git a/tests/adapter/test_general.py b/tests/adapter/test_general.py index 3afad504..71fb5684 100644 --- a/tests/adapter/test_general.py +++ b/tests/adapter/test_general.py @@ -45,3 +45,25 @@ def test_thing(): assert (i,) == (Is(i),) """ ).run_pytest(["--inline-snapshot=short-report"], report=snapshot("")) + + +def test_usageerror_unmanaged(): + + Example( + """\ +from inline_snapshot import snapshot,Is + + +def test_thing(): + assert [Is(5)] == snapshot([6]) +""" + ).run_inline( + ["--inline-snapshot=fix"], + report=snapshot(""), + raises=snapshot( + """\ +UsageError: +unmanaged values can not be compared with snapshots\ +""" + ), + ) From 8363e8e17684fc8f87b4f5651478bdc4d7186131 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Mon, 1 Dec 2025 20:16:45 +0100 Subject: [PATCH 10/72] refactor: coverage --- .../20251201_200505_15r10nk-git_customize.md | 3 + src/inline_snapshot/_customize.py | 18 ++---- tests/test_code_repr.py | 64 ++++++++++++++++++- 3 files changed, 71 insertions(+), 14 deletions(-) create mode 100644 changelog.d/20251201_200505_15r10nk-git_customize.md diff --git a/changelog.d/20251201_200505_15r10nk-git_customize.md b/changelog.d/20251201_200505_15r10nk-git_customize.md new file mode 100644 index 00000000..b6cc6f8a --- /dev/null +++ b/changelog.d/20251201_200505_15r10nk-git_customize.md @@ -0,0 +1,3 @@ +### Added + +- `pathlib.Path/PurePath` values are now never stored as `Posix/WindowsPath` or their Pure variants, which improves the writing of platform independent tests. diff --git a/src/inline_snapshot/_customize.py b/src/inline_snapshot/_customize.py index 66efc13b..20258ba9 100644 --- a/src/inline_snapshot/_customize.py +++ b/src/inline_snapshot/_customize.py @@ -39,10 +39,8 @@ def __hash__(self): return hash(self.eval()) def __eq__(self, other): - if isinstance(other, Custom): - return self.eval() == other.eval() - - return NotImplemented + assert isinstance(other, Custom) + return self.eval() == other.eval() @abstractmethod def map(self, f): @@ -61,7 +59,9 @@ class CustomDefault(Custom): value: Custom = field(compare=False) def repr(self): - return self.value.repr() + assert ( + False + ), "this should never be called because default values are never converted into code" def map(self, f): return self.value.map(f) @@ -125,14 +125,6 @@ def all_pos_args(self): def kwargs(self): return {**self._kwargs, **self._kwonly} - def kwonly(self, **kwonly): - assert not self._kwonly, "you should not call kwonly twice" - assert ( - not kwonly.keys() & self._kwargs.keys() - ), "same keys in kwargs and kwonly arguments" - self._kwonly = kwonly - return self - def argument(self, pos_or_str): if isinstance(pos_or_str, int): return unwrap_default(self.all_pos_args[pos_or_str]) diff --git a/tests/test_code_repr.py b/tests/test_code_repr.py index 3657f3b5..83c8c6e8 100644 --- a/tests/test_code_repr.py +++ b/tests/test_code_repr.py @@ -49,6 +49,56 @@ class color(Enum): ) +def test_path(): + + Example( + """\ +from pathlib import Path,PurePath +from inline_snapshot import snapshot + +folder="a" + +def test_a(): + assert Path(folder,"b.txt") == snapshot() + assert PurePath(folder,"b.txt") == snapshot() +""" + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from pathlib import Path,PurePath +from inline_snapshot import snapshot + +folder="a" + +def test_a(): + assert Path(folder,"b.txt") == snapshot(Path("a/b.txt")) + assert PurePath(folder,"b.txt") == snapshot(PurePath("a/b.txt")) +""" + } + ), + ).replace( + '"a"', '"c"' + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from pathlib import Path,PurePath +from inline_snapshot import snapshot + +folder="c" + +def test_a(): + assert Path(folder,"b.txt") == snapshot(Path("c/b.txt")) + assert PurePath(folder,"b.txt") == snapshot(PurePath("c/b.txt")) +""" + } + ), + ) + + def test_snapshot_generates_hasrepr(): Example( @@ -334,7 +384,7 @@ def test_datatypes_explicit(): assert code_repr(default_dict) == snapshot("defaultdict(list, {5: [2], 3: [1]})") -def test_tuple(): +def test_fake_tuple1(): class FakeTuple(tuple): def __init__(self): @@ -346,6 +396,18 @@ def __repr__(self): assert code_repr(FakeTuple()) == snapshot("FakeTuple()") +def test_fake_tuple2(): + + class FakeTuple(tuple): + def __init__(self): + self._fields = 1 + + def __repr__(self): + return "FakeTuple()" + + assert code_repr(FakeTuple()) == snapshot("FakeTuple()") + + def test_invalid_repr(check_update): assert ( check_update( From 341ed47e4f14daafb35cf61709b5077be2f3e87a Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Mon, 1 Dec 2025 21:14:45 +0100 Subject: [PATCH 11/72] refactor: coverage --- src/inline_snapshot/_customize.py | 5 ++-- src/inline_snapshot/_new_adapter.py | 38 ++++++++++++----------------- src/inline_snapshot/_utils.py | 4 ++- tests/test_code_repr.py | 6 ++--- 4 files changed, 22 insertions(+), 31 deletions(-) diff --git a/src/inline_snapshot/_customize.py b/src/inline_snapshot/_customize.py index 20258ba9..4beac028 100644 --- a/src/inline_snapshot/_customize.py +++ b/src/inline_snapshot/_customize.py @@ -59,9 +59,8 @@ class CustomDefault(Custom): value: Custom = field(compare=False) def repr(self): - assert ( - False - ), "this should never be called because default values are never converted into code" + # this should never be called because default values are never converted into code + assert False def map(self, f): return self.value.map(f) diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index af86bf08..977ca4d4 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -20,7 +20,6 @@ from inline_snapshot._customize import CustomDict from inline_snapshot._customize import CustomList from inline_snapshot._customize import CustomSequence -from inline_snapshot._customize import CustomTuple from inline_snapshot._customize import CustomUndefined from inline_snapshot._customize import CustomUnmanaged from inline_snapshot._customize import CustomValue @@ -54,6 +53,9 @@ def reeval_CustomList(old_value: CustomList, value: CustomList): return CustomList([reeval(a, b) for a, b in zip(old_value.value, value.value)]) +reeval_CustomTuple = reeval_CustomList + + def reeval_CustomUnmanaged(old_value: CustomUnmanaged, value: CustomUnmanaged): old_value.value = value.value return old_value @@ -82,11 +84,6 @@ def reeval_CustomCall(old_value: CustomCall, value: CustomCall): ) -def reeval_CustomTuple(old_value, value): - assert len(old_value.value) == len(value.value) - return CustomTuple([reeval(a, b) for a, b in zip(old_value.value, value.value)]) - - def reeval_CustomDict(old_value, value): assert len(old_value.value) == len(value.value) return CustomDict( @@ -196,6 +193,8 @@ def compare_CustomSequence( category=InlineSnapshotSyntaxWarning, ) return old_value + else: + pass # pragma: no cover with compare_context(): diff = add_x(align(old_value.value, new_value.value)) @@ -248,20 +247,12 @@ def compare_CustomSequence( compare_CustomList = compare_CustomSequence def compare_CustomDict( - self, old_value: CustomDict, old_node: ast.AST, new_value: CustomDict + self, old_value: CustomDict, old_node: ast.Dict, new_value: CustomDict ) -> Generator[Change, None, Custom]: assert isinstance(old_value, CustomDict) assert isinstance(new_value, CustomDict) if old_node is not None: - if not ( - isinstance(old_node, ast.Dict) - and len(old_value.value) == len(old_node.keys) - ): - result1 = yield from self.compare_CustomValue( - old_value, old_node, new_value - ) - return result1 for key1, value in zip(old_node.keys, old_node.values): if key1 is None: @@ -274,14 +265,15 @@ def compare_CustomDict( return old_value for value2, node in zip(old_value.value.keys(), old_node.keys): - - if node is not None: - try: - # this is just a sanity check, dicts should be ordered - node_value = ast.literal_eval(node) - except Exception: - continue - assert node_value == value2.eval() + assert node is not None + try: + # this is just a sanity check, dicts should be ordered + node_value = ast.literal_eval(node) + except Exception: + continue + assert node_value == value2.eval() + else: + pass # pragma: no cover result = {} for key2, node2 in zip( diff --git a/src/inline_snapshot/_utils.py b/src/inline_snapshot/_utils.py index 559c64e4..cf680eb7 100644 --- a/src/inline_snapshot/_utils.py +++ b/src/inline_snapshot/_utils.py @@ -143,7 +143,9 @@ def __eq__(self, other): for s in (self.string, other.string) for suffix in ("f", "rf", "Rf", "F", "rF", "RF") ): - return False + # I don't know why this is not covered/(maybe needed) with the new customize algo + # but I think it is better to handle it as 'no cover' for now + return False # pragma: no cover return ast.literal_eval(self.string) == ast.literal_eval( other.string diff --git a/tests/test_code_repr.py b/tests/test_code_repr.py index 83c8c6e8..20f1c80f 100644 --- a/tests/test_code_repr.py +++ b/tests/test_code_repr.py @@ -387,8 +387,7 @@ def test_datatypes_explicit(): def test_fake_tuple1(): class FakeTuple(tuple): - def __init__(self): - self._fields = 5 + _fields = 5 def __repr__(self): return "FakeTuple()" @@ -399,8 +398,7 @@ def __repr__(self): def test_fake_tuple2(): class FakeTuple(tuple): - def __init__(self): - self._fields = 1 + _fields = (1, 2) def __repr__(self): return "FakeTuple()" From 1bb67471373fb3fb4627b58d90c77bbfd995a045 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Thu, 4 Dec 2025 08:41:47 +0100 Subject: [PATCH 12/72] refactor: coverage --- src/inline_snapshot/_new_adapter.py | 131 +++++++++++++++++----------- 1 file changed, 80 insertions(+), 51 deletions(-) diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index 977ca4d4..6ec320b4 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -4,6 +4,7 @@ import warnings from collections import defaultdict from typing import Generator +from typing import Sequence from inline_snapshot._align import add_x from inline_snapshot._align import align @@ -109,23 +110,34 @@ def compare( if isinstance(new_value, CustomUnmanaged): raise UsageError("unmanaged values can not be compared with snapshots") - if type(old_value) is not type(new_value) or not isinstance( - old_node, new_value.node_type + if ( + type(old_value) is type(new_value) + and ( + isinstance(old_node, new_value.node_type) + if old_node is not None + else True + ) + and ( + isinstance(old_value, (CustomCall, CustomSequence)) + if old_node is None + else True + ) ): + function_name = f"compare_{type(old_value).__name__}" + result = yield from getattr(self, function_name)( + old_value, old_node, new_value + ) + else: result = yield from self.compare_CustomValue(old_value, old_node, new_value) - return result - - function_name = f"compare_{type(old_value).__name__}" - result = yield from getattr(self, function_name)(old_value, old_node, new_value) - return result def compare_CustomValue( - self, old_value: Custom, old_node: ast.AST, new_value: Custom + self, old_value: Custom, old_node: ast.expr, new_value: Custom ) -> Generator[Change, None, Custom]: assert isinstance(old_value, Custom) assert isinstance(new_value, Custom) + assert isinstance(old_node, (ast.expr, type(None))), old_node if old_node is None: new_token = [] @@ -200,7 +212,7 @@ def compare_CustomSequence( diff = add_x(align(old_value.value, new_value.value)) old = zip( old_value.value, - old_node.elts if old_node is not None else [None] * len(old_value), + old_node.elts if old_node is not None else [None] * len(old_value.value), ) new = iter(new_value.value) old_position = 0 @@ -301,6 +313,7 @@ def compare_CustomDict( if isinstance(old_node, ast.Dict): node = old_node.values[list(old_value.value.keys()).index(key)] else: + assert False node = None # check values with same keys result[key] = yield from self.compare( @@ -350,27 +363,28 @@ def compare_CustomCall( self, old_value: CustomCall, old_node: ast.Call, new_value: CustomCall ) -> Generator[Change, None, Custom]: - # positional arguments - for pos_arg in old_node.args: - if isinstance(pos_arg, ast.Starred): - warnings.warn_explicit( - "star-expressions are not supported inside snapshots", - filename=self.context.file._source.filename, - lineno=pos_arg.lineno, - category=InlineSnapshotSyntaxWarning, - ) - return old_value + if old_node is not None: + # positional arguments + for pos_arg in old_node.args: + if isinstance(pos_arg, ast.Starred): + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=self.context.file._source.filename, + lineno=pos_arg.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return old_value - # keyword arguments - for kw in old_node.keywords: - if kw.arg is None: - warnings.warn_explicit( - "star-expressions are not supported inside snapshots", - filename=self.context.file._source.filename, - lineno=kw.value.lineno, - category=InlineSnapshotSyntaxWarning, - ) - return old_value + # keyword arguments + for kw in old_node.keywords: + if kw.arg is None: + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=self.context.file._source.filename, + lineno=kw.value.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return old_value call = new_value new_args = call.args @@ -380,22 +394,31 @@ def compare_CustomCall( result_args = [] - for i, (new_value_element, node) in enumerate(zip(new_args, old_node.args)): + old_node_args: Sequence[ast.expr | None] + if old_node: + old_node_args = old_node.args + else: + old_node_args = [None] * len(new_args) + + for i, (new_value_element, node) in enumerate(zip(new_args, old_node_args)): old_value_element = old_value.argument(i) result = yield from self.compare(old_value_element, node, new_value_element) result_args.append(result) - if len(old_node.args) > len(new_args): - for arg_pos, node in list(enumerate(old_node.args))[len(new_args) :]: - yield Delete( - "fix", - self.context.file._source, - node, - old_value.argument(arg_pos), - ) + old_args_len = len(old_node.args if old_node else old_value.args) + + if old_node is not None: + if old_args_len > len(new_args): + for arg_pos, node in list(enumerate(old_node.args))[len(new_args) :]: + yield Delete( + "fix", + self.context.file._source, + node, + old_value.argument(arg_pos), + ) - if len(old_node.args) < len(new_args): - for insert_pos, value in list(enumerate(new_args))[len(old_node.args) :]: + if old_args_len < len(new_args): + for insert_pos, value in list(enumerate(new_args))[old_args_len:]: yield CallArg( flag="fix", file=self.context.file._source, @@ -408,35 +431,38 @@ def compare_CustomCall( # keyword arguments result_kwargs = {} - for kw in old_node.keywords: - if kw.arg not in new_kwargs or isinstance( - new_kwargs[kw.arg], CustomDefault + if old_node is None: + old_keywords = {key: None for key in old_value._kwargs.keys()} + else: + old_keywords = {kw.arg: kw.value for kw in old_node.keywords} + + for kw_arg, kw_value in old_keywords.items(): + if kw_arg not in new_kwargs or isinstance( + new_kwargs[kw_arg], CustomDefault ): # delete entries yield Delete( ( "update" - if old_value.argument(kw.arg) == new_value.argument(kw.arg) + if old_value.argument(kw_arg) == new_value.argument(kw_arg) else "fix" ), self.context.file._source, - kw.value, - old_value.argument(kw.arg), + kw_value, + old_value.argument(kw_arg), ) - old_node_kwargs = {kw.arg: kw.value for kw in old_node.keywords} - to_insert = [] insert_pos = 0 for key, new_value_element in new_kwargs.items(): if isinstance(new_value_element, CustomDefault): continue - if key not in old_node_kwargs: + if key not in old_keywords: # add new values to_insert.append((key, new_value_element)) result_kwargs[key] = new_value_element else: - node = old_node_kwargs[key] + node = old_keywords[key] # check values with same keys old_value_element = old_value.argument(key) @@ -473,10 +499,13 @@ def compare_CustomCall( new_code=value.repr(), new_value=value, ) + return CustomCall( ( yield from self.compare( - old_value._function, old_node.func, new_value._function + old_value._function, + old_node.func if old_node else None, + new_value._function, ) ), result_args, From c39758b1678456c100e75ae1899c7a87c6476887 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Fri, 12 Dec 2025 17:47:33 +0100 Subject: [PATCH 13/72] fix: fixed some issues --- docs/customize.md | 79 ++++++ mkdocs.yml | 1 + src/inline_snapshot/__init__.py | 8 + src/inline_snapshot/_code_repr.py | 2 +- src/inline_snapshot/_customize.py | 235 ++++++++++++++---- .../_snapshot/collection_value.py | 4 +- src/inline_snapshot/_snapshot/dict_value.py | 2 +- src/inline_snapshot/_snapshot/eq_value.py | 6 +- .../_snapshot/generic_value.py | 2 +- .../_snapshot/min_max_value.py | 4 +- .../_snapshot/undecided_value.py | 4 +- tests/adapter/test_dataclass.py | 4 +- 12 files changed, 282 insertions(+), 69 deletions(-) create mode 100644 docs/customize.md diff --git a/docs/customize.md b/docs/customize.md new file mode 100644 index 00000000..69262ce5 --- /dev/null +++ b/docs/customize.md @@ -0,0 +1,79 @@ +`@customize` allows you to register special hooks that control how inline-snapshot generates your snapshots. +You should use it when you find yourself manually editing snapshots after they were created by inline-snapshot. + +inline-snapshot calls each hook until it finds one that returns a custom object, which can be created with the `create_*` methods of the [`Builder`][inline_snapshot.Builder]. + +One use case might be that you have a dataclass with a special constructor function that can be used for certain instances of this dataclass, and you want inline-snapshot to use this constructor when possible. + + + +``` python +from dataclasses import dataclass + +from inline_snapshot import customize +from inline_snapshot import Builder +from inline_snapshot import snapshot + + +@dataclass +class Rect: + width: int + height: int + + @staticmethod + def make_quadrat(size): + return Rect(size, size) + + +@customize +def quadrat_handler(value, builder: Builder): + if isinstance(value, Rect) and value.width == value.height: + return builder.create_call(Rect.make_quadrat, [value.width]) + + +def test_quadrat(): + assert Rect.make_quadrat(5) == snapshot(Rect.make_quadrat(5)) # (1)! + assert Rect(1, 1) == snapshot(Rect.make_quadrat(1)) # (2)! + assert Rect(1, 2) == snapshot(Rect(width=1, height=2)) # (3)! +``` + +1. Your handler is used because you created a quadrat +2. Your handler is used because you created a rect that happens to have the same width and height +3. Your handler is not used because width and height are different + + +It can also be used to teach inline-snapshot to use specific dirty-equals expressions for specific values. + + + +``` python +from dataclasses import dataclass + +from inline_snapshot import customize +from inline_snapshot import Builder +from inline_snapshot import snapshot + +from dirty_equals import IsNow +from datetime import datetime + + +@customize +def quadrat_handler(value, builder: Builder): + if value == IsNow(): + return builder.create_call(IsNow) + + +def test_quadrat(): + assert datetime.now() == snapshot(IsNow()) +``` + + + + +::: inline_snapshot + options: + heading_level: 3 + members: [customize,Builder,Custom,CustomizeHandler] + show_root_heading: false + show_bases: false + show_source: false diff --git a/mkdocs.yml b/mkdocs.yml index 6759f2e9..cf13ae26 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -63,6 +63,7 @@ nav: - external_file(): external/external_file.md - outsource(): external/outsource.md - '@register_format()': external/register_format.md + - '@customize': customize.md - '@customize_repr': customize_repr.md - types: types.md - get_snapshot_value(): get_snapshot_value.md diff --git a/src/inline_snapshot/__init__.py b/src/inline_snapshot/__init__.py index 69a573b2..5743d1cd 100644 --- a/src/inline_snapshot/__init__.py +++ b/src/inline_snapshot/__init__.py @@ -3,6 +3,10 @@ from ._code_repr import HasRepr from ._code_repr import customize_repr +from ._customize import Builder +from ._customize import Custom +from ._customize import CustomizeHandler +from ._customize import customize from ._exceptions import UsageError from ._external._external import external from ._external._external_file import external_file @@ -37,4 +41,8 @@ "declare_unmanaged", "get_snapshot_value", "__version__", + "customize", + "Custom", + "Builder", + "CustomizeHandler", ] diff --git a/src/inline_snapshot/_code_repr.py b/src/inline_snapshot/_code_repr.py index 28d5cd62..890e0da7 100644 --- a/src/inline_snapshot/_code_repr.py +++ b/src/inline_snapshot/_code_repr.py @@ -81,7 +81,7 @@ def code_repr(obj): def mocked_code_repr(obj): from inline_snapshot._customize import Builder - return Builder().get_handler(obj).repr() + return Builder()._get_handler(obj).repr() def value_code_repr(obj): diff --git a/src/inline_snapshot/_customize.py b/src/inline_snapshot/_customize.py index 4beac028..f8c928d0 100644 --- a/src/inline_snapshot/_customize.py +++ b/src/inline_snapshot/_customize.py @@ -5,32 +5,25 @@ from abc import abstractmethod from collections import Counter from collections import defaultdict +from dataclasses import MISSING +from dataclasses import dataclass +from dataclasses import field +from dataclasses import fields +from dataclasses import is_dataclass from pathlib import Path from pathlib import PurePath from types import BuiltinFunctionType from types import FunctionType from typing import Any from typing import Callable +from typing import TypeAlias from inline_snapshot._code_repr import value_code_repr +from inline_snapshot._sentinels import undefined +from inline_snapshot._unmanaged import is_dirty_equal from inline_snapshot._unmanaged import is_unmanaged from inline_snapshot._utils import clone -custom_functions = [] - -from dataclasses import MISSING -from dataclasses import dataclass -from dataclasses import field -from dataclasses import fields -from dataclasses import is_dataclass - -from inline_snapshot._sentinels import undefined - - -def customize(f: Callable[[Any, Builder], Custom | None]): - custom_functions.append(f) - return f - class Custom(ABC): node_type: type[ast.AST] = ast.AST @@ -54,6 +47,89 @@ def eval(self): return self.map(lambda a: a) +CustomizeHandler: TypeAlias = Callable[[Any, "Builder"], Custom | None] +""" +Type alias for customization handler functions. + +A customization handler is a function that takes a Python value and a Builder, +and returns either a Custom representation or None. + +The handler receives two parameters: + +- `value` (Any): The Python object to be converted to snapshot code +- `builder` (Builder): Helper object providing methods to create Custom representations + +The handler should return a Custom object if it processes the value type, or None otherwise. +""" + + +custom_functions = [] + + +def customize(f: CustomizeHandler) -> CustomizeHandler: + """ + Registers a function as a customization hook inside inline-snapshot. + + Customization hooks allow you to control how objects are represented in snapshot code. + When inline-snapshot generates code for a value, it calls each registered customization + function in reverse order of registration until one returns a Custom object. + + **Important**: Customization handlers should be registered in your `conftest.py` file to ensure + they are loaded before your tests run. + + Args: + f: A customization handler function. See [CustomizeHandler][inline_snapshot._customize.CustomizeHandler] + for the expected signature. + + Returns: + The input function unchanged (for use as a decorator) + + Example: + Basic usage with a custom class: + + + ``` python + from inline_snapshot import customize, snapshot + + + class MyClass: + def __init__(self, arg1, arg2, key=None): + self.arg1 = arg1 + self.arg2 = arg2 + self.key_attr = key + + + @customize + def my_custom_handler(value, builder): + if isinstance(value, MyClass): + # Generate code like: MyClass(arg1, arg2, key=value) + return builder.create_call( + MyClass, [value.arg1, value.arg2], {"key": value.key_attr} + ) + return None # Let other handlers process this value + + + def test_myclass(): + obj = MyClass(42, "hello", key="world") + assert obj == snapshot(MyClass(42, "hello", key="world")) + ``` + + Note: + - **Always register handlers in `conftest.py`** to ensure they're available for all tests + - Handlers are called in **reverse order** of registration (last registered is called first) + - If no handler returns a Custom object, a default representation is used + - Use builder methods (`create_call`, `create_list`, `create_dict`, etc.) to construct representations + - Always return `None` if your handler doesn't apply to the given value type + - The builder automatically handles recursive conversion of nested values + + See Also: + - [Builder][inline_snapshot._customize.Builder]: Available builder methods + - [Custom][inline_snapshot._customize.Custom]: Base class for custom representations + """ + custom_functions.append(f) + return f + + @dataclass(frozen=True) class CustomDefault(Custom): value: Custom = field(compare=False) @@ -208,46 +284,46 @@ def __repr__(self): @customize def standard_handler(value, builder: Builder): if isinstance(value, list): - return builder.List(value) + return builder.create_list(value) if type(value) is tuple: - return builder.Tuple(value) + return builder.create_tuple(value) if isinstance(value, dict): - return builder.Dict(value) + return builder.create_dict(value) @customize def counter_handler(value, builder: Builder): if isinstance(value, Counter): - return builder.Call(value, Counter, [dict(value)]) + return builder.create_call(Counter, [dict(value)]) @customize def function_handler(value, builder: Builder): if isinstance(value, FunctionType): - return builder.Value(value, value.__qualname__) + return builder.create_value(value, value.__qualname__) @customize def builtin_function_handler(value, builder: Builder): if isinstance(value, BuiltinFunctionType): - return builder.Value(value, value.__name__) + return builder.create_value(value, value.__name__) @customize def type_handler(value, builder: Builder): if isinstance(value, type): - return builder.Value(value, value.__qualname__) + return builder.create_value(value, value.__qualname__) @customize def path_handler(value, builder: Builder): if isinstance(value, Path): - return builder.Call(value, Path, [value.as_posix()]) + return builder.create_call(Path, [value.as_posix()]) if isinstance(value, PurePath): - return builder.Call(value, PurePath, [value.as_posix()]) + return builder.create_call(PurePath, [value.as_posix()]) def sort_set_values(set_values): @@ -269,18 +345,20 @@ def sort_set_values(set_values): def set_handler(value, builder: Builder): if isinstance(value, set): if len(value) == 0: - return builder.Value(value, "set()") + return builder.create_value(value, "set()") else: - return builder.Value(value, "{" + ", ".join(sort_set_values(value)) + "}") + return builder.create_value( + value, "{" + ", ".join(sort_set_values(value)) + "}" + ) @customize def frozenset_handler(value, builder: Builder): if isinstance(value, frozenset): if len(value) == 0: - return builder.Value(value, "frozenset()") + return builder.create_value(value, "frozenset()") else: - return builder.Call(value, frozenset, [set(value)]) + return builder.create_call(frozenset, [set(value)]) @customize @@ -305,10 +383,10 @@ def dataclass_handler(value, builder: Builder): is_default = True if is_default: - field_value = builder.Default(field_value) + field_value = builder.create_default(field_value) kwargs[field.name] = field_value - return builder.Call(value, type(value), [], kwargs, {}) + return builder.create_call(type(value), [], kwargs, {}) try: @@ -345,11 +423,11 @@ def attrs_handler(value, builder: Builder): is_default = True if is_default: - field_value = builder.Default(field_value) + field_value = builder.create_default(field_value) kwargs[field.name] = field_value - return builder.Call(value, type(value), [], kwargs, {}) + return builder.create_call(type(value), [], kwargs, {}) try: @@ -399,11 +477,11 @@ def attrs_handler(value, builder: Builder): is_default = True if is_default: - field_value = builder.Default(field_value) + field_value = builder.create_default(field_value) kwargs[name] = field_value - return builder.Call(value, type(value), [], kwargs, {}) + return builder.create_call(type(value), [], kwargs, {}) @customize @@ -420,8 +498,7 @@ def namedtuple_handler(value, builder: Builder): # TODO handle with builder.Default - return builder.Call( - value, + return builder.create_call( type(value), [], { @@ -437,8 +514,8 @@ def namedtuple_handler(value, builder: Builder): @customize def defaultdict_handler(value, builder: Builder): if isinstance(value, defaultdict): - return builder.Call( - value, type(value), [value.default_factory, dict(value)], {}, {} + return builder.create_call( + type(value), [value.default_factory, dict(value)], {}, {} ) @@ -454,8 +531,20 @@ def undefined_handler(value, builder: Builder): return CustomUndefined() +@customize +def dirty_equals_handler(value, builder: Builder): + if is_dirty_equal(value) and builder._build_new_value: + if isinstance(value, type): + return builder.create_value(value, value.__name__) + else: + return builder.create_call(type(value)) + + +@dataclass class Builder: - def get_handler(self, v) -> Custom: + _build_new_value: bool = False + + def _get_handler(self, v) -> Custom: if isinstance(v, Custom): return v @@ -465,21 +554,39 @@ def get_handler(self, v) -> Custom: return r return CustomValue(v) - def List(self, value) -> CustomList: - custom = [self.get_handler(v) for v in value] + def create_list(self, value) -> Custom: + """ + Creates an intermediate node for a list-expression which can be used as a result for your customization function. + + `create_list([1,2,3])` becomes `[1,2,3]` in the code. + List elements are recursively converted into CustomNodes. + """ + custom = [self._get_handler(v) for v in value] return CustomList(value=custom) - def Tuple(self, value) -> CustomTuple: - custom = [self.get_handler(v) for v in value] + def create_tuple(self, value) -> Custom: + """ + Creates an intermediate node for a tuple-expression which can be used as a result for your customization function. + + `create_tuple((1, 2, 3))` becomes `(1, 2, 3)` in the code. + Tuple elements are recursively converted into CustomNodes. + """ + custom = [self._get_handler(v) for v in value] return CustomTuple(value=custom) - def Call( - self, value, function, posonly_args=[], kwargs={}, kwonly_args={} - ) -> CustomCall: - function = self.get_handler(function) - posonly_args = [self.get_handler(arg) for arg in posonly_args] - kwargs = {k: self.get_handler(arg) for k, arg in kwargs.items()} - kwonly_args = {k: self.get_handler(arg) for k, arg in kwonly_args.items()} + def create_call( + self, function, posonly_args=[], kwargs={}, kwonly_args={} + ) -> Custom: + """ + Creates an intermediate node for a function call expression which can be used as a result for your customization function. + + `create_call(MyClass, [arg1, arg2], {'key': value})` becomes `MyClass(arg1, arg2, key=value)` in the code. + Function, arguments, and keyword arguments are recursively converted into CustomNodes. + """ + function = self._get_handler(function) + posonly_args = [self._get_handler(arg) for arg in posonly_args] + kwargs = {k: self._get_handler(arg) for k, arg in kwargs.items()} + kwonly_args = {k: self._get_handler(arg) for k, arg in kwonly_args.items()} return CustomCall( _function=function, @@ -488,12 +595,30 @@ def Call( _kwonly=kwonly_args, ) - def Default(self, value) -> CustomDefault: - return CustomDefault(value=self.get_handler(value)) + def create_default(self, value) -> Custom: + """ + Creates an intermediate node for a default value which can be used as a result for your customization function. - def Dict(self, value) -> CustomDict: - custom = {self.get_handler(k): self.get_handler(v) for k, v in value.items()} + Default values are not included in the generated code when they match the actual default. + The value is recursively converted into a CustomNode. + """ + return CustomDefault(value=self._get_handler(value)) + + def create_dict(self, value) -> Custom: + """ + Creates an intermediate node for a dict-expression which can be used as a result for your customization function. + + `create_dict({'key': 'value'})` becomes `{'key': 'value'}` in the code. + Dict keys and values are recursively converted into CustomNodes. + """ + custom = {self._get_handler(k): self._get_handler(v) for k, v in value.items()} return CustomDict(value=custom) - def Value(self, value, repr) -> CustomValue: + def create_value(self, value, repr) -> CustomValue: + """ + Creates an intermediate node for a value with a custom representation which can be used as a result for your customization function. + + `create_value(my_obj, 'MyClass')` becomes `MyClass` in the code. + Use this when you want to control the exact string representation of a value. + """ return CustomValue(value, repr) diff --git a/src/inline_snapshot/_snapshot/collection_value.py b/src/inline_snapshot/_snapshot/collection_value.py index 2fea9504..0f4d1406 100644 --- a/src/inline_snapshot/_snapshot/collection_value.py +++ b/src/inline_snapshot/_snapshot/collection_value.py @@ -23,10 +23,10 @@ def __contains__(self, item): state().missing_values += 1 if isinstance(self._new_value, CustomUndefined): - self._new_value = CustomList([Builder().get_handler(item)]) + self._new_value = CustomList([Builder()._get_handler(item)]) else: if item not in self._new_value.eval(): - self._new_value.value.append(Builder().get_handler(item)) + self._new_value.value.append(Builder()._get_handler(item)) if ignore_old_value() or isinstance(self._old_value, CustomUndefined): return True diff --git a/src/inline_snapshot/_snapshot/dict_value.py b/src/inline_snapshot/_snapshot/dict_value.py index 24a196cc..7e2e4b00 100644 --- a/src/inline_snapshot/_snapshot/dict_value.py +++ b/src/inline_snapshot/_snapshot/dict_value.py @@ -24,7 +24,7 @@ def __getitem__(self, index): if isinstance(self._new_value, CustomUndefined): self._new_value = CustomDict({}) - index = Builder().get_handler(index) + index = Builder()._get_handler(index) if index not in self._new_value.value: if isinstance(self._old_value, CustomUndefined): diff --git a/src/inline_snapshot/_snapshot/eq_value.py b/src/inline_snapshot/_snapshot/eq_value.py index 37b81078..6971e69e 100644 --- a/src/inline_snapshot/_snapshot/eq_value.py +++ b/src/inline_snapshot/_snapshot/eq_value.py @@ -17,7 +17,7 @@ class EqValue(GenericValue): _changes: List[Change] def __eq__(self, other): - custom_other = Builder().get_handler(other) + custom_other = Builder(_build_new_value=True)._get_handler(other) if isinstance(self._old_value, CustomUndefined): state().missing_values += 1 @@ -35,8 +35,8 @@ def __eq__(self, other): break return self._return( - self._old_value.eval() == custom_other.eval(), - self._new_value.eval() == custom_other.eval(), + self._old_value.eval() == other, + self._new_value.eval() == other, ) def _new_code(self): diff --git a/src/inline_snapshot/_snapshot/generic_value.py b/src/inline_snapshot/_snapshot/generic_value.py index f1cece1f..5089001b 100644 --- a/src/inline_snapshot/_snapshot/generic_value.py +++ b/src/inline_snapshot/_snapshot/generic_value.py @@ -46,7 +46,7 @@ def _file(self): def _re_eval(self, value, context: AdapterContext): - self._old_value = reeval(self._old_value, Builder().get_handler(value)) + self._old_value = reeval(self._old_value, Builder()._get_handler(value)) return def _ignore_old(self): diff --git a/src/inline_snapshot/_snapshot/min_max_value.py b/src/inline_snapshot/_snapshot/min_max_value.py index ffa2d680..b59ef986 100644 --- a/src/inline_snapshot/_snapshot/min_max_value.py +++ b/src/inline_snapshot/_snapshot/min_max_value.py @@ -23,13 +23,13 @@ def _generic_cmp(self, other): state().missing_values += 1 if isinstance(self._new_value, CustomUndefined): - self._new_value = Builder().get_handler(other) + self._new_value = Builder()._get_handler(other) if isinstance(self._old_value, CustomUndefined) or ignore_old_value(): return True return self._return(self.cmp(self._old_value.eval(), other)) else: if not self.cmp(self._new_value.eval(), other): - self._new_value = Builder().get_handler(other) + self._new_value = Builder()._get_handler(other) return self._return(self.cmp(self._visible_value().eval(), other)) diff --git a/src/inline_snapshot/_snapshot/undecided_value.py b/src/inline_snapshot/_snapshot/undecided_value.py index 52efbf96..549a7156 100644 --- a/src/inline_snapshot/_snapshot/undecided_value.py +++ b/src/inline_snapshot/_snapshot/undecided_value.py @@ -97,7 +97,7 @@ def verify_call(value: Custom, node: ast.Call, eval) -> Custom: class UndecidedValue(GenericValue): def __init__(self, old_value, ast_node, context: AdapterContext): - old_value = Builder().get_handler(old_value) + old_value = Builder()._get_handler(old_value) old_value = verify(old_value, ast_node, context.eval) assert isinstance(old_value, Custom) @@ -115,7 +115,7 @@ def _new_code(self): def _get_changes(self) -> Iterator[Change]: assert isinstance(self._new_value, CustomUndefined) - new_value = Builder().get_handler(self._old_value.eval()) + new_value = Builder()._get_handler(self._old_value.eval()) adapter = NewAdapter(self._context) diff --git a/tests/adapter/test_dataclass.py b/tests/adapter/test_dataclass.py index 88be9440..85f7df04 100644 --- a/tests/adapter/test_dataclass.py +++ b/tests/adapter/test_dataclass.py @@ -467,7 +467,7 @@ def __eq__(self,other): @customize def handle(value,builder): if isinstance(value,L): - return builder.Call(value,L,value.l) + return builder.create_call(L,value.l) def test_L1(): for _ in [1,2]: @@ -502,7 +502,7 @@ def __eq__(self,other): @customize def handle(value,builder): if isinstance(value,L): - return builder.Call(value,L,value.l) + return builder.create_call(L,value.l) def test_L1(): for _ in [1,2]: From 4bfb262087298fec80d99b28d701d8f7d27c6faa Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Tue, 16 Dec 2025 17:23:32 +0100 Subject: [PATCH 14/72] refactor: convert original ast to custom object --- src/inline_snapshot/_change.py | 2 +- src/inline_snapshot/_new_adapter.py | 95 +++++++------ .../_snapshot/generic_value.py | 15 +- .../_snapshot/undecided_value.py | 130 ++++++++---------- tests/adapter/test_dataclass.py | 4 +- 5 files changed, 126 insertions(+), 120 deletions(-) diff --git a/src/inline_snapshot/_change.py b/src/inline_snapshot/_change.py index d510a11c..33e2d438 100644 --- a/src/inline_snapshot/_change.py +++ b/src/inline_snapshot/_change.py @@ -320,7 +320,7 @@ def arg_token_range(node): else len(parent.args) + len(parent.keywords) ) to_insert[position].append( - f"{change.arg_name} = {change.new_code}" + f"{change.arg_name}={change.new_code}" ) else: assert change.arg_pos is not None diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index 6ec320b4..94e8d831 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -30,6 +30,55 @@ from inline_snapshot.syntax_warnings import InlineSnapshotSyntaxWarning +def warn_star_expression(node, context): + if isinstance(node, ast.Call): + for pos_arg in node.args: + if isinstance(pos_arg, ast.Starred): + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=context.file._source.filename, + lineno=pos_arg.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return True + + # keyword arguments + for kw in node.keywords: + if kw.arg is None: + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=context.file._source.filename, + lineno=kw.value.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return True + + if isinstance(node, (ast.Tuple, ast.List)): + + for e in node.elts: + if isinstance(e, ast.Starred): + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=context.file.filename, + lineno=e.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return True + if isinstance(node, ast.Dict): + + for key1, value in zip(node.keys, node.values): + if key1 is None: + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=context.file._source.filename, + lineno=value.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return True + + return False + + def reeval(old_value: Custom, value: Custom) -> Custom: if isinstance(old_value, CustomDefault): @@ -158,7 +207,7 @@ def compare_CustomValue( ) return old_value - if not old_value == new_value: + if not old_value.eval() == new_value.eval(): if isinstance(old_value, CustomUndefined): flag = "create" else: @@ -196,15 +245,8 @@ def compare_CustomSequence( ) assert isinstance(old_node, (ast.List, ast.Tuple)) - for e in old_node.elts: - if isinstance(e, ast.Starred): - warnings.warn_explicit( - "star-expressions are not supported inside snapshots", - filename=self.context.file.filename, - lineno=e.lineno, - category=InlineSnapshotSyntaxWarning, - ) - return old_value + if warn_star_expression(old_node, self.context): + return old_value else: pass # pragma: no cover @@ -266,15 +308,8 @@ def compare_CustomDict( if old_node is not None: - for key1, value in zip(old_node.keys, old_node.values): - if key1 is None: - warnings.warn_explicit( - "star-expressions are not supported inside snapshots", - filename=self.context.file._source.filename, - lineno=value.lineno, - category=InlineSnapshotSyntaxWarning, - ) - return old_value + if warn_star_expression(old_node, self.context): + return old_value for value2, node in zip(old_value.value.keys(), old_node.keys): assert node is not None @@ -365,26 +400,8 @@ def compare_CustomCall( if old_node is not None: # positional arguments - for pos_arg in old_node.args: - if isinstance(pos_arg, ast.Starred): - warnings.warn_explicit( - "star-expressions are not supported inside snapshots", - filename=self.context.file._source.filename, - lineno=pos_arg.lineno, - category=InlineSnapshotSyntaxWarning, - ) - return old_value - - # keyword arguments - for kw in old_node.keywords: - if kw.arg is None: - warnings.warn_explicit( - "star-expressions are not supported inside snapshots", - filename=self.context.file._source.filename, - lineno=kw.value.lineno, - category=InlineSnapshotSyntaxWarning, - ) - return old_value + if warn_star_expression(old_node, self.context): + return old_value call = new_value new_args = call.args diff --git a/src/inline_snapshot/_snapshot/generic_value.py b/src/inline_snapshot/_snapshot/generic_value.py index 5089001b..b55706ab 100644 --- a/src/inline_snapshot/_snapshot/generic_value.py +++ b/src/inline_snapshot/_snapshot/generic_value.py @@ -44,10 +44,21 @@ def _return(self, result, new_result=True): def _file(self): return self._context.file + def value_to_custom(self, value): + if isinstance(value, Custom): + return value + + if self._ast_node is None: + return Builder()._get_handler(value) + else: + from inline_snapshot._snapshot.undecided_value import AstToCustom + + return AstToCustom(self._context).convert(value, self._ast_node) + def _re_eval(self, value, context: AdapterContext): + self._context = context - self._old_value = reeval(self._old_value, Builder()._get_handler(value)) - return + self._old_value = reeval(self._old_value, self.value_to_custom(value)) def _ignore_old(self): return ( diff --git a/src/inline_snapshot/_snapshot/undecided_value.py b/src/inline_snapshot/_snapshot/undecided_value.py index 549a7156..89b38637 100644 --- a/src/inline_snapshot/_snapshot/undecided_value.py +++ b/src/inline_snapshot/_snapshot/undecided_value.py @@ -1,10 +1,10 @@ import ast +from typing import Any from typing import Iterator from inline_snapshot._customize import Builder from inline_snapshot._customize import Custom from inline_snapshot._customize import CustomCall -from inline_snapshot._customize import CustomDefault from inline_snapshot._customize import CustomDict from inline_snapshot._customize import CustomList from inline_snapshot._customize import CustomTuple @@ -12,93 +12,71 @@ from inline_snapshot._customize import CustomUnmanaged from inline_snapshot._customize import CustomValue from inline_snapshot._new_adapter import NewAdapter +from inline_snapshot._new_adapter import warn_star_expression +from inline_snapshot._unmanaged import is_unmanaged from .._adapter_context import AdapterContext from .._change import Change from .generic_value import GenericValue -def verify(value: Custom, node: ast.AST, eval) -> Custom: - """Verify that a Custom value matches its corresponding AST node structure.""" - if isinstance(value, CustomUnmanaged): - return value - if isinstance(value, CustomDefault): - return CustomDefault(value=verify(value.value, node, eval)) - - if isinstance(node, ast.List): - return verify_list(value, node, eval) - elif isinstance(node, ast.Tuple): - return verify_tuple(value, node, eval) - elif isinstance(node, ast.Dict): - return verify_dict(value, node, eval) - elif isinstance(node, ast.Call): - return verify_call(value, node, eval) - else: - # For other types, return the value as-is - return value - - -def verify_list(value: Custom, node: ast.List, eval) -> Custom: - """Verify a CustomList matches its List AST node.""" - assert isinstance(value, CustomList) - return CustomList([verify(v, n, eval) for v, n in zip(value.value, node.elts)]) - - -def verify_tuple(value: Custom, node: ast.Tuple, eval) -> Custom: - """Verify a CustomTuple matches its Tuple AST node.""" - assert isinstance(value, CustomTuple) - return CustomTuple([verify(v, n, eval) for v, n in zip(value.value, node.elts)]) - - -def verify_dict(value: Custom, node: ast.Dict, eval) -> Custom: - """Verify a CustomDict matches its Dict AST node.""" - assert isinstance(value, CustomDict) - if any(key is None for key in node.keys): - return value - - verified_items = {} - for (key, val), key_node, val_node in zip( - value.value.items(), node.keys, node.values - ): - verified_key = verify(key, key_node, eval) if key_node else key - verified_val = verify(val, val_node, eval) - verified_items[verified_key] = verified_val - return CustomDict(value=verified_items) - - -def verify_call(value: Custom, node: ast.Call, eval) -> Custom: - """Verify a CustomCall matches its Call AST node.""" - - if not isinstance(value, CustomCall) or eval(node.func) != value._function.eval(): - return CustomValue(eval(node), ast.unparse(node)) - - # Verify function - verified_function = verify(value._function, node.func, eval) - - # Verify positional arguments - verified_args = [] - for arg, arg_node in zip(value._args, node.args): - verified_args.append(verify(arg, arg_node, eval)) - - # Verify keyword arguments - verified_kwargs = {} - keyword_map = {kw.arg: kw.value for kw in node.keywords if kw.arg} - for key, val in value._kwargs.items(): - if key in keyword_map: - verified_kwargs[key] = verify(val, keyword_map[key], eval) +class AstToCustom: + + def __init__(self, context): + self.eval = context.eval + self.context = context + + def convert(self, value: Any, node: ast.expr): + if is_unmanaged(value): + return CustomUnmanaged(value) + + if warn_star_expression(node, self.context): + return self.convert_generic(value, node) + + t = type(node).__name__ + return getattr(self, "convert_" + t, self.convert_generic)(value, node) + + def eval_convert(self, node): + return self.convert(self.eval(node), node) + + def convert_generic(self, value: Any, node: ast.expr): + if value is ...: + return CustomUndefined() else: - verified_kwargs[key] = val + return CustomValue(value, ast.unparse(node)) + + def convert_Call(self, value: Any, node: ast.Call): + return CustomCall( + self.eval_convert(node.func), + [self.eval_convert(a) for a in node.args], + {kw.arg: self.eval_convert(kw.value) for kw in node.keywords if kw.arg}, + ) + + def convert_List(self, value: list, node: ast.List): + + return CustomList([self.convert(v, n) for v, n in zip(value, node.elts)]) + + def convert_Tuple(self, value: tuple, node: ast.Tuple): + return CustomTuple([self.convert(v, n) for v, n in zip(value, node.elts)]) - return CustomCall( - _function=verified_function, _args=verified_args, _kwargs=verified_kwargs - ) + def convert_Dict(self, value: dict, node: ast.Dict): + return CustomDict( + { + self.convert(k, k_node): self.convert(v, v_node) + for (k, v), k_node, v_node in zip(value.items(), node.keys, node.values) + if k_node is not None + } + ) class UndecidedValue(GenericValue): def __init__(self, old_value, ast_node, context: AdapterContext): - old_value = Builder()._get_handler(old_value) - old_value = verify(old_value, ast_node, context.eval) + if not isinstance(old_value, Custom): + if ast_node is not None: + old_value = AstToCustom(context).convert(old_value, ast_node) + else: + old_value = Builder()._get_handler(old_value) assert isinstance(old_value, Custom) self._old_value = old_value @@ -120,7 +98,7 @@ def _get_changes(self) -> Iterator[Change]: adapter = NewAdapter(self._context) for change in adapter.compare(self._old_value, self._ast_node, new_value): - assert change.flag == "update" + assert change.flag == "update", change yield change # def handle(node, obj): diff --git a/tests/adapter/test_dataclass.py b/tests/adapter/test_dataclass.py index 85f7df04..3ebbccf1 100644 --- a/tests/adapter/test_dataclass.py +++ b/tests/adapter/test_dataclass.py @@ -374,7 +374,7 @@ def test_something(): assert A(a=3) == snapshot(A(**{"a":5})),"not equal" """ ).run_inline( - ["--inline-snapshot=fix"], + ["--inline-snapshot=report"], raises=snapshot( """\ AssertionError: @@ -414,7 +414,7 @@ class A: c:int=0 def test_something(): - assert A(a=3,b=3,c=3) == snapshot(A(a = 3, b=3, c = 3)),"not equal" + assert A(a=3,b=3,c=3) == snapshot(A(a=3, b=3, c=3)),"not equal" """ } ), From 8b34536bb846139678f697ed4b92eb44754ca85c Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Wed, 17 Dec 2025 10:06:40 +0100 Subject: [PATCH 15/72] fix: update works now with customize --- .../{_customize.py => _customize/__init__.py} | 62 +++++-------------- src/inline_snapshot/_customize/_custom.py | 51 +++++++++++++++ src/inline_snapshot/_global_state.py | 4 ++ src/inline_snapshot/_new_adapter.py | 27 +------- src/inline_snapshot/fix_pytest_diff.py | 13 ---- tests/test_customize.py | 47 ++++++++++++++ 6 files changed, 122 insertions(+), 82 deletions(-) rename src/inline_snapshot/{_customize.py => _customize/__init__.py} (93%) create mode 100644 src/inline_snapshot/_customize/_custom.py create mode 100644 tests/test_customize.py diff --git a/src/inline_snapshot/_customize.py b/src/inline_snapshot/_customize/__init__.py similarity index 93% rename from src/inline_snapshot/_customize.py rename to src/inline_snapshot/_customize/__init__.py index f8c928d0..29911c12 100644 --- a/src/inline_snapshot/_customize.py +++ b/src/inline_snapshot/_customize/__init__.py @@ -10,6 +10,7 @@ from dataclasses import field from dataclasses import fields from dataclasses import is_dataclass +from dataclasses import replace from pathlib import Path from pathlib import PurePath from types import BuiltinFunctionType @@ -19,51 +20,14 @@ from typing import TypeAlias from inline_snapshot._code_repr import value_code_repr +from inline_snapshot._customize._custom import CustomizeHandler from inline_snapshot._sentinels import undefined from inline_snapshot._unmanaged import is_dirty_equal from inline_snapshot._unmanaged import is_unmanaged from inline_snapshot._utils import clone - -class Custom(ABC): - node_type: type[ast.AST] = ast.AST - - def __hash__(self): - return hash(self.eval()) - - def __eq__(self, other): - assert isinstance(other, Custom) - return self.eval() == other.eval() - - @abstractmethod - def map(self, f): - raise NotImplementedError() - - @abstractmethod - def repr(self): - raise NotImplementedError() - - def eval(self): - return self.map(lambda a: a) - - -CustomizeHandler: TypeAlias = Callable[[Any, "Builder"], Custom | None] -""" -Type alias for customization handler functions. - -A customization handler is a function that takes a Python value and a Builder, -and returns either a Custom representation or None. - -The handler receives two parameters: - -- `value` (Any): The Python object to be converted to snapshot code -- `builder` (Builder): Helper object providing methods to create Custom representations - -The handler should return a Custom object if it processes the value type, or None otherwise. -""" - - -custom_functions = [] +from ._custom import Custom +from ._custom import CustomizeHandler def customize(f: CustomizeHandler) -> CustomizeHandler: @@ -126,7 +90,9 @@ def test_myclass(): - [Builder][inline_snapshot._customize.Builder]: Available builder methods - [Custom][inline_snapshot._customize.Custom]: Base class for custom representations """ - custom_functions.append(f) + from inline_snapshot._global_state import state + + state().custom_functions.append(f) return f @@ -147,7 +113,7 @@ class CustomUnmanaged(Custom): value: Any def repr(self): - return "" + return "'unmanaged'" # pragma: no cover def map(self, f): return f(self.value) @@ -548,11 +514,17 @@ def _get_handler(self, v) -> Custom: if isinstance(v, Custom): return v - for f in reversed(custom_functions): + from inline_snapshot._global_state import state + + for f in reversed(state().custom_functions): r = f(v, self) if isinstance(r, Custom): - return r - return CustomValue(v) + break + else: + r = CustomValue(v) + + r.__dict__["original_value"] = v + return r def create_list(self, value) -> Custom: """ diff --git a/src/inline_snapshot/_customize/_custom.py b/src/inline_snapshot/_customize/_custom.py new file mode 100644 index 00000000..5ebe55a2 --- /dev/null +++ b/src/inline_snapshot/_customize/_custom.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import ast +from abc import ABC +from abc import abstractmethod +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import TypeAlias + +if TYPE_CHECKING: + from inline_snapshot._customize import Builder + + +class Custom(ABC): + node_type: type[ast.AST] = ast.AST + original_value: Any + + def __hash__(self): + return hash(self.eval()) + + def __eq__(self, other): + assert isinstance(other, Custom) + return self.eval() == other.eval() + + @abstractmethod + def map(self, f): + raise NotImplementedError() + + @abstractmethod + def repr(self): + raise NotImplementedError() + + def eval(self): + return self.map(lambda a: a) + + +CustomizeHandler: TypeAlias = Callable[[Any, "Builder"], Custom | None] +""" +Type alias for customization handler functions. + +A customization handler is a function that takes a Python value and a Builder, +and returns either a Custom representation or None. + +The handler receives two parameters: + +- `value` (Any): The Python object to be converted to snapshot code +- `builder` (Builder): Helper object providing methods to create Custom representations + +The handler should return a Custom object if it processes the value type, or None otherwise. +""" diff --git a/src/inline_snapshot/_global_state.py b/src/inline_snapshot/_global_state.py index 57fdc589..9e535c73 100644 --- a/src/inline_snapshot/_global_state.py +++ b/src/inline_snapshot/_global_state.py @@ -15,6 +15,7 @@ from inline_snapshot._config import Config if TYPE_CHECKING: + from inline_snapshot._customize._custom import CustomizeHandler from inline_snapshot._external._format._protocol import Format from inline_snapshot._external._storage._protocol import StorageProtocol from inline_snapshot._types import SnapshotRefBase @@ -56,6 +57,8 @@ def new_tmp_path(self, suffix: str) -> Path: disable_reason: Literal["xdist", "ci", "implementation", None] = None + custom_functions: list[CustomizeHandler] = field(default_factory=list) + _latest_global_states: list[State] = [] @@ -72,6 +75,7 @@ def enter_snapshot_context(): latest = _current _latest_global_states.append(_current) _current = State() + _current.custom_functions = list(latest.custom_functions) _current.all_formats = dict(latest.all_formats) _current.config = deepcopy(latest.config) diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index 94e8d831..ab2e63ef 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -25,7 +25,7 @@ from inline_snapshot._customize import CustomUnmanaged from inline_snapshot._customize import CustomValue from inline_snapshot._exceptions import UsageError -from inline_snapshot._utils import value_to_token +from inline_snapshot._utils import map_strings from inline_snapshot.syntax_warnings import InlineSnapshotInfo from inline_snapshot.syntax_warnings import InlineSnapshotSyntaxWarning @@ -80,21 +80,10 @@ def warn_star_expression(node, context): def reeval(old_value: Custom, value: Custom) -> Custom: - - if isinstance(old_value, CustomDefault): - return reeval(old_value.value, value) - - if isinstance(value, CustomDefault): - return CustomDefault(value=reeval(old_value, value.value)) - - if type(old_value) is not type(value): - return CustomUnmanaged(value.eval()) - function_name = f"reeval_{type(old_value).__name__}" result = globals()[function_name](old_value, value) assert isinstance(result, Custom) - # assert result == value,(result,value) return result @@ -191,7 +180,7 @@ def compare_CustomValue( if old_node is None: new_token = [] else: - new_token = value_to_token(new_value.eval()) + new_token = map_strings(new_value.repr()) if ( isinstance(old_node, ast.JoinedStr) @@ -207,7 +196,7 @@ def compare_CustomValue( ) return old_value - if not old_value.eval() == new_value.eval(): + if not old_value.eval() == new_value.original_value: if isinstance(old_value, CustomUndefined): flag = "create" else: @@ -245,8 +234,6 @@ def compare_CustomSequence( ) assert isinstance(old_node, (ast.List, ast.Tuple)) - if warn_star_expression(old_node, self.context): - return old_value else: pass # pragma: no cover @@ -308,9 +295,6 @@ def compare_CustomDict( if old_node is not None: - if warn_star_expression(old_node, self.context): - return old_value - for value2, node in zip(old_value.value.keys(), old_node.keys): assert node is not None try: @@ -398,11 +382,6 @@ def compare_CustomCall( self, old_value: CustomCall, old_node: ast.Call, new_value: CustomCall ) -> Generator[Change, None, Custom]: - if old_node is not None: - # positional arguments - if warn_star_expression(old_node, self.context): - return old_value - call = new_value new_args = call.args new_kwargs = call.kwargs diff --git a/src/inline_snapshot/fix_pytest_diff.py b/src/inline_snapshot/fix_pytest_diff.py index e33ff366..48f23635 100644 --- a/src/inline_snapshot/fix_pytest_diff.py +++ b/src/inline_snapshot/fix_pytest_diff.py @@ -22,19 +22,6 @@ def _pprint_snapshot( PrettyPrinter._dispatch[GenericValue.__repr__] = _pprint_snapshot - # def _pprint_unmanaged( - # self, - # object: Any, - # stream: IO[str], - # indent: int, - # allowance: int, - # context: Set[int], - # level: int, - # ) -> None: - # self._format(object.value, stream, indent, allowance, context, level) - - # PrettyPrinter._dispatch[Unmanaged.__repr__] = _pprint_unmanaged - def _pprint_is( self, object: Any, diff --git a/tests/test_customize.py b/tests/test_customize.py new file mode 100644 index 00000000..bcf37512 --- /dev/null +++ b/tests/test_customize.py @@ -0,0 +1,47 @@ +import pytest + +from inline_snapshot import snapshot +from inline_snapshot.testing import Example + + +@pytest.mark.parametrize( + "original,flag", [("'a'", "update"), ("'b'", "fix"), ("", "create")] +) +def test_custom_dirty_equal(original, flag): + + Example( + f"""\ +from inline_snapshot import customize +from inline_snapshot import Builder +from inline_snapshot import snapshot +from dirty_equals import IsStr + +@customize +def re_handler(value, builder: Builder): + if value == IsStr(regex="[a-z]"): + return builder.create_call(IsStr, [], {{"regex": "[a-z]"}}) + +def test_a(): + assert snapshot({original}) == "a" +""" + ).run_inline( + [f"--inline-snapshot={flag}"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from inline_snapshot import customize +from inline_snapshot import Builder +from inline_snapshot import snapshot +from dirty_equals import IsStr + +@customize +def re_handler(value, builder: Builder): + if value == IsStr(regex="[a-z]"): + return builder.create_call(IsStr, [], {"regex": "[a-z]"}) + +def test_a(): + assert snapshot(IsStr(regex="[a-z]")) == "a" +""" + } + ), + ) From c89f8d1c104c98828708cede88c8a5faa51139f6 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sun, 21 Dec 2025 14:22:23 +0100 Subject: [PATCH 16/72] feat: implemented import support for created snapshots --- src/inline_snapshot/_change.py | 15 +++++ src/inline_snapshot/_customize/__init__.py | 64 ++++++++++++++++++- .../_external/_find_external.py | 42 +++++++++++- src/inline_snapshot/_inline_snapshot.py | 10 +++ src/inline_snapshot/_new_adapter.py | 55 ++++++++++++---- .../_snapshot/generic_value.py | 3 + src/inline_snapshot/testing/_example.py | 25 +++++--- tests/adapter/test_dataclass.py | 2 +- tests/conftest.py | 12 +++- tests/test_customize.py | 40 ++++++++++++ 10 files changed, 240 insertions(+), 28 deletions(-) diff --git a/src/inline_snapshot/_change.py b/src/inline_snapshot/_change.py index 33e2d438..253478a1 100644 --- a/src/inline_snapshot/_change.py +++ b/src/inline_snapshot/_change.py @@ -14,6 +14,7 @@ from executing.executing import EnhancedAST from inline_snapshot._external._external_location import Location +from inline_snapshot._external._find_external import ensure_import from inline_snapshot._source_file import SourceFile from ._rewrite_code import ChangeRecorder @@ -115,6 +116,11 @@ def apply_external_changes(self): pass +@dataclass() +class RequiredImports(Change): + imports: dict[str, list[str]] + + @dataclass() class Delete(Change): node: ast.AST | None @@ -251,6 +257,9 @@ def apply_all(all_changes: list[ChangeBase], recorder: ChangeRecorder): ) sources: dict[EnhancedAST, SourceFile] = {} + # file -> module -> names + imports_by_file = defaultdict(lambda: defaultdict(set)) + for change in all_changes: if isinstance(change, Delete): node = cast(EnhancedAST, change.node).parent @@ -263,9 +272,15 @@ def apply_all(all_changes: list[ChangeBase], recorder: ChangeRecorder): node = cast(EnhancedAST, change.node) by_parent[node].append(change) sources[node] = change.file + elif isinstance(change, RequiredImports): + for module, names in change.imports.items(): + imports_by_file[change.filename][module] |= set(names) else: change.apply(recorder) + for filename, imports in imports_by_file.items(): + ensure_import(filename, imports, recorder) + for parent, changes in by_parent.items(): source = sources[parent] diff --git a/src/inline_snapshot/_customize/__init__.py b/src/inline_snapshot/_customize/__init__.py index 29911c12..39f267fb 100644 --- a/src/inline_snapshot/_customize/__init__.py +++ b/src/inline_snapshot/_customize/__init__.py @@ -1,6 +1,7 @@ from __future__ import annotations import ast +import importlib from abc import ABC from abc import abstractmethod from collections import Counter @@ -107,6 +108,9 @@ def repr(self): def map(self, f): return self.value.map(f) + def _needed_imports(self): + yield from self.value._needed_imports() + @dataclass() class CustomUnmanaged(Custom): @@ -118,6 +122,9 @@ def repr(self): def map(self, f): return f(self.value) + def _needed_imports(self): + yield from () + class CustomUndefined(Custom): def __init__(self): @@ -129,6 +136,9 @@ def repr(self) -> str: def map(self, f): return f(undefined) + def _needed_imports(self): + yield from () + def unwrap_default(value): if isinstance(value, CustomDefault): @@ -178,6 +188,17 @@ def map(self, f): **{k: f(v.map(f)) for k, v in self.kwargs.items()}, ) + def _needed_imports(self): + yield from self._function._needed_imports() + for v in self._args: + yield from v._needed_imports() + + for v in self._kwargs.values(): + yield from v._needed_imports() + + for v in self._kwonly.values(): + yield from v._needed_imports() + class CustomSequenceTypes: trailing_comma: bool @@ -196,6 +217,10 @@ def repr(self) -> str: trailing_comma = self.trailing_comma and len(self.value) == 1 return f"{self.braces[0]}{', '.join(v.repr() for v in self.value)}{', ' if trailing_comma else ''}{self.braces[1]}" + def _needed_imports(self): + for v in self.value: + yield from v._needed_imports() + class CustomList(CustomSequence): node_type = ast.List @@ -224,6 +249,11 @@ def repr(self) -> str: f"{{{ ', '.join(f'{k.repr()}: {v.repr()}' for k,v in self.value.items())}}}" ) + def _needed_imports(self): + for k, v in self.value.items(): + yield from k._needed_imports() + yield from v._needed_imports() + class CustomValue(Custom): def __init__(self, value, repr_str=None): @@ -236,6 +266,9 @@ def __init__(self, value, repr_str=None): self.repr_str = repr_str self.value = value + self._imports = defaultdict(list) + + super().__init__() def map(self, f): return f(self.value) @@ -246,6 +279,27 @@ def repr(self) -> str: def __repr__(self): return f"CustomValue({self.repr_str})" + def _needed_imports(self): + yield from self._imports.items() + + def with_import(self, module, name, simplify=True): + value = getattr(importlib.import_module(module), name) + if simplify: + parts = module.split(".") + while len(parts) >= 2: + if ( + getattr(importlib.import_module(".".join(parts[:-1])), name, None) + == value + ): + parts.pop() + else: + break + module = ".".join(parts) + + self._imports[module].append(name) + + return self + @customize def standard_handler(value, builder: Builder): @@ -268,7 +322,9 @@ def counter_handler(value, builder: Builder): @customize def function_handler(value, builder: Builder): if isinstance(value, FunctionType): - return builder.create_value(value, value.__qualname__) + qualname = value.__qualname__ + name = qualname.split(".")[0] + return builder.create_value(value, qualname).with_import(value.__module__, name) @customize @@ -280,7 +336,9 @@ def builtin_function_handler(value, builder: Builder): @customize def type_handler(value, builder: Builder): if isinstance(value, type): - return builder.create_value(value, value.__qualname__) + qualname = value.__qualname__ + name = qualname.split(".")[0] + return builder.create_value(value, qualname).with_import(value.__module__, name) @customize @@ -586,7 +644,7 @@ def create_dict(self, value) -> Custom: custom = {self._get_handler(k): self._get_handler(v) for k, v in value.items()} return CustomDict(value=custom) - def create_value(self, value, repr) -> CustomValue: + def create_value(self, value, repr: Optional[str] = None) -> CustomValue: """ Creates an intermediate node for a value with a custom representation which can be used as a result for your customization function. diff --git a/src/inline_snapshot/_external/_find_external.py b/src/inline_snapshot/_external/_find_external.py index 5246e72c..e00c2eb1 100644 --- a/src/inline_snapshot/_external/_find_external.py +++ b/src/inline_snapshot/_external/_find_external.py @@ -1,7 +1,9 @@ import ast +import os from dataclasses import replace from pathlib import Path from typing import List +from typing import Optional from typing import Union from executing import Source @@ -64,6 +66,38 @@ def used_externals_in( return usages +def module_name_of(filename: str | os.PathLike) -> Optional[str]: + path = Path(filename).resolve() + + if path.suffix != ".py": + return None + + parts = [] + + if path.name != "__init__.py": + parts.append(path.stem) + + current = path.parent + + while current != current.root: + if not (current / "__init__.py").exists(): + break + + parts.append(current.name) + + next_parent = current.parent + if next_parent == current: + break + current = next_parent + + parts.reverse() + + if not parts: + return None + + return ".".join(parts) + + def ensure_import(filename, imports, recorder: ChangeRecorder): source = Source.for_filename(filename) @@ -74,8 +108,14 @@ def ensure_import(filename, imports, recorder: ChangeRecorder): to_add = [] + my_module = module_name_of(filename) + for module, names in imports.items(): - for name in names: + if module == my_module: + continue + if module == "builtins": + continue + for name in sorted(names): if not contains_import(tree, module, name): to_add.append((module, name)) diff --git a/src/inline_snapshot/_inline_snapshot.py b/src/inline_snapshot/_inline_snapshot.py index 6f23dd2c..82137525 100644 --- a/src/inline_snapshot/_inline_snapshot.py +++ b/src/inline_snapshot/_inline_snapshot.py @@ -1,5 +1,6 @@ import ast import inspect +from collections import defaultdict from typing import Any from typing import Iterator from typing import TypeVar @@ -15,6 +16,7 @@ from ._change import CallArg from ._change import Change +from ._change import RequiredImports from ._global_state import state from ._sentinels import undefined from ._snapshot.undecided_value import UndecidedValue @@ -149,6 +151,14 @@ def _changes(self) -> Iterator[Change]: new_value=self._value._new_value, ) + imports = defaultdict(set) + for module, names in self._value._needed_imports(): + imports[module] |= set(names) + + yield RequiredImports( + flag="create", file=self._value._file, imports=imports + ) + else: yield from self._value._get_changes() diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index ab2e63ef..e28a3fb8 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -14,6 +14,7 @@ from inline_snapshot._change import DictInsert from inline_snapshot._change import ListInsert from inline_snapshot._change import Replace +from inline_snapshot._change import RequiredImports from inline_snapshot._compare_context import compare_context from inline_snapshot._customize import Custom from inline_snapshot._customize import CustomCall @@ -222,6 +223,15 @@ def compare_CustomValue( new_value=new_value, ) + def needed_imports(value: Custom): + imports = defaultdict(set) + for module, names in value._needed_imports(): + imports[module] |= set(names) + return imports + + if imports := needed_imports(new_value): + yield RequiredImports(flag, self.context.file._source, imports) + return new_value def compare_CustomSequence( @@ -390,6 +400,23 @@ def compare_CustomCall( result_args = [] + flag = "update" if old_value.eval() == new_value.original_value else "fix" + + if flag == "update": + + def intercept(stream): + while True: + try: + change = next(stream) + if change.flag == "fix": + change.flag = "update" + yield change + except StopIteration as stop: + return stop.value + + else: + intercept = lambda a: a + old_node_args: Sequence[ast.expr | None] if old_node: old_node_args = old_node.args @@ -398,7 +425,9 @@ def compare_CustomCall( for i, (new_value_element, node) in enumerate(zip(new_args, old_node_args)): old_value_element = old_value.argument(i) - result = yield from self.compare(old_value_element, node, new_value_element) + result = yield from intercept( + self.compare(old_value_element, node, new_value_element) + ) result_args.append(result) old_args_len = len(old_node.args if old_node else old_value.args) @@ -407,7 +436,7 @@ def compare_CustomCall( if old_args_len > len(new_args): for arg_pos, node in list(enumerate(old_node.args))[len(new_args) :]: yield Delete( - "fix", + flag, self.context.file._source, node, old_value.argument(arg_pos), @@ -416,7 +445,7 @@ def compare_CustomCall( if old_args_len < len(new_args): for insert_pos, value in list(enumerate(new_args))[old_args_len:]: yield CallArg( - flag="fix", + flag=flag, file=self.context.file._source, node=old_node, arg_pos=insert_pos, @@ -441,7 +470,7 @@ def compare_CustomCall( ( "update" if old_value.argument(kw_arg) == new_value.argument(kw_arg) - else "fix" + else flag ), self.context.file._source, kw_value, @@ -462,15 +491,15 @@ def compare_CustomCall( # check values with same keys old_value_element = old_value.argument(key) - result_kwargs[key] = yield from self.compare( - old_value_element, node, new_value_element + result_kwargs[key] = yield from intercept( + self.compare(old_value_element, node, new_value_element) ) if to_insert: for key, value in to_insert: yield CallArg( - flag="fix", + flag=flag, file=self.context.file._source, node=old_node, arg_pos=insert_pos, @@ -487,7 +516,7 @@ def compare_CustomCall( for key, value in to_insert: yield CallArg( - flag="fix", + flag=flag, file=self.context.file._source, node=old_node, arg_pos=insert_pos, @@ -498,10 +527,12 @@ def compare_CustomCall( return CustomCall( ( - yield from self.compare( - old_value._function, - old_node.func if old_node else None, - new_value._function, + yield from intercept( + self.compare( + old_value._function, + old_node.func if old_node else None, + new_value._function, + ) ) ), result_args, diff --git a/src/inline_snapshot/_snapshot/generic_value.py b/src/inline_snapshot/_snapshot/generic_value.py index b55706ab..00a91184 100644 --- a/src/inline_snapshot/_snapshot/generic_value.py +++ b/src/inline_snapshot/_snapshot/generic_value.py @@ -68,6 +68,9 @@ def _ignore_old(self): or isinstance(self._old_value, CustomUndefined) ) + def _needed_imports(self): + yield from self._new_value._needed_imports() + def _visible_value(self): if self._ignore_old(): return self._new_value diff --git a/src/inline_snapshot/testing/_example.py b/src/inline_snapshot/testing/_example.py index 53c90d89..135f0648 100644 --- a/src/inline_snapshot/testing/_example.py +++ b/src/inline_snapshot/testing/_example.py @@ -1,12 +1,12 @@ from __future__ import annotations +import importlib.util import os import platform import random import re import subprocess as sp import sys -import tokenize import traceback import uuid from argparse import ArgumentParser @@ -14,7 +14,6 @@ from io import StringIO from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any from typing import Callable from unittest.mock import patch @@ -319,6 +318,7 @@ def report_error(message): raise StopTesting(message) snapshot_flags = set() + old_modules = sys.modules try: enter_snapshot_context() session.load_config( @@ -334,19 +334,23 @@ def report_error(message): tests_found = False for filename in tmp_path.rglob("test_*.py"): - globals: dict[str, Any] = {} - print("run> pytest-inline", filename) - with tokenize.open(filename) as f: - code = f.read() - exec( - compile(code, filename, "exec"), - globals, + print("run> pytest-inline", filename, *args) + + # Load module using importlib + spec = importlib.util.spec_from_file_location( + filename.stem, filename ) + if spec and spec.loader: + module = importlib.util.module_from_spec(spec) + sys.modules[filename.stem] = module + spec.loader.exec_module(module) + else: + raise UsageError(f"Could not load module from {filename}") # run all test_* functions tests = [ v - for k, v in globals.items() + for k, v in module.__dict__.items() if (k.startswith("test_") or k == "test") and callable(v) ] tests_found |= len(tests) != 0 @@ -378,6 +382,7 @@ def fail(message): except StopTesting as e: assert stderr == f"ERROR: {e}\n" finally: + sys.modules = old_modules leave_snapshot_context() if reported_categories is not None: diff --git a/tests/adapter/test_dataclass.py b/tests/adapter/test_dataclass.py index 3ebbccf1..b54297c0 100644 --- a/tests/adapter/test_dataclass.py +++ b/tests/adapter/test_dataclass.py @@ -153,7 +153,7 @@ class A: def test_something(): for _ in [1,2]: - assert A(a=1) == snapshot(A(1,2)) + assert A(a=1) == snapshot(A(a=1)) """ } ), diff --git a/tests/conftest.py b/tests/conftest.py index c1524e2f..cd50b575 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import importlib.util import os import platform import re @@ -113,7 +114,16 @@ def run(self, *flags_arg: Category): error = False try: - exec(compile(filename.read_text("utf-8"), filename, "exec"), {}) + # Load module using importlib + spec = importlib.util.spec_from_file_location( + filename.stem, filename + ) + if spec and spec.loader: + module = importlib.util.module_from_spec(spec) + sys.modules[filename.stem] = module + spec.loader.exec_module(module) + else: + raise RuntimeError(f"Could not load module from {filename}") except AssertionError: traceback.print_exc() error = True diff --git a/tests/test_customize.py b/tests/test_customize.py index bcf37512..e6dfb8db 100644 --- a/tests/test_customize.py +++ b/tests/test_customize.py @@ -45,3 +45,43 @@ def test_a(): } ), ) + + +@pytest.mark.parametrize( + "original,flag", + [("{'1': 1, '2': 2}", "update"), ("5", "fix"), ("", "create")], +) +def test_create_imports(original, flag): + + Example( + { + "tests/test_something.py": f"""\ +from inline_snapshot import snapshot + +def counter(): + from collections import Counter + return Counter("122") + +def test(): + assert counter() == snapshot({original}) +""" + } + ).run_inline( + [f"--inline-snapshot={flag}"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from inline_snapshot import snapshot + +from collections import Counter + +def counter(): + from collections import Counter + return Counter("122") + +def test(): + assert counter() == snapshot(Counter({"1": 1, "2": 2})) +""" + } + ), + ) From b031d574cdea8cfa3ce5ce198f4e2da0bc64d45c Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sun, 21 Dec 2025 15:10:45 +0100 Subject: [PATCH 17/72] refactor: removed special import generation for HasRepr and external --- src/inline_snapshot/_change.py | 4 +-- src/inline_snapshot/_code_repr.py | 17 ------------ src/inline_snapshot/_customize/__init__.py | 23 ++++++++++++++-- src/inline_snapshot/_customize/_custom.py | 4 +++ src/inline_snapshot/_inline_snapshot.py | 2 +- src/inline_snapshot/_new_adapter.py | 2 +- src/inline_snapshot/_snapshot_session.py | 24 ---------------- tests/test_code_repr.py | 32 ++++++++++++++-------- 8 files changed, 50 insertions(+), 58 deletions(-) diff --git a/src/inline_snapshot/_change.py b/src/inline_snapshot/_change.py index 253478a1..41e97c55 100644 --- a/src/inline_snapshot/_change.py +++ b/src/inline_snapshot/_change.py @@ -118,7 +118,7 @@ def apply_external_changes(self): @dataclass() class RequiredImports(Change): - imports: dict[str, list[str]] + imports: dict[str, set[str]] @dataclass() @@ -258,7 +258,7 @@ def apply_all(all_changes: list[ChangeBase], recorder: ChangeRecorder): sources: dict[EnhancedAST, SourceFile] = {} # file -> module -> names - imports_by_file = defaultdict(lambda: defaultdict(set)) + imports_by_file: dict[str, dict[str, set]] = defaultdict(lambda: defaultdict(set)) for change in all_changes: if isinstance(change, Delete): diff --git a/src/inline_snapshot/_code_repr.py b/src/inline_snapshot/_code_repr.py index 890e0da7..03c43599 100644 --- a/src/inline_snapshot/_code_repr.py +++ b/src/inline_snapshot/_code_repr.py @@ -1,4 +1,3 @@ -import ast from enum import Enum from enum import Flag from functools import singledispatch @@ -37,17 +36,6 @@ def __eq__(self, other): return other_repr == self._str_repr or other_repr == repr(self) -def used_hasrepr(tree): - return [ - n - for n in ast.walk(tree) - if isinstance(n, ast.Call) - and isinstance(n.func, ast.Name) - and n.func.id == "HasRepr" - and len(n.args) == 2 - ] - - @singledispatch def code_repr_dispatch(value): return real_repr(value) @@ -94,11 +82,6 @@ def value_code_repr(obj): result = code_repr_dispatch(obj) - try: - ast.parse(result) - except SyntaxError: - return real_repr(HasRepr(type(obj), result)) - return result diff --git a/src/inline_snapshot/_customize/__init__.py b/src/inline_snapshot/_customize/__init__.py index 39f267fb..dbc28774 100644 --- a/src/inline_snapshot/_customize/__init__.py +++ b/src/inline_snapshot/_customize/__init__.py @@ -18,8 +18,10 @@ from types import FunctionType from typing import Any from typing import Callable +from typing import Optional from typing import TypeAlias +from inline_snapshot._code_repr import HasRepr from inline_snapshot._code_repr import value_code_repr from inline_snapshot._customize._custom import CustomizeHandler from inline_snapshot._sentinels import undefined @@ -259,14 +261,20 @@ class CustomValue(Custom): def __init__(self, value, repr_str=None): assert not isinstance(value, Custom) value = clone(value) + self._imports = defaultdict(list) if repr_str is None: self.repr_str = value_code_repr(value) + + try: + ast.parse(self.repr_str) + except SyntaxError: + self.repr_str = HasRepr(type(value), self.repr_str).__repr__() + self.with_import("inline_snapshot", "HasRepr") else: self.repr_str = repr_str self.value = value - self._imports = defaultdict(list) super().__init__() @@ -561,9 +569,20 @@ def dirty_equals_handler(value, builder: Builder): if isinstance(value, type): return builder.create_value(value, value.__name__) else: + # TODO: args return builder.create_call(type(value)) +@customize +def outsourced_handler(value, builder: Builder): + from inline_snapshot._external._outsource import Outsourced + + if isinstance(value, Outsourced): + return builder.create_value(value, repr(value)).with_import( + "inline_snapshot", "external" + ) + + @dataclass class Builder: _build_new_value: bool = False @@ -644,7 +663,7 @@ def create_dict(self, value) -> Custom: custom = {self._get_handler(k): self._get_handler(v) for k, v in value.items()} return CustomDict(value=custom) - def create_value(self, value, repr: Optional[str] = None) -> CustomValue: + def create_value(self, value, repr: str | None = None) -> CustomValue: """ Creates an intermediate node for a value with a custom representation which can be used as a result for your customization function. diff --git a/src/inline_snapshot/_customize/_custom.py b/src/inline_snapshot/_customize/_custom.py index 5ebe55a2..39fa880e 100644 --- a/src/inline_snapshot/_customize/_custom.py +++ b/src/inline_snapshot/_customize/_custom.py @@ -34,6 +34,10 @@ def repr(self): def eval(self): return self.map(lambda a: a) + @abstractmethod + def _needed_imports(self): + raise NotImplementedError() + CustomizeHandler: TypeAlias = Callable[[Any, "Builder"], Custom | None] """ diff --git a/src/inline_snapshot/_inline_snapshot.py b/src/inline_snapshot/_inline_snapshot.py index 82137525..9aa1966c 100644 --- a/src/inline_snapshot/_inline_snapshot.py +++ b/src/inline_snapshot/_inline_snapshot.py @@ -151,7 +151,7 @@ def _changes(self) -> Iterator[Change]: new_value=self._value._new_value, ) - imports = defaultdict(set) + imports: dict[str, set[str]] = defaultdict(set) for module, names in self._value._needed_imports(): imports[module] |= set(names) diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index e28a3fb8..e7daeed9 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -224,7 +224,7 @@ def compare_CustomValue( ) def needed_imports(value: Custom): - imports = defaultdict(set) + imports: dict[str, set] = defaultdict(set) for module, names in value._needed_imports(): imports[module] |= set(names) return imports diff --git a/src/inline_snapshot/_snapshot_session.py b/src/inline_snapshot/_snapshot_session.py index a103c718..09456764 100644 --- a/src/inline_snapshot/_snapshot_session.py +++ b/src/inline_snapshot/_snapshot_session.py @@ -1,4 +1,3 @@ -import ast import os import sys import tokenize @@ -21,8 +20,6 @@ from ._change import ChangeBase from ._change import ExternalRemove from ._change import apply_all -from ._code_repr import used_hasrepr -from ._external._find_external import ensure_import from ._external._find_external import used_externals_in from ._flags import Flags from ._global_state import state @@ -454,25 +451,4 @@ def console(): for change in used_changes: change.apply_external_changes() - for test_file in cr.files(): - tree = ast.parse(test_file.new_code()) - used_externals = used_externals_in( - test_file.filename, tree, check_import=False - ) - - required_imports = [] - - if used_externals: - required_imports.append("external") - - if used_hasrepr(tree): - required_imports.append("HasRepr") - - if required_imports: - ensure_import( - test_file.filename, - {"inline_snapshot": required_imports}, - cr, - ) - cr.fix_all() diff --git a/tests/test_code_repr.py b/tests/test_code_repr.py index 20f1c80f..169ab440 100644 --- a/tests/test_code_repr.py +++ b/tests/test_code_repr.py @@ -407,9 +407,11 @@ def __repr__(self): def test_invalid_repr(check_update): - assert ( - check_update( - """\ + + Example( + """\ +from inline_snapshot import snapshot + class Thing: def __repr__(self): return "+++" @@ -419,12 +421,18 @@ def __eq__(self,other): return NotImplemented return True -assert Thing() == snapshot() -""", - flags="create", - ) - == snapshot( - """\ +def test_a(): + assert Thing() == snapshot() +""" + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from inline_snapshot import snapshot + +from inline_snapshot import HasRepr + class Thing: def __repr__(self): return "+++" @@ -434,9 +442,11 @@ def __eq__(self,other): return NotImplemented return True -assert Thing() == snapshot(HasRepr(Thing, "+++")) +def test_a(): + assert Thing() == snapshot(HasRepr(Thing, "+++")) """ - ) + } + ), ) From cc2f1c4a71f03a2ac0a5052038885416eb7de5e9 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Tue, 23 Dec 2025 14:26:38 +0100 Subject: [PATCH 18/72] feat: support for local_vars and global_vars in customize --- src/inline_snapshot/_code_repr.py | 2 +- src/inline_snapshot/_customize/__init__.py | 64 ++++++++++++++++--- .../_snapshot/collection_value.py | 5 +- src/inline_snapshot/_snapshot/dict_value.py | 3 +- src/inline_snapshot/_snapshot/eq_value.py | 3 +- .../_snapshot/generic_value.py | 5 +- .../_snapshot/min_max_value.py | 5 +- .../_snapshot/undecided_value.py | 11 ++-- 8 files changed, 73 insertions(+), 25 deletions(-) diff --git a/src/inline_snapshot/_code_repr.py b/src/inline_snapshot/_code_repr.py index 03c43599..837cda27 100644 --- a/src/inline_snapshot/_code_repr.py +++ b/src/inline_snapshot/_code_repr.py @@ -69,7 +69,7 @@ def code_repr(obj): def mocked_code_repr(obj): from inline_snapshot._customize import Builder - return Builder()._get_handler(obj).repr() + return Builder(_snapshot_context=None)._get_handler(obj).repr() def value_code_repr(obj): diff --git a/src/inline_snapshot/_customize/__init__.py b/src/inline_snapshot/_customize/__init__.py index dbc28774..cbc42e7f 100644 --- a/src/inline_snapshot/_customize/__init__.py +++ b/src/inline_snapshot/_customize/__init__.py @@ -21,9 +21,13 @@ from typing import Optional from typing import TypeAlias +from inline_snapshot._adapter_context import AdapterContext from inline_snapshot._code_repr import HasRepr from inline_snapshot._code_repr import value_code_repr +from inline_snapshot._compare_context import compare_context +from inline_snapshot._compare_context import compare_only from inline_snapshot._customize._custom import CustomizeHandler +from inline_snapshot._partial_call import partial_call from inline_snapshot._sentinels import undefined from inline_snapshot._unmanaged import is_dirty_equal from inline_snapshot._unmanaged import is_unmanaged @@ -583,25 +587,67 @@ def outsourced_handler(value, builder: Builder): ) +@dataclass +class ContextValue: + name: str + value: Any + + +@customize +def context_value_handler(value, builder: Builder): + if isinstance(value, ContextValue): + return builder.create_value(value.value, value.name) + + @dataclass class Builder: + _snapshot_context: AdapterContext _build_new_value: bool = False def _get_handler(self, v) -> Custom: - if isinstance(v, Custom): - return v from inline_snapshot._global_state import state - for f in reversed(state().custom_functions): - r = f(v, self) - if isinstance(r, Custom): - break + if ( + self._snapshot_context is not None + and (frame := self._snapshot_context.frame) is not None + ): + local_vars = [ + ContextValue(var_name, var_value) + for var_name, var_value in frame.locals.items() + if "@" not in var_name + ] + global_vars = [ + ContextValue(var_name, var_value) + for var_name, var_value in frame.globals.items() + if "@" not in var_name + ] else: - r = CustomValue(v) + local_vars = {} + global_vars = {} + + result = v + + while not isinstance(result, Custom): + for f in reversed(state().custom_functions): + with compare_context(): + r = partial_call( + f, + { + "value": result, + "builder": self, + "local_vars": local_vars, + "global_vars": global_vars, + }, + ) + if r is not None: + result = r + break + else: + result = CustomValue(result) - r.__dict__["original_value"] = v - return r + result.__dict__["original_value"] = v + return result def create_list(self, value) -> Custom: """ diff --git a/src/inline_snapshot/_snapshot/collection_value.py b/src/inline_snapshot/_snapshot/collection_value.py index 0f4d1406..9d168600 100644 --- a/src/inline_snapshot/_snapshot/collection_value.py +++ b/src/inline_snapshot/_snapshot/collection_value.py @@ -1,7 +1,6 @@ import ast from typing import Iterator -from inline_snapshot._customize import Builder from inline_snapshot._customize import CustomList from inline_snapshot._customize import CustomUndefined @@ -23,10 +22,10 @@ def __contains__(self, item): state().missing_values += 1 if isinstance(self._new_value, CustomUndefined): - self._new_value = CustomList([Builder()._get_handler(item)]) + self._new_value = CustomList([self.get_builder()._get_handler(item)]) else: if item not in self._new_value.eval(): - self._new_value.value.append(Builder()._get_handler(item)) + self._new_value.value.append(self.get_builder()._get_handler(item)) if ignore_old_value() or isinstance(self._old_value, CustomUndefined): return True diff --git a/src/inline_snapshot/_snapshot/dict_value.py b/src/inline_snapshot/_snapshot/dict_value.py index 7e2e4b00..b3d60264 100644 --- a/src/inline_snapshot/_snapshot/dict_value.py +++ b/src/inline_snapshot/_snapshot/dict_value.py @@ -1,7 +1,6 @@ import ast from typing import Iterator -from inline_snapshot._customize import Builder from inline_snapshot._customize import CustomDict from inline_snapshot._customize import CustomUndefined @@ -24,7 +23,7 @@ def __getitem__(self, index): if isinstance(self._new_value, CustomUndefined): self._new_value = CustomDict({}) - index = Builder()._get_handler(index) + index = self.get_builder()._get_handler(index) if index not in self._new_value.value: if isinstance(self._old_value, CustomUndefined): diff --git a/src/inline_snapshot/_snapshot/eq_value.py b/src/inline_snapshot/_snapshot/eq_value.py index 6971e69e..b168f641 100644 --- a/src/inline_snapshot/_snapshot/eq_value.py +++ b/src/inline_snapshot/_snapshot/eq_value.py @@ -1,7 +1,6 @@ from typing import Iterator from typing import List -from inline_snapshot._customize import Builder from inline_snapshot._customize import CustomUndefined from inline_snapshot._new_adapter import NewAdapter from inline_snapshot._utils import map_strings @@ -17,7 +16,7 @@ class EqValue(GenericValue): _changes: List[Change] def __eq__(self, other): - custom_other = Builder(_build_new_value=True)._get_handler(other) + custom_other = self.get_builder(_build_new_value=True)._get_handler(other) if isinstance(self._old_value, CustomUndefined): state().missing_values += 1 diff --git a/src/inline_snapshot/_snapshot/generic_value.py b/src/inline_snapshot/_snapshot/generic_value.py index 00a91184..31d0d3ca 100644 --- a/src/inline_snapshot/_snapshot/generic_value.py +++ b/src/inline_snapshot/_snapshot/generic_value.py @@ -25,6 +25,9 @@ class GenericValue(SnapshotBase): _ast_node: ast.Expr _context: AdapterContext + def get_builder(self, **args): + return Builder(_snapshot_context=self._context, **args) + def _return(self, result, new_result=True): if not result: @@ -49,7 +52,7 @@ def value_to_custom(self, value): return value if self._ast_node is None: - return Builder()._get_handler(value) + return self.get_builder()._get_handler(value) else: from inline_snapshot._snapshot.undecided_value import AstToCustom diff --git a/src/inline_snapshot/_snapshot/min_max_value.py b/src/inline_snapshot/_snapshot/min_max_value.py index b59ef986..21ace18f 100644 --- a/src/inline_snapshot/_snapshot/min_max_value.py +++ b/src/inline_snapshot/_snapshot/min_max_value.py @@ -1,6 +1,5 @@ from typing import Iterator -from inline_snapshot._customize import Builder from inline_snapshot._customize import CustomUndefined from .._change import Change @@ -23,13 +22,13 @@ def _generic_cmp(self, other): state().missing_values += 1 if isinstance(self._new_value, CustomUndefined): - self._new_value = Builder()._get_handler(other) + self._new_value = self.get_builder()._get_handler(other) if isinstance(self._old_value, CustomUndefined) or ignore_old_value(): return True return self._return(self.cmp(self._old_value.eval(), other)) else: if not self.cmp(self._new_value.eval(), other): - self._new_value = Builder()._get_handler(other) + self._new_value = self.get_builder()._get_handler(other) return self._return(self.cmp(self._visible_value().eval(), other)) diff --git a/src/inline_snapshot/_snapshot/undecided_value.py b/src/inline_snapshot/_snapshot/undecided_value.py index 89b38637..8bbe4027 100644 --- a/src/inline_snapshot/_snapshot/undecided_value.py +++ b/src/inline_snapshot/_snapshot/undecided_value.py @@ -2,7 +2,7 @@ from typing import Any from typing import Iterator -from inline_snapshot._customize import Builder +from inline_snapshot._compare_context import compare_only from inline_snapshot._customize import Custom from inline_snapshot._customize import CustomCall from inline_snapshot._customize import CustomDict @@ -71,18 +71,18 @@ def convert_Dict(self, value: dict, node: ast.Dict): class UndecidedValue(GenericValue): def __init__(self, old_value, ast_node, context: AdapterContext): + self._context = context if not isinstance(old_value, Custom): if ast_node is not None: old_value = AstToCustom(context).convert(old_value, ast_node) else: - old_value = Builder()._get_handler(old_value) + old_value = self.get_builder()._get_handler(old_value) assert isinstance(old_value, Custom) self._old_value = old_value self._new_value = CustomUndefined() self._ast_node = ast_node - self._context = context def _change(self, cls): self.__class__ = cls @@ -93,7 +93,7 @@ def _new_code(self): def _get_changes(self) -> Iterator[Change]: assert isinstance(self._new_value, CustomUndefined) - new_value = Builder()._get_handler(self._old_value.eval()) + new_value = self.get_builder()._get_handler(self._old_value.eval()) adapter = NewAdapter(self._context) @@ -128,6 +128,9 @@ def _get_changes(self) -> Iterator[Change]: # functions which determine the type def __eq__(self, other): + if compare_only(): + return False + from .._snapshot.eq_value import EqValue self._change(EqValue) From 8be6a497e07e17c89f4cbed72c8be25d92a349f1 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Tue, 23 Dec 2025 19:06:59 +0100 Subject: [PATCH 19/72] fix: support dirty-equals expressions with arguments --- src/inline_snapshot/_customize/__init__.py | 48 ++++++++++++++++++---- src/inline_snapshot/_global_state.py | 9 +++- src/inline_snapshot/_partial_call.py | 26 ++++++++++++ src/inline_snapshot/_unmanaged.py | 5 +-- 4 files changed, 74 insertions(+), 14 deletions(-) create mode 100644 src/inline_snapshot/_partial_call.py diff --git a/src/inline_snapshot/_customize/__init__.py b/src/inline_snapshot/_customize/__init__.py index cbc42e7f..b770b7d9 100644 --- a/src/inline_snapshot/_customize/__init__.py +++ b/src/inline_snapshot/_customize/__init__.py @@ -12,6 +12,7 @@ from dataclasses import fields from dataclasses import is_dataclass from dataclasses import replace +from functools import partial from pathlib import Path from pathlib import PurePath from types import BuiltinFunctionType @@ -20,6 +21,7 @@ from typing import Callable from typing import Optional from typing import TypeAlias +from typing import overload from inline_snapshot._adapter_context import AdapterContext from inline_snapshot._code_repr import HasRepr @@ -37,7 +39,19 @@ from ._custom import CustomizeHandler -def customize(f: CustomizeHandler) -> CustomizeHandler: +@overload +def customize( + f: None = None, *, priority: int = 0 +) -> Callable[[CustomizeHandler], CustomizeHandler]: ... + + +@overload +def customize(f: CustomizeHandler, *, priority: int = 0) -> CustomizeHandler: ... + + +def customize( + f: CustomizeHandler | None = None, *, priority: int = 0 +) -> CustomizeHandler | Callable[[CustomizeHandler], CustomizeHandler]: """ Registers a function as a customization hook inside inline-snapshot. @@ -97,9 +111,13 @@ def test_myclass(): - [Builder][inline_snapshot._customize.Builder]: Available builder methods - [Custom][inline_snapshot._customize.Custom]: Base class for custom representations """ + + if f is None: + return partial(customize, priority=priority) # type: ignore[return-value] + from inline_snapshot._global_state import state - state().custom_functions.append(f) + state().custom_functions[priority].append(f) return f @@ -567,14 +585,20 @@ def undefined_handler(value, builder: Builder): return CustomUndefined() -@customize +@customize(priority=1000) def dirty_equals_handler(value, builder: Builder): + if is_dirty_equal(value) and builder._build_new_value: if isinstance(value, type): - return builder.create_value(value, value.__name__) + return builder.create_value(value, value.__name__).with_import( + "dirty_equals", value.__name__ + ) else: - # TODO: args - return builder.create_call(type(value)) + from dirty_equals._utils import Omit + + args = [a for a in value._repr_args if a is not Omit] + kwargs = {k: a for k, a in value._repr_kwargs.items() if a is not Omit} + return builder.create_call(type(value), args, kwargs) @customize @@ -623,13 +647,19 @@ def _get_handler(self, v) -> Custom: if "@" not in var_name ] else: - local_vars = {} - global_vars = {} + local_vars = [] + global_vars = [] result = v + custom_functions = [ + f + for _, function_list in sorted(state().custom_functions.items()) + for f in function_list + ] + while not isinstance(result, Custom): - for f in reversed(state().custom_functions): + for f in reversed(custom_functions): with compare_context(): r = partial_call( f, diff --git a/src/inline_snapshot/_global_state.py b/src/inline_snapshot/_global_state.py index 9e535c73..2bac6202 100644 --- a/src/inline_snapshot/_global_state.py +++ b/src/inline_snapshot/_global_state.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +from collections import defaultdict from copy import deepcopy from dataclasses import dataclass from dataclasses import field @@ -57,7 +58,9 @@ def new_tmp_path(self, suffix: str) -> Path: disable_reason: Literal["xdist", "ci", "implementation", None] = None - custom_functions: list[CustomizeHandler] = field(default_factory=list) + custom_functions: dict[int, list[CustomizeHandler]] = field( + default_factory=lambda: defaultdict(list) + ) _latest_global_states: list[State] = [] @@ -75,7 +78,9 @@ def enter_snapshot_context(): latest = _current _latest_global_states.append(_current) _current = State() - _current.custom_functions = list(latest.custom_functions) + _current.custom_functions = defaultdict( + list, {k: list(v) for k, v in latest.custom_functions.items()} + ) _current.all_formats = dict(latest.all_formats) _current.config = deepcopy(latest.config) diff --git a/src/inline_snapshot/_partial_call.py b/src/inline_snapshot/_partial_call.py new file mode 100644 index 00000000..dbdc623b --- /dev/null +++ b/src/inline_snapshot/_partial_call.py @@ -0,0 +1,26 @@ +import inspect + +from inline_snapshot._exceptions import UsageError + + +def check_args(func, allowed): + sign = inspect.signature(func) + for p in sign.parameters.values(): + if p.default is not inspect.Parameter.empty: + raise UsageError(f"`{p.name}` has a default value which is not supported") + + if p.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD: + raise UsageError( + f"`{p.name}` is not a positional or keyword parameter, which is not supported" + ) + + if p.name not in allowed: + raise UsageError( + f"`{p.name}` is an unknown parameter. allowed are {allowed}" + ) + + +def partial_call(func, args): + sign = inspect.signature(func) + used = [p.name for p in sign.parameters.values()] + return func(**{n: args[n] for n in used}) diff --git a/src/inline_snapshot/_unmanaged.py b/src/inline_snapshot/_unmanaged.py index 9d15e626..63065770 100644 --- a/src/inline_snapshot/_unmanaged.py +++ b/src/inline_snapshot/_unmanaged.py @@ -8,9 +8,8 @@ def is_dirty_equal(value): else: def is_dirty_equal(value): - return isinstance(value, dirty_equals.DirtyEquals) or ( - isinstance(value, type) and issubclass(value, dirty_equals.DirtyEquals) - ) + t = value if isinstance(value, type) else type(value) + return any(x is dirty_equals.DirtyEquals for x in t.__mro__) def update_allowed(value): From adfb4ecf573c4f49f63f4e1dab9bc686471389d1 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sun, 28 Dec 2025 20:32:17 +0100 Subject: [PATCH 20/72] refactor: reimplemented outsource with @customize --- src/inline_snapshot/_code_repr.py | 19 ++- src/inline_snapshot/_customize/__init__.py | 107 +++++++++++--- src/inline_snapshot/_customize/_custom.py | 6 +- src/inline_snapshot/_external/_external.py | 3 + .../_external/_external_base.py | 10 +- src/inline_snapshot/_external/_outsource.py | 66 ++++++--- src/inline_snapshot/_generator_utils.py | 17 +++ src/inline_snapshot/_get_snapshot_value.py | 3 + src/inline_snapshot/_inline_snapshot.py | 6 +- src/inline_snapshot/_new_adapter.py | 30 ++-- .../_snapshot/collection_value.py | 9 +- src/inline_snapshot/_snapshot/dict_value.py | 25 ++-- src/inline_snapshot/_snapshot/eq_value.py | 7 +- .../_snapshot/generic_value.py | 7 +- .../_snapshot/min_max_value.py | 9 +- .../_snapshot/undecided_value.py | 4 +- src/inline_snapshot/_source_file.py | 1 + tests/conftest.py | 14 +- tests/external/test_external.py | 139 +++++++++++++----- tests/test_pytest_plugin.py | 3 + 20 files changed, 354 insertions(+), 131 deletions(-) create mode 100644 src/inline_snapshot/_generator_utils.py diff --git a/src/inline_snapshot/_code_repr.py b/src/inline_snapshot/_code_repr.py index 837cda27..ed547d8d 100644 --- a/src/inline_snapshot/_code_repr.py +++ b/src/inline_snapshot/_code_repr.py @@ -1,8 +1,11 @@ from enum import Enum from enum import Flag +from functools import partial from functools import singledispatch from unittest import mock +from inline_snapshot._generator_utils import only_value + real_repr = repr @@ -60,19 +63,25 @@ def _(obj: MyCustomClass): code_repr_dispatch.register(f) -def code_repr(obj): +def code_repr(obj, context=None): + + new_repr = partial(mocked_code_repr, context=context) - with mock.patch("builtins.repr", mocked_code_repr): - return mocked_code_repr(obj) + with mock.patch("builtins.repr", new_repr): + return new_repr(obj) -def mocked_code_repr(obj): +def mocked_code_repr(obj, context): from inline_snapshot._customize import Builder - return Builder(_snapshot_context=None)._get_handler(obj).repr() + return only_value( + Builder(_snapshot_context=context)._get_handler(obj).repr(context) + ) def value_code_repr(obj): + # TODO: check the called functions + if not type(obj) == type(obj): # pragma: no cover # this was caused by https://github.com/samuelcolvin/dirty-equals/issues/104 # dispatch will not work in cases like this diff --git a/src/inline_snapshot/_customize/__init__.py b/src/inline_snapshot/_customize/__init__.py index b770b7d9..5eb7af2e 100644 --- a/src/inline_snapshot/_customize/__init__.py +++ b/src/inline_snapshot/_customize/__init__.py @@ -19,16 +19,24 @@ from types import FunctionType from typing import Any from typing import Callable +from typing import Generator from typing import Optional from typing import TypeAlias from typing import overload from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._change import Change +from inline_snapshot._change import ChangeBase +from inline_snapshot._change import ExternalChange from inline_snapshot._code_repr import HasRepr from inline_snapshot._code_repr import value_code_repr from inline_snapshot._compare_context import compare_context from inline_snapshot._compare_context import compare_only from inline_snapshot._customize._custom import CustomizeHandler +from inline_snapshot._external._external_location import ExternalLocation +from inline_snapshot._external._format._protocol import get_format_handler +from inline_snapshot._external._format._protocol import get_format_handler_from_suffix +from inline_snapshot._global_state import state from inline_snapshot._partial_call import partial_call from inline_snapshot._sentinels import undefined from inline_snapshot._unmanaged import is_dirty_equal @@ -125,7 +133,8 @@ def test_myclass(): class CustomDefault(Custom): value: Custom = field(compare=False) - def repr(self): + def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + yield from () # this should never be called because default values are never converted into code assert False @@ -140,7 +149,8 @@ def _needed_imports(self): class CustomUnmanaged(Custom): value: Any - def repr(self): + def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + yield from () return "'unmanaged'" # pragma: no cover def map(self, f): @@ -154,7 +164,8 @@ class CustomUndefined(Custom): def __init__(self): self.value = undefined - def repr(self) -> str: + def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + yield from () return "..." def map(self, f): @@ -178,15 +189,18 @@ class CustomCall(Custom): _kwargs: dict[str, Custom] = field(compare=False) _kwonly: dict[str, Custom] = field(default_factory=dict, compare=False) - def repr(self) -> str: + def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: args = [] - args += [a.repr() for a in self.args] - args += [ - f"{k}={v.repr()}" - for k, v in self.kwargs.items() - if not isinstance(v, CustomDefault) - ] - return f"{self._function.repr()}({', '.join(args)})" + for a in self.args: + v = yield from a.repr(context) + args.append(v) + + for k, v in self.kwargs.items(): + if not isinstance(v, CustomDefault): + value = yield from v.repr(context) + args.append(f"{k}={value}") + + return f"{yield from self._function.repr(context)}({', '.join(args)})" @property def args(self): @@ -237,9 +251,14 @@ class CustomSequence(Custom, CustomSequenceTypes): def map(self, f): return f(self.value_type([x.map(f) for x in self.value])) - def repr(self) -> str: + def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + values = [] + for v in self.value: + value = yield from v.repr(context) + values.append(value) + trailing_comma = self.trailing_comma and len(self.value) == 1 - return f"{self.braces[0]}{', '.join(v.repr() for v in self.value)}{', ' if trailing_comma else ''}{self.braces[1]}" + return f"{self.braces[0]}{', '.join(values)}{', ' if trailing_comma else ''}{self.braces[1]}" def _needed_imports(self): for v in self.value: @@ -268,10 +287,14 @@ class CustomDict(Custom): def map(self, f): return f({k.map(f): v.map(f) for k, v in self.value.items()}) - def repr(self) -> str: - return ( - f"{{{ ', '.join(f'{k.repr()}: {v.repr()}' for k,v in self.value.items())}}}" - ) + def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + values = [] + for k, v in self.value.items(): + key = yield from k.repr(context) + value = yield from v.repr(context) + values.append(f"{key}: {value}") + + return f"{{{ ', '.join(values)}}}" def _needed_imports(self): for k, v in self.value.items(): @@ -279,6 +302,49 @@ def _needed_imports(self): yield from v._needed_imports() +@dataclass(frozen=True) +class CustomExternal(Custom): + value: Any + format: str | None = None + storage: str | None = None + + def map(self, f): + return f(self.value) + + def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + storage_name = self.storage or state().config.default_storage + + format = get_format_handler(self.value, self.format or "") + + location = ExternalLocation( + storage=storage_name, + stem="", + suffix=format.suffix, + filename=Path(context.file.filename), + qualname=context.qualname, + ) + + tmp_file = state().new_tmp_path(location.suffix) + + storage = state().all_storages[storage_name] + + format.encode(self.value, tmp_file) + location = storage.new_location(location, tmp_file) + + yield ExternalChange( + "create", + tmp_file, + ExternalLocation.from_name("", context=context), + location, + format, + ) + + return f"external({location.to_str()!r})" + + def _needed_imports(self): + return [("inline_snapshot", ["external"])] + + class CustomValue(Custom): def __init__(self, value, repr_str=None): assert not isinstance(value, Custom) @@ -303,7 +369,8 @@ def __init__(self, value, repr_str=None): def map(self, f): return f(self.value) - def repr(self) -> str: + def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + yield from () return self.repr_str def __repr__(self): @@ -679,6 +746,10 @@ def _get_handler(self, v) -> Custom: result.__dict__["original_value"] = v return result + def create_external(self, value: Any, format: str | None, storage: str | None): + + return CustomExternal(value, format=format, storage=storage) + def create_list(self, value) -> Custom: """ Creates an intermediate node for a list-expression which can be used as a result for your customization function. diff --git a/src/inline_snapshot/_customize/_custom.py b/src/inline_snapshot/_customize/_custom.py index 39fa880e..12306910 100644 --- a/src/inline_snapshot/_customize/_custom.py +++ b/src/inline_snapshot/_customize/_custom.py @@ -6,8 +6,12 @@ from typing import TYPE_CHECKING from typing import Any from typing import Callable +from typing import Generator from typing import TypeAlias +from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._change import ChangeBase + if TYPE_CHECKING: from inline_snapshot._customize import Builder @@ -28,7 +32,7 @@ def map(self, f): raise NotImplementedError() @abstractmethod - def repr(self): + def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: raise NotImplementedError() def eval(self): diff --git a/src/inline_snapshot/_external/_external.py b/src/inline_snapshot/_external/_external.py index b6800db9..553c5493 100644 --- a/src/inline_snapshot/_external/_external.py +++ b/src/inline_snapshot/_external/_external.py @@ -32,6 +32,9 @@ def is_inside_testdir(path: Path) -> bool: @declare_unmanaged class External(ExternalBase): + _original_location: ExternalLocation + _location: ExternalLocation + def __init__(self, name: str, expr, context: AdapterContext): """External objects are used to move some data outside the source code. You should not instantiate this class directly, but by using `external()` instead. diff --git a/src/inline_snapshot/_external/_external_base.py b/src/inline_snapshot/_external/_external_base.py index 76ae3de7..404e4675 100644 --- a/src/inline_snapshot/_external/_external_base.py +++ b/src/inline_snapshot/_external/_external_base.py @@ -4,6 +4,7 @@ from inline_snapshot._change import ExternalChange from inline_snapshot._exceptions import UsageError +from inline_snapshot._external._format._protocol import get_format_handler from inline_snapshot._external._outsource import Outsourced from inline_snapshot._global_state import state @@ -46,15 +47,18 @@ def __eq__(self, other): __tracebackhide__ = True if isinstance(other, Outsourced): - self._location.suffix = other._location.suffix + + format = get_format_handler(other.data, other.suffix or "") + + self._location.suffix = other.suffix or format.suffix other = other.data - if isinstance(other, ExternalBase): + elif isinstance(other, ExternalBase): raise UsageError( f"you can not compare {external_type}(...) with {external_type}(...)" ) - if isinstance(other, GenericValue): + elif isinstance(other, GenericValue): raise UsageError( f"you can not compare {external_type}(...) with snapshot(...)" ) diff --git a/src/inline_snapshot/_external/_outsource.py b/src/inline_snapshot/_external/_outsource.py index 9a4c23ef..36d4224d 100644 --- a/src/inline_snapshot/_external/_outsource.py +++ b/src/inline_snapshot/_external/_outsource.py @@ -1,35 +1,40 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Any -from inline_snapshot._external._external_location import ExternalLocation +from inline_snapshot._customize import Builder +from inline_snapshot._customize import customize from inline_snapshot._external._format._protocol import get_format_handler from inline_snapshot._global_state import state - -from .._snapshot.generic_value import GenericValue +from inline_snapshot._snapshot.generic_value import GenericValue +@dataclass class Outsourced: - def __init__(self, data: Any, suffix: str | None): - self.data = data + data: Any + suffix: str | None + storage: str | None + # def __init__(self, data: Any, suffix: str | None): + # self.data = data - self._format = get_format_handler(data, suffix or "") - if suffix is None: - suffix = self._format.suffix + # self._format = get_format_handler(data, suffix or "") + # if suffix is None: + # suffix = self._format.suffix - self._location = ExternalLocation("hash", "", suffix, None, None) + # self._location = ExternalLocation("hash", "", suffix, None, None) - tmp_path = state().new_tmp_path(suffix) + # tmp_path = state().new_tmp_path(suffix) - self._format.encode(data, tmp_path) + # self._format.encode(data, tmp_path) - storage = state().all_storages["hash"] + # storage = state().all_storages["hash"] - self._location = storage.new_location( - self._location, tmp_path # type:ignore - ) + # self._location = storage.new_location( + # self._location, tmp_path # type:ignore + # ) - storage.store(self._location, tmp_path) # type: ignore + # storage.store(self._location, tmp_path) # type: ignore def __eq__(self, other): if isinstance(other, GenericValue): @@ -38,20 +43,37 @@ def __eq__(self, other): if isinstance(other, Outsourced): return self.data == other.data + from inline_snapshot._external._external_base import ExternalBase + + if isinstance(other, ExternalBase): + return NotImplemented + + return self.data == other + return NotImplemented - def __repr__(self) -> str: - return f'external("{self._location.to_str()}")' + # def __repr__(self) -> str: + # return f'external("{self._location.to_str()}")' + + # def _load_value(self) -> Any: + # return self.data - def _load_value(self) -> Any: - return self.data +@customize +def outsource_handler(value, builder: Builder): + if isinstance(value, Outsourced): + return builder.create_external( + value.data, format=value.suffix, storage=value.storage + ) -def outsource(data: Any, suffix: str | None = None) -> Any: + +def outsource(data: Any, suffix: str | None = None, storage: str | None = None) -> Any: if suffix and suffix[0] != ".": raise ValueError("suffix has to start with a '.' like '.png'") if not state().active: return data - return Outsourced(data, suffix) + format = get_format_handler(data, suffix or "") + + return Outsourced(data, suffix, storage) diff --git a/src/inline_snapshot/_generator_utils.py b/src/inline_snapshot/_generator_utils.py new file mode 100644 index 00000000..bc094d88 --- /dev/null +++ b/src/inline_snapshot/_generator_utils.py @@ -0,0 +1,17 @@ +from collections import namedtuple + +IterResult = namedtuple("IterResult", "value list") + + +def split_gen(gen): + it = iter(gen) + l = [] + while True: + try: + l.append(next(it)) + except StopIteration as stop: + return IterResult(stop.value, l) + + +def only_value(gen): + return split_gen(gen).value diff --git a/src/inline_snapshot/_get_snapshot_value.py b/src/inline_snapshot/_get_snapshot_value.py index abe99f89..b34e6a1f 100644 --- a/src/inline_snapshot/_get_snapshot_value.py +++ b/src/inline_snapshot/_get_snapshot_value.py @@ -15,6 +15,9 @@ def unwrap(value): if isinstance(value, GenericValue): return value._visible_value().map(lambda v: unwrap(v)[0]), True + if isinstance(value, Outsourced): + return (value.data, True) + if isinstance(value, (External, Outsourced, ExternalFile)): try: return unwrap(value._load_value())[0], True diff --git a/src/inline_snapshot/_inline_snapshot.py b/src/inline_snapshot/_inline_snapshot.py index 9aa1966c..dc829d91 100644 --- a/src/inline_snapshot/_inline_snapshot.py +++ b/src/inline_snapshot/_inline_snapshot.py @@ -15,7 +15,7 @@ from inline_snapshot._types import SnapshotRefBase from ._change import CallArg -from ._change import Change +from ._change import ChangeBase from ._change import RequiredImports from ._global_state import state from ._sentinels import undefined @@ -128,7 +128,7 @@ def result(self): def create_raw(obj, context: AdapterContext): return obj - def _changes(self) -> Iterator[Change]: + def _changes(self) -> Iterator[ChangeBase]: if ( isinstance(self._value._old_value, CustomUndefined) @@ -139,7 +139,7 @@ def _changes(self) -> Iterator[Change]: if isinstance(self._value._new_value, CustomUndefined): return - new_code = self._value._new_code() + new_code = yield from self._value._new_code() yield CallArg( flag="create", diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index e7daeed9..4817510c 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -9,7 +9,7 @@ from inline_snapshot._align import add_x from inline_snapshot._align import align from inline_snapshot._change import CallArg -from inline_snapshot._change import Change +from inline_snapshot._change import ChangeBase from inline_snapshot._change import Delete from inline_snapshot._change import DictInsert from inline_snapshot._change import ListInsert @@ -26,6 +26,7 @@ from inline_snapshot._customize import CustomUnmanaged from inline_snapshot._customize import CustomValue from inline_snapshot._exceptions import UsageError +from inline_snapshot._generator_utils import only_value from inline_snapshot._utils import map_strings from inline_snapshot.syntax_warnings import InlineSnapshotInfo from inline_snapshot.syntax_warnings import InlineSnapshotSyntaxWarning @@ -141,7 +142,7 @@ def __init__(self, context): def compare( self, old_value: Custom, old_node, new_value: Custom - ) -> Generator[Change, None, Custom]: + ) -> Generator[ChangeBase, None, Custom]: if isinstance(old_value, CustomUnmanaged): return old_value @@ -172,7 +173,7 @@ def compare( def compare_CustomValue( self, old_value: Custom, old_node: ast.expr, new_value: Custom - ) -> Generator[Change, None, Custom]: + ) -> Generator[ChangeBase, None, Custom]: assert isinstance(old_value, Custom) assert isinstance(new_value, Custom) @@ -181,7 +182,8 @@ def compare_CustomValue( if old_node is None: new_token = [] else: - new_token = map_strings(new_value.repr()) + new_code = yield from new_value.repr(self.context) + new_token = map_strings(new_code) if ( isinstance(old_node, ast.JoinedStr) @@ -189,8 +191,10 @@ def compare_CustomValue( and isinstance(new_value.value, str) ): if not old_value.eval() == new_value.eval(): + + value = only_value(new_value.repr(self.context)) warnings.warn_explicit( - f"inline-snapshot will be able to fix f-strings in the future.\nThe current string value is:\n {new_value.repr()}", + f"inline-snapshot will be able to fix f-strings in the future.\nThe current string value is:\n {value}", filename=self.context.file._source.filename, lineno=old_node.lineno, category=InlineSnapshotInfo, @@ -236,7 +240,7 @@ def needed_imports(value: Custom): def compare_CustomSequence( self, old_value: CustomSequence, old_node: ast.AST, new_value: CustomSequence - ) -> Generator[Change, None, CustomSequence]: + ) -> Generator[ChangeBase, None, CustomSequence]: if old_node is not None: assert isinstance( @@ -299,7 +303,7 @@ def compare_CustomSequence( def compare_CustomDict( self, old_value: CustomDict, old_node: ast.Dict, new_value: CustomDict - ) -> Generator[Change, None, Custom]: + ) -> Generator[ChangeBase, None, Custom]: assert isinstance(old_value, CustomDict) assert isinstance(new_value, CustomDict) @@ -390,7 +394,7 @@ def compare_CustomDict( def compare_CustomCall( self, old_value: CustomCall, old_node: ast.Call, new_value: CustomCall - ) -> Generator[Change, None, Custom]: + ) -> Generator[ChangeBase, None, Custom]: call = new_value new_args = call.args @@ -444,13 +448,14 @@ def intercept(stream): if old_args_len < len(new_args): for insert_pos, value in list(enumerate(new_args))[old_args_len:]: + new_code = yield from value.repr(self.context) yield CallArg( flag=flag, file=self.context.file._source, node=old_node, arg_pos=insert_pos, arg_name=None, - new_code=value.repr(), + new_code=new_code, new_value=value, ) @@ -497,14 +502,14 @@ def intercept(stream): if to_insert: for key, value in to_insert: - + new_code = yield from value.repr(self.context) yield CallArg( flag=flag, file=self.context.file._source, node=old_node, arg_pos=insert_pos, arg_name=key, - new_code=value.repr(), + new_code=new_code, new_value=value, ) to_insert = [] @@ -514,6 +519,7 @@ def intercept(stream): if to_insert: for key, value in to_insert: + new_code = yield from value.repr(self.context) yield CallArg( flag=flag, @@ -521,7 +527,7 @@ def intercept(stream): node=old_node, arg_pos=insert_pos, arg_name=key, - new_code=value.repr(), + new_code=new_code, new_value=value, ) diff --git a/src/inline_snapshot/_snapshot/collection_value.py b/src/inline_snapshot/_snapshot/collection_value.py index 9d168600..6ee7a03d 100644 --- a/src/inline_snapshot/_snapshot/collection_value.py +++ b/src/inline_snapshot/_snapshot/collection_value.py @@ -1,14 +1,17 @@ import ast +from typing import Generator from typing import Iterator from inline_snapshot._customize import CustomList from inline_snapshot._customize import CustomUndefined from .._change import Change +from .._change import ChangeBase from .._change import Delete from .._change import ListInsert from .._change import Replace from .._global_state import state +from .._utils import map_strings from .._utils import value_to_token from .generic_value import GenericValue from .generic_value import ignore_old_value @@ -32,9 +35,9 @@ def __contains__(self, item): else: return self._return(item in self._old_value.eval()) - def _new_code(self): - # TODO repr() ... - return self._file._value_to_code(self._new_value.eval()) + def _new_code(self) -> Generator[ChangeBase, None, str]: + code = yield from self._new_value.repr(self._context) + return self._file._token_to_code(map_strings(code)) def _get_changes(self) -> Iterator[Change]: assert isinstance(self._old_value, CustomList), self._old_value diff --git a/src/inline_snapshot/_snapshot/dict_value.py b/src/inline_snapshot/_snapshot/dict_value.py index b3d60264..f4bbe6f3 100644 --- a/src/inline_snapshot/_snapshot/dict_value.py +++ b/src/inline_snapshot/_snapshot/dict_value.py @@ -1,4 +1,5 @@ import ast +from typing import Generator from typing import Iterator from inline_snapshot._customize import CustomDict @@ -6,6 +7,7 @@ from .._adapter_context import AdapterContext from .._change import Change +from .._change import ChangeBase from .._change import Delete from .._change import DictInsert from .._global_state import state @@ -56,18 +58,14 @@ def _re_eval(self, value, context: AdapterContext): if key in self._old_value.value: s._re_eval(self._old_value.value[key], context) # type:ignore - def _new_code(self): - return ( - "{" - + ", ".join( - [ - f"{self._file._value_to_code(k)}: {v._new_code()}" # type:ignore - for k, v in self._new_value.value.items() - if not isinstance(v, UndecidedValue) - ] - ) - + "}" - ) + def _new_code(self) -> Generator[ChangeBase, None, str]: + values = [] + for k, v in self._new_value.value.items(): + if not isinstance(v, UndecidedValue): + new_code = yield from v._new_code() # type:ignore + values.append(f"{self._file._value_to_code(k)}: {new_code}") + + return "{" + ", ".join(values) + "}" def _get_changes(self) -> Iterator[Change]: @@ -93,7 +91,8 @@ def _get_changes(self) -> Iterator[Change]: new_value_element, UndecidedValue ): # add new values - to_insert.append((key, new_value_element._new_code())) # type:ignore + new_code = yield from new_value_element._new_code() # type:ignore + to_insert.append((key, new_code)) if to_insert: new_code = [(self._file._value_to_code(k.eval()), v) for k, v in to_insert] diff --git a/src/inline_snapshot/_snapshot/eq_value.py b/src/inline_snapshot/_snapshot/eq_value.py index b168f641..127dd8b4 100644 --- a/src/inline_snapshot/_snapshot/eq_value.py +++ b/src/inline_snapshot/_snapshot/eq_value.py @@ -1,3 +1,4 @@ +from typing import Generator from typing import Iterator from typing import List @@ -6,6 +7,7 @@ from inline_snapshot._utils import map_strings from .._change import Change +from .._change import ChangeBase from .._compare_context import compare_only from .._global_state import state from .generic_value import GenericValue @@ -38,8 +40,9 @@ def __eq__(self, other): self._new_value.eval() == other, ) - def _new_code(self): - return self._file._token_to_code(map_strings(self._new_value.repr())) + def _new_code(self) -> Generator[ChangeBase, None, str]: + code = yield from self._new_value.repr(self._context) + return self._file._token_to_code(map_strings(code)) def _get_changes(self) -> Iterator[Change]: return iter(getattr(self, "_changes", [])) diff --git a/src/inline_snapshot/_snapshot/generic_value.py b/src/inline_snapshot/_snapshot/generic_value.py index 31d0d3ca..961dee47 100644 --- a/src/inline_snapshot/_snapshot/generic_value.py +++ b/src/inline_snapshot/_snapshot/generic_value.py @@ -1,4 +1,5 @@ import ast +from typing import Generator from typing import Iterator from inline_snapshot._adapter_context import AdapterContext @@ -7,7 +8,7 @@ from inline_snapshot._customize import CustomUndefined from inline_snapshot._new_adapter import reeval -from .._change import Change +from .._change import ChangeBase from .._global_state import state from .._types import SnapshotBase from .._unmanaged import declare_unmanaged @@ -80,10 +81,10 @@ def _visible_value(self): else: return self._old_value - def _get_changes(self) -> Iterator[Change]: + def _get_changes(self) -> Iterator[ChangeBase]: raise NotImplementedError() - def _new_code(self): + def _new_code(self) -> Generator[ChangeBase, None, str]: raise NotImplementedError() def __repr__(self): diff --git a/src/inline_snapshot/_snapshot/min_max_value.py b/src/inline_snapshot/_snapshot/min_max_value.py index 21ace18f..18486339 100644 --- a/src/inline_snapshot/_snapshot/min_max_value.py +++ b/src/inline_snapshot/_snapshot/min_max_value.py @@ -1,10 +1,13 @@ +from typing import Generator from typing import Iterator from inline_snapshot._customize import CustomUndefined from .._change import Change +from .._change import ChangeBase from .._change import Replace from .._global_state import state +from .._utils import map_strings from .._utils import value_to_token from .generic_value import GenericValue from .generic_value import ignore_old_value @@ -32,9 +35,9 @@ def _generic_cmp(self, other): return self._return(self.cmp(self._visible_value().eval(), other)) - def _new_code(self): - # TODO repr() ... - return self._file._value_to_code(self._new_value.eval()) + def _new_code(self) -> Generator[ChangeBase, None, str]: + code = yield from self._new_value.repr(self._context) + return self._file._token_to_code(map_strings(code)) def _get_changes(self) -> Iterator[Change]: # TODO repr() ... diff --git a/src/inline_snapshot/_snapshot/undecided_value.py b/src/inline_snapshot/_snapshot/undecided_value.py index 8bbe4027..14635a03 100644 --- a/src/inline_snapshot/_snapshot/undecided_value.py +++ b/src/inline_snapshot/_snapshot/undecided_value.py @@ -16,7 +16,7 @@ from inline_snapshot._unmanaged import is_unmanaged from .._adapter_context import AdapterContext -from .._change import Change +from .._change import ChangeBase from .generic_value import GenericValue @@ -90,7 +90,7 @@ def _change(self, cls): def _new_code(self): assert False - def _get_changes(self) -> Iterator[Change]: + def _get_changes(self) -> Iterator[ChangeBase]: assert isinstance(self._new_value, CustomUndefined) new_value = self.get_builder()._get_handler(self._old_value.eval()) diff --git a/src/inline_snapshot/_source_file.py b/src/inline_snapshot/_source_file.py index d5c3dc15..dd343525 100644 --- a/src/inline_snapshot/_source_file.py +++ b/src/inline_snapshot/_source_file.py @@ -45,6 +45,7 @@ def _token_to_code(self, tokens): return self._format(tokenize.untokenize(tokens)).strip() def _value_to_code(self, value): + # TODO remove this return self._token_to_code(value_to_token(value)) def _token_of_node(self, node): diff --git a/tests/conftest.py b/tests/conftest.py index cd50b575..593e0d7f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,6 +25,7 @@ from inline_snapshot._global_state import snapshot_env from inline_snapshot._rewrite_code import ChangeRecorder from inline_snapshot._types import Category +from inline_snapshot.testing._example import deterministic_uuid pytest_plugins = "pytester" @@ -103,7 +104,7 @@ def run(self, *flags_arg: Category): print("input:") print(textwrap.indent(source, " |", lambda line: True).rstrip()) - with snapshot_env() as state: + with snapshot_env() as state, deterministic_uuid(): recorder = ChangeRecorder() state.update_flags = flags state.all_storages["hash"] = inline_snapshot._external.HashStorage( @@ -325,6 +326,17 @@ def _(value:FakeDate): def set_time(freezer): freezer.move_to(datetime.datetime(2024, 3, 14, 0, 0, 0, 0)) yield + +import uuid +import random + +rd = random.Random(0) + +def f(): + return uuid.UUID(int=rd.getrandbits(128), version=4) + +uuid.uuid4 = f + """ ) diff --git a/tests/external/test_external.py b/tests/external/test_external.py index 17283721..5d71d5d0 100644 --- a/tests/external/test_external.py +++ b/tests/external/test_external.py @@ -19,7 +19,7 @@ def test_basic(check_update): assert check_update( "assert outsource('text') == snapshot()", flags="create" ) == snapshot( - "assert outsource('text') == snapshot(external(\"hash:982d9e3eb996*.txt\"))" + "assert outsource('text') == snapshot(external(\"uuid:e3e70682-c209-4cac-a29f-6fbed82c07cd.txt\"))" ) @@ -48,9 +48,6 @@ def test_a(): ["--inline-snapshot=create"], changed_files=snapshot( { - ".inline-snapshot/external/9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08.bin": "test", - ".inline-snapshot/external/9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08.log": "test", - ".inline-snapshot/external/9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08.txt": "test", "tests/__inline_snapshot__/test_something/test_a/e3e70682-c209-4cac-a29f-6fbed82c07cd.txt": "test", "tests/__inline_snapshot__/test_something/test_a/eb1167b3-67a9-4378-bc65-c1e582e2e662.bin": "test", "tests/__inline_snapshot__/test_something/test_a/f728b4fa-4248-4e3a-8a5d-2f346baa9455.log": "test", @@ -74,20 +71,60 @@ def test_a(): ) -def test_diskstorage(): - with snapshot_env(): +def test_hash_collision(): + e = ( + Example( + { + "pyproject.toml": """\ +[tool.inline-snapshot] +default-storage="hash" +""", + "tests/test_a.py": """ +from inline_snapshot import outsource,snapshot,external - assert outsource("test4") == snapshot(external("hash:a4e624d686e0*.txt")) - assert outsource("test5") == snapshot(external("hash:a140c0c1eda2*.txt")) - assert outsource("test6") == snapshot(external("hash:ed0cb90bdfa4*.txt")) +def test_a(): + assert outsource("test4") == snapshot() + assert outsource("test5") == snapshot() + assert outsource("test6") == snapshot() + if False: + assert outsource("test4") == external("hash:a*.txt") +""", + } + ) + .run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + ".inline-snapshot/external/a140c0c1eda2def2b830363ba362aa4d7d255c262960544821f556e16661b6ff.txt": "test5", + ".inline-snapshot/external/a4e624d686e03ed2767c0abd85c14426b0b1157d2ce81d27bb4fe4f6f01d688a.txt": "test4", + ".inline-snapshot/external/ed0cb90bdfa4f93981a7d03cff99213a86aa96a6cbcf89ec5e8889871f088727.txt": "test6", + "tests/test_a.py": """\ - with raises( - snapshot( - "StorageLookupError: hash collision files=['a140c0c1eda2def2b830363ba362aa4d7d255c262960544821f556e16661b6ff.txt', 'a4e624d686e03ed2767c0abd85c14426b0b1157d2ce81d27bb4fe4f6f01d688a.txt']" - ) - ): - assert outsource("test4") == external("hash:a*.txt") +from inline_snapshot import outsource,snapshot,external + +def test_a(): + assert outsource("test4") == snapshot(external("hash:a4e624d686e0*.txt")) + assert outsource("test5") == snapshot(external("hash:a140c0c1eda2*.txt")) + assert outsource("test6") == snapshot(external("hash:ed0cb90bdfa4*.txt")) + if False: + assert outsource("test4") == external("hash:a*.txt") +""", + } + ), + ) + .replace("False", "True") + ) + + with raises( + snapshot( + "StorageLookupError: hash collision files=['a140c0c1eda2def2b830363ba362aa4d7d255c262960544821f556e16661b6ff.txt', 'a4e624d686e03ed2767c0abd85c14426b0b1157d2ce81d27bb4fe4f6f01d688a.txt']" + ) + ): + e.run_inline() + +def test_hash_not_found(): + with snapshot_env(): with raises( snapshot( "StorageLookupError: hash 'bbbbb*.txt' is not found in the HashStorage" @@ -96,15 +133,21 @@ def test_diskstorage(): assert outsource("test4") == external("hash:bbbbb*.txt") -def test_update_legacy_external_names(project): +def test_update_legacy_external_names(): ( Example( - """\ + { + "pyproject.toml": """\ +[tool.inline-snapshot] +default-storage="hash" +""", + "tests/test_something.py": """\ from inline_snapshot import outsource,snapshot def test_something(): assert outsource("foo") == snapshot() -""" +""", + } ) .run_pytest( ["--inline-snapshot=create"], @@ -179,13 +222,10 @@ def test_pytest_compare_external_bytes(project): from inline_snapshot import external def test_a(): - assert outsource(b"test") == snapshot( - external("hash:9f86d081884c*.bin") - ) + s=snapshot() + assert outsource(b"test") == s - assert outsource(b"test2") == snapshot( - external("hash:9f86d081884c*.bin") - ) + assert outsource(b"test2") == s """ ) @@ -194,7 +234,7 @@ def test_a(): assert result.errorLines() == ( snapshot( """\ -> assert outsource(b"test2") == snapshot( +> assert outsource(b"test2") == s E AssertionError: assert b'test2' == b'test' E \n\ E Use -v to get more diff @@ -227,12 +267,19 @@ def test_a(): from inline_snapshot import external def test_a(): - assert outsource("test") == snapshot(external("hash:9f86d081884c*.txt")) + assert outsource("test") == snapshot(external("uuid:e3e70682-c209-4cac-a29f-6fbed82c07cd.txt")) """ ) def test_pytest_trim_external(project): + project.pyproject( + """\ +[tool.inline-snapshot] +default-storage="hash" +""" + ) + project.setup( """\ def test_a(): @@ -288,6 +335,13 @@ def test_a(): def test_pytest_new_external(project): + project.pyproject( + """\ +[tool.inline-snapshot] +default-storage="hash" +""" + ) + project.setup( """\ def test_a(): @@ -296,9 +350,7 @@ def test_a(): ) project.run() - assert project.storage() == snapshot( - ["9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08.txt"] - ) + assert project.storage() == snapshot([]) project.run("--inline-snapshot=create") @@ -308,6 +360,12 @@ def test_a(): def test_pytest_config_hash_length(project): + project.pyproject( + """\ +[tool.inline-snapshot] +default-storage="hash" +""" + ) project.setup( """\ def test_a(): @@ -415,7 +473,7 @@ def test_something(): from inline_snapshot import external def test_something(): from inline_snapshot import outsource,snapshot - assert outsource("test") == snapshot(external("hash:9f86d081884c*.txt")) + assert outsource("test") == snapshot(external("uuid:e3e70682-c209-4cac-a29f-6fbed82c07cd.txt")) test_something() \ """ @@ -465,6 +523,13 @@ def test_ensure_imports_with_comment(tmp_path): def test_new_externals(project): + project.pyproject( + """\ +[tool.inline-snapshot] +default-storage="hash" +""" + ) + project.setup( """ @@ -479,10 +544,7 @@ def test_something(): project.run("--inline-snapshot=create") assert project.storage() == snapshot( - [ - "2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae.txt", - "8dc140e6fe831481a2005ae152ffe32a9974aa92a260dfbac780d6a87154bb0b.txt", - ] + ["2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae.txt"] ) assert project.source == snapshot( @@ -504,10 +566,7 @@ def test_something(): project.run() assert project.storage() == snapshot( - [ - "2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae.txt", - "8dc140e6fe831481a2005ae152ffe32a9974aa92a260dfbac780d6a87154bb0b.txt", - ] + ["2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae.txt"] ) @@ -540,14 +599,14 @@ def test_something(): ["--inline-snapshot=create"], changed_files=snapshot( { - ".inline-snapshot/external/2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae.txt": "foo", + "tests/__inline_snapshot__/test_something/test_something/f728b4fa-4248-4e3a-8a5d-2f346baa9455.txt": "foo", "tests/__inline_snapshot__/test_something/test_something/e3e70682-c209-4cac-a29f-6fbed82c07cd.txt": "foo", "tests/test_something.py": """\ from inline_snapshot import external, snapshot,outsource def test_something(): - assert outsource("foo") == snapshot(external("hash:2c26b46b68ff*.txt")) + assert outsource("foo") == snapshot(external("uuid:f728b4fa-4248-4e3a-8a5d-2f346baa9455.txt")) assert "foo" == external("uuid:e3e70682-c209-4cac-a29f-6fbed82c07cd.txt") """, } diff --git a/tests/test_pytest_plugin.py b/tests/test_pytest_plugin.py index 2a9cb981..aa11b2b9 100644 --- a/tests/test_pytest_plugin.py +++ b/tests/test_pytest_plugin.py @@ -814,6 +814,7 @@ def test_storage_dir_config(project, tmp_path, storage_dir): "pyproject.toml": f"""\ [tool.inline-snapshot] storage-dir = {str(storage_dir)!r} +default-storage="hash" """, "tests/test_a.py": """\ from inline_snapshot import outsource, snapshot @@ -871,6 +872,7 @@ def test_find_pyproject_in_parent_directories(): "pyproject.toml": """\ [tool.inline-snapshot] hash-length=2 +default-storage="hash" """, "project/pytest.ini": "", "project/test_something.py": """\ @@ -904,6 +906,7 @@ def test_find_pyproject_in_workspace_project(): "sub_project/pyproject.toml": """\ [tool.inline-snapshot] hash-length=2 +default-storage="hash" """, "pyproject.toml": "[tool.pytest.ini_options]", "sub_project/test_something.py": """\ From 35e24db0c025a5bb0620e61131873c99db6ee307 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Tue, 30 Dec 2025 18:54:05 +0100 Subject: [PATCH 21/72] refactor: use customize to format strings --- src/inline_snapshot/_adapter_context.py | 3 +++ src/inline_snapshot/_customize/__init__.py | 14 +++++++++++ src/inline_snapshot/_new_adapter.py | 13 ++++++----- .../_snapshot/collection_value.py | 2 +- src/inline_snapshot/_snapshot/dict_value.py | 6 +++-- src/inline_snapshot/_source_file.py | 23 ++++++++----------- src/inline_snapshot/_utils.py | 19 +-------------- 7 files changed, 39 insertions(+), 41 deletions(-) diff --git a/src/inline_snapshot/_adapter_context.py b/src/inline_snapshot/_adapter_context.py index 9de8cc55..017edb83 100644 --- a/src/inline_snapshot/_adapter_context.py +++ b/src/inline_snapshot/_adapter_context.py @@ -24,3 +24,6 @@ def eval(self, node): self.frame.globals, self.frame.locals, ) + + def _value_to_code(self, value): + return self.file._value_to_code(value, self) diff --git a/src/inline_snapshot/_customize/__init__.py b/src/inline_snapshot/_customize/__init__.py index 5eb7af2e..16218274 100644 --- a/src/inline_snapshot/_customize/__init__.py +++ b/src/inline_snapshot/_customize/__init__.py @@ -42,6 +42,7 @@ from inline_snapshot._unmanaged import is_dirty_equal from inline_snapshot._unmanaged import is_unmanaged from inline_snapshot._utils import clone +from inline_snapshot._utils import triple_quote from ._custom import Custom from ._custom import CustomizeHandler @@ -410,6 +411,19 @@ def standard_handler(value, builder: Builder): return builder.create_dict(value) +@customize +def string_handler(value, builder: Builder): + if isinstance(value, str) and ( + ("\n" in value and value[-1] != "\n") or value.count("\n") > 1 + ): + + triple_quoted_string = triple_quote(value) + + assert ast.literal_eval(triple_quoted_string) == value + + return builder.create_value(value, triple_quoted_string) + + @customize def counter_handler(value, builder: Builder): if isinstance(value, Counter): diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index 4817510c..679704dc 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -6,6 +6,7 @@ from typing import Generator from typing import Sequence +from inline_snapshot._adapter_context import AdapterContext from inline_snapshot._align import add_x from inline_snapshot._align import align from inline_snapshot._change import CallArg @@ -137,7 +138,7 @@ def reeval_CustomDict(old_value, value): class NewAdapter: - def __init__(self, context): + def __init__(self, context: AdapterContext): self.context = context def compare( @@ -272,7 +273,7 @@ def compare_CustomSequence( old_position += 1 elif c == "i": new_value_element = next(new) - new_code = self.context.file._value_to_code(new_value_element) + new_code = self.context._value_to_code(new_value_element) result.append(new_value_element) to_insert[old_position].append((new_code, new_value_element)) elif c == "d": @@ -356,8 +357,8 @@ def compare_CustomDict( if to_insert: new_code = [ ( - self.context.file._value_to_code(k), - self.context.file._value_to_code(v), + self.context._value_to_code(k), + self.context._value_to_code(v), ) for k, v in to_insert ] @@ -376,8 +377,8 @@ def compare_CustomDict( if to_insert: new_code = [ ( - self.context.file._value_to_code(k), - self.context.file._value_to_code(v), + self.context._value_to_code(k), + self.context._value_to_code(v), ) for k, v in to_insert ] diff --git a/src/inline_snapshot/_snapshot/collection_value.py b/src/inline_snapshot/_snapshot/collection_value.py index 6ee7a03d..6cd19a89 100644 --- a/src/inline_snapshot/_snapshot/collection_value.py +++ b/src/inline_snapshot/_snapshot/collection_value.py @@ -86,6 +86,6 @@ def _get_changes(self) -> Iterator[Change]: file=self._file, node=self._ast_node, position=len(self._old_value.value), - new_code=[self._file._value_to_code(v) for v in new_values], + new_code=[self._context._value_to_code(v) for v in new_values], new_values=new_values, ) diff --git a/src/inline_snapshot/_snapshot/dict_value.py b/src/inline_snapshot/_snapshot/dict_value.py index f4bbe6f3..d214b8c2 100644 --- a/src/inline_snapshot/_snapshot/dict_value.py +++ b/src/inline_snapshot/_snapshot/dict_value.py @@ -63,7 +63,7 @@ def _new_code(self) -> Generator[ChangeBase, None, str]: for k, v in self._new_value.value.items(): if not isinstance(v, UndecidedValue): new_code = yield from v._new_code() # type:ignore - values.append(f"{self._file._value_to_code(k)}: {new_code}") + values.append(f"{self._context._value_to_code(k)}: {new_code}") return "{" + ", ".join(values) + "}" @@ -95,7 +95,9 @@ def _get_changes(self) -> Iterator[Change]: to_insert.append((key, new_code)) if to_insert: - new_code = [(self._file._value_to_code(k.eval()), v) for k, v in to_insert] + new_code = [ + (self._context._value_to_code(k.eval()), v) for k, v in to_insert + ] yield DictInsert( "create", self._file, diff --git a/src/inline_snapshot/_source_file.py b/src/inline_snapshot/_source_file.py index dd343525..b3a3964e 100644 --- a/src/inline_snapshot/_source_file.py +++ b/src/inline_snapshot/_source_file.py @@ -1,14 +1,14 @@ -import ast import tokenize from pathlib import Path from executing import Source +from inline_snapshot._code_repr import code_repr from inline_snapshot._format import enforce_formatting from inline_snapshot._format import format_code +from inline_snapshot._generator_utils import only_value from inline_snapshot._utils import normalize from inline_snapshot._utils import simple_token -from inline_snapshot._utils import value_to_token from ._utils import ignore_tokens @@ -33,20 +33,15 @@ def asttokens(self): return self._source.asttokens() def _token_to_code(self, tokens): - if len(tokens) == 1 and tokens[0].type == 3: - try: - if ast.literal_eval(tokens[0].string) == "": - # https://github.com/15r10nk/inline-snapshot/issues/281 - # https://github.com/15r10nk/inline-snapshot/issues/258 - # this would otherwise cause a triple-quoted-string because black would format it as a docstring at the beginning of the code - return '""' - except: # pragma: no cover - pass return self._format(tokenize.untokenize(tokens)).strip() - def _value_to_code(self, value): - # TODO remove this - return self._token_to_code(value_to_token(value)) + def _value_to_code(self, value, context): + from inline_snapshot._customize._custom import Custom + + if isinstance(value, Custom): + return self._format(only_value(value.repr(context))).strip() + else: + return self._format(code_repr(value)).strip() def _token_of_node(self, node): diff --git a/src/inline_snapshot/_utils.py b/src/inline_snapshot/_utils.py index cf680eb7..ead05a81 100644 --- a/src/inline_snapshot/_utils.py +++ b/src/inline_snapshot/_utils.py @@ -161,25 +161,8 @@ def value_to_token(value): def map_strings(code_repr): input = io.StringIO(code_repr) - def map_string(tok): - """Convert strings with newlines in triple quoted strings.""" - if tok.type == token.STRING: - s = ast.literal_eval(tok.string) - if isinstance(s, str) and ( - ("\n" in s and s[-1] != "\n") or s.count("\n") > 1 - ): - # unparse creates a triple quoted string here, - # because it thinks that the string should be a docstring - triple_quoted_string = triple_quote(s) - - assert ast.literal_eval(triple_quoted_string) == s - - return simple_token(tok.type, triple_quoted_string) - - return simple_token(tok.type, tok.string) - return [ - map_string(t) + simple_token(t.type, t.string) for t in tokenize.generate_tokens(input.readline) if t.type not in ignore_tokens ] From 1d76d8b4db276402600c227a937cc6372faf3729 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Tue, 30 Dec 2025 21:54:58 +0100 Subject: [PATCH 22/72] refactor: simplify code generation --- README.md | 2 +- docs/external/outsource.md | 6 ++- src/inline_snapshot/_code_repr.py | 38 +++++++++++------- src/inline_snapshot/_customize/__init__.py | 6 +-- .../_external/_external_base.py | 16 +++++++- .../_external/_storage/_hash.py | 6 ++- .../_external/_storage/_protocol.py | 4 +- .../_external/_storage/_uuid.py | 2 +- src/inline_snapshot/_inline_snapshot.py | 2 + src/inline_snapshot/_new_adapter.py | 40 ++++++++----------- .../_snapshot/collection_value.py | 19 +++------ src/inline_snapshot/_snapshot/dict_value.py | 5 ++- src/inline_snapshot/_snapshot/eq_value.py | 21 +++++----- .../_snapshot/generic_value.py | 7 +++- .../_snapshot/min_max_value.py | 21 +++------- .../_snapshot/undecided_value.py | 36 ++--------------- src/inline_snapshot/_source_file.py | 26 +++++++++--- src/inline_snapshot/_utils.py | 6 ++- tests/external/test_external.py | 16 ++++---- tests/test_formatting.py | 10 ++--- 20 files changed, 145 insertions(+), 144 deletions(-) diff --git a/README.md b/README.md index 817c0b86..54c09b98 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,7 @@ def test_something(): The following examples show how you can use inline-snapshot in your tests. Take a look at the [documentation](https://15r10nk.github.io/inline-snapshot/latest) if you want to know more. - + ``` python from inline_snapshot import snapshot, outsource, external diff --git a/docs/external/outsource.md b/docs/external/outsource.md index cde6bb58..605528f6 100644 --- a/docs/external/outsource.md +++ b/docs/external/outsource.md @@ -26,7 +26,7 @@ def test_captcha(): inline-snapshot always generates an external object in this case. -``` python hl_lines="3 4 20 21 22 23 24 25 26" +``` python hl_lines="3 4 20 21 22 23 24 25 26 27 28" from inline_snapshot import outsource, register_format_alias, snapshot from inline_snapshot import external @@ -50,7 +50,9 @@ def test_captcha(): { "size": "200x100", "difficulty": 8, - "picture": external("hash:0da2cc316111*.png"), + "picture": external( + "uuid:e3e70682-c209-4cac-a29f-6fbed82c07cd.png" + ), } ) ``` diff --git a/src/inline_snapshot/_code_repr.py b/src/inline_snapshot/_code_repr.py index ed547d8d..4ed0fb83 100644 --- a/src/inline_snapshot/_code_repr.py +++ b/src/inline_snapshot/_code_repr.py @@ -1,11 +1,18 @@ +from __future__ import annotations + +from contextlib import contextmanager from enum import Enum from enum import Flag -from functools import partial from functools import singledispatch +from typing import TYPE_CHECKING from unittest import mock from inline_snapshot._generator_utils import only_value +if TYPE_CHECKING: + from inline_snapshot._adapter_context import AdapterContext + + real_repr = repr @@ -35,8 +42,9 @@ def __eq__(self, other): if type(other) is not self._type: return False - other_repr = value_code_repr(other) - return other_repr == self._str_repr or other_repr == repr(self) + with mock_repr(None): + other_repr = value_code_repr(other) + return other_repr == self._str_repr or other_repr == real_repr(self) @singledispatch @@ -63,24 +71,26 @@ def _(obj: MyCustomClass): code_repr_dispatch.register(f) -def code_repr(obj, context=None): +def code_repr(obj): + with mock_repr(None): + return repr(obj) - new_repr = partial(mocked_code_repr, context=context) - with mock.patch("builtins.repr", new_repr): - return new_repr(obj) +@contextmanager +def mock_repr(context: AdapterContext): + def new_repr(obj): + from inline_snapshot._customize import Builder + return only_value( + Builder(_snapshot_context=context)._get_handler(obj).repr(context) + ) -def mocked_code_repr(obj, context): - from inline_snapshot._customize import Builder - - return only_value( - Builder(_snapshot_context=context)._get_handler(obj).repr(context) - ) + with mock.patch("builtins.repr", new_repr): + yield def value_code_repr(obj): - # TODO: check the called functions + assert repr is not real_repr, "@mock_repr is missing" if not type(obj) == type(obj): # pragma: no cover # this was caused by https://github.com/samuelcolvin/dirty-equals/issues/104 diff --git a/src/inline_snapshot/_customize/__init__.py b/src/inline_snapshot/_customize/__init__.py index 16218274..1d74376d 100644 --- a/src/inline_snapshot/_customize/__init__.py +++ b/src/inline_snapshot/_customize/__init__.py @@ -320,7 +320,7 @@ def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: location = ExternalLocation( storage=storage_name, stem="", - suffix=format.suffix, + suffix=self.format or format.suffix, filename=Path(context.file.filename), qualname=context.qualname, ) @@ -687,9 +687,7 @@ def outsourced_handler(value, builder: Builder): from inline_snapshot._external._outsource import Outsourced if isinstance(value, Outsourced): - return builder.create_value(value, repr(value)).with_import( - "inline_snapshot", "external" - ) + return builder.create_external(value, value.suffix, value.storage) @dataclass diff --git a/src/inline_snapshot/_external/_external_base.py b/src/inline_snapshot/_external/_external_base.py index 404e4675..ed44487e 100644 --- a/src/inline_snapshot/_external/_external_base.py +++ b/src/inline_snapshot/_external/_external_base.py @@ -6,9 +6,11 @@ from inline_snapshot._exceptions import UsageError from inline_snapshot._external._format._protocol import get_format_handler from inline_snapshot._external._outsource import Outsourced +from inline_snapshot._external._storage._protocol import StorageLookupError from inline_snapshot._global_state import state from .._snapshot.generic_value import GenericValue +from ._external_location import ExternalLocation from ._external_location import Location @@ -74,7 +76,19 @@ def __eq__(self, other): return True return False - value = self._load_value() + try: + value = self._load_value() + except StorageLookupError as error: + if not error.files and state().update_flags.fix: + self._original_location = ExternalLocation.from_name("") + self._assign(other) + state().incorrect_values += 1 + if state().update_flags.fix: + return True + return False + else: + raise + result = value == other if not result and first_comparison: diff --git a/src/inline_snapshot/_external/_storage/_hash.py b/src/inline_snapshot/_external/_storage/_hash.py index 4876c5b7..f54b4f38 100644 --- a/src/inline_snapshot/_external/_storage/_hash.py +++ b/src/inline_snapshot/_external/_storage/_hash.py @@ -93,11 +93,13 @@ def _lookup_path(self, name) -> Path: if len(files) > 1: raise StorageLookupError( - f"hash collision files={sorted(f.name for f in files)}" + f"hash collision files={sorted(f.name for f in files)}", files=files ) if not files: - raise StorageLookupError(f"hash {name!r} is not found in the HashStorage") + raise StorageLookupError( + f"hash {name!r} is not found in the HashStorage", files=[] + ) return files[0] diff --git a/src/inline_snapshot/_external/_storage/_protocol.py b/src/inline_snapshot/_external/_storage/_protocol.py index 1e174ab3..45168e5c 100644 --- a/src/inline_snapshot/_external/_storage/_protocol.py +++ b/src/inline_snapshot/_external/_storage/_protocol.py @@ -8,7 +8,9 @@ class StorageLookupError(Exception): - pass + def __init__(self, msg, files): + super().__init__(msg) + self.files = files class StorageProtocol: diff --git a/src/inline_snapshot/_external/_storage/_uuid.py b/src/inline_snapshot/_external/_storage/_uuid.py index de40ad00..07cd3e26 100644 --- a/src/inline_snapshot/_external/_storage/_uuid.py +++ b/src/inline_snapshot/_external/_storage/_uuid.py @@ -48,7 +48,7 @@ def _lookup_path(self, location: ExternalLocation): if location.path in external_files(): return external_files()[location.path] else: - raise StorageLookupError(location) + raise StorageLookupError(location, files=[]) def store(self, location: ExternalLocation, file_path: Path): snapshot_path = self._get_path(location) diff --git a/src/inline_snapshot/_inline_snapshot.py b/src/inline_snapshot/_inline_snapshot.py index dc829d91..2d05d1db 100644 --- a/src/inline_snapshot/_inline_snapshot.py +++ b/src/inline_snapshot/_inline_snapshot.py @@ -141,6 +141,8 @@ def _changes(self) -> Iterator[ChangeBase]: new_code = yield from self._value._new_code() + new_code = self._context.file.format_expression(new_code) + yield CallArg( flag="create", file=self._value._file, diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index 679704dc..4e3afbfc 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -28,7 +28,6 @@ from inline_snapshot._customize import CustomValue from inline_snapshot._exceptions import UsageError from inline_snapshot._generator_utils import only_value -from inline_snapshot._utils import map_strings from inline_snapshot.syntax_warnings import InlineSnapshotInfo from inline_snapshot.syntax_warnings import InlineSnapshotSyntaxWarning @@ -181,10 +180,9 @@ def compare_CustomValue( assert isinstance(old_node, (ast.expr, type(None))), old_node if old_node is None: - new_token = [] + new_code = "" else: new_code = yield from new_value.repr(self.context) - new_token = map_strings(new_code) if ( isinstance(old_node, ast.JoinedStr) @@ -207,21 +205,19 @@ def compare_CustomValue( flag = "create" else: flag = "fix" - elif ( - old_node is not None - and not isinstance(old_value, CustomUnmanaged) - and self.context.file._token_of_node(old_node) != new_token - ): + elif not isinstance( + old_value, CustomUnmanaged + ) and self.context.file.code_changed(old_node, new_code): flag = "update" else: # equal and equal repr return old_value - new_code = self.context.file._token_to_code(new_token) + new_code = self.context.file.format_expression(new_code) yield Replace( node=old_node, - file=self.context.file._source, + file=self.context.file, new_code=new_code, flag=flag, old_value=old_value.eval(), @@ -235,7 +231,7 @@ def needed_imports(value: Custom): return imports if imports := needed_imports(new_value): - yield RequiredImports(flag, self.context.file._source, imports) + yield RequiredImports(flag, self.context.file, imports) return new_value @@ -280,7 +276,7 @@ def compare_CustomSequence( old_value_element, old_node_element = next(old) yield Delete( "fix", - self.context.file._source, + self.context.file, old_node_element, old_value_element, ) @@ -291,7 +287,7 @@ def compare_CustomSequence( for position, code_values in to_insert.items(): yield ListInsert( "fix", - self.context.file._source, + self.context.file, old_node, position, *zip(*code_values), # type:ignore @@ -332,9 +328,7 @@ def compare_CustomDict( ): if key2 not in new_value.value: # delete entries - yield Delete( - "fix", self.context.file._source, node2, old_value.value[key2] - ) + yield Delete("fix", self.context.file, node2, old_value.value[key2]) to_insert = [] insert_pos = 0 @@ -364,7 +358,7 @@ def compare_CustomDict( ] yield DictInsert( "fix", - self.context.file._source, + self.context.file, old_node, insert_pos, new_code, @@ -384,7 +378,7 @@ def compare_CustomDict( ] yield DictInsert( "fix", - self.context.file._source, + self.context.file, old_node, len(old_value.value), new_code, @@ -442,7 +436,7 @@ def intercept(stream): for arg_pos, node in list(enumerate(old_node.args))[len(new_args) :]: yield Delete( flag, - self.context.file._source, + self.context.file, node, old_value.argument(arg_pos), ) @@ -452,7 +446,7 @@ def intercept(stream): new_code = yield from value.repr(self.context) yield CallArg( flag=flag, - file=self.context.file._source, + file=self.context.file, node=old_node, arg_pos=insert_pos, arg_name=None, @@ -478,7 +472,7 @@ def intercept(stream): if old_value.argument(kw_arg) == new_value.argument(kw_arg) else flag ), - self.context.file._source, + self.context.file, kw_value, old_value.argument(kw_arg), ) @@ -506,7 +500,7 @@ def intercept(stream): new_code = yield from value.repr(self.context) yield CallArg( flag=flag, - file=self.context.file._source, + file=self.context.file, node=old_node, arg_pos=insert_pos, arg_name=key, @@ -524,7 +518,7 @@ def intercept(stream): yield CallArg( flag=flag, - file=self.context.file._source, + file=self.context.file, node=old_node, arg_pos=insert_pos, arg_name=key, diff --git a/src/inline_snapshot/_snapshot/collection_value.py b/src/inline_snapshot/_snapshot/collection_value.py index 6cd19a89..49c10cbc 100644 --- a/src/inline_snapshot/_snapshot/collection_value.py +++ b/src/inline_snapshot/_snapshot/collection_value.py @@ -5,14 +5,11 @@ from inline_snapshot._customize import CustomList from inline_snapshot._customize import CustomUndefined -from .._change import Change from .._change import ChangeBase from .._change import Delete from .._change import ListInsert from .._change import Replace from .._global_state import state -from .._utils import map_strings -from .._utils import value_to_token from .generic_value import GenericValue from .generic_value import ignore_old_value @@ -25,10 +22,10 @@ def __contains__(self, item): state().missing_values += 1 if isinstance(self._new_value, CustomUndefined): - self._new_value = CustomList([self.get_builder()._get_handler(item)]) + self._new_value = CustomList([self.to_custom(item)]) else: if item not in self._new_value.eval(): - self._new_value.value.append(self.get_builder()._get_handler(item)) + self._new_value.value.append(self.to_custom(item)) if ignore_old_value() or isinstance(self._old_value, CustomUndefined): return True @@ -37,9 +34,9 @@ def __contains__(self, item): def _new_code(self) -> Generator[ChangeBase, None, str]: code = yield from self._new_value.repr(self._context) - return self._file._token_to_code(map_strings(code)) + return code - def _get_changes(self) -> Iterator[Change]: + def _get_changes(self) -> Iterator[ChangeBase]: assert isinstance(self._old_value, CustomList), self._old_value assert isinstance(self._new_value, CustomList), self._new_value @@ -60,13 +57,9 @@ def _get_changes(self) -> Iterator[Change]: continue # check for update - new_token = value_to_token(old_value.eval()) + new_code = yield from self.to_custom(old_value.eval()).repr(self._context) - if ( - old_node is not None - and self._file._token_of_node(old_node) != new_token - ): - new_code = self._file._token_to_code(new_token) + if self._file.code_changed(old_node, new_code): yield Replace( node=old_node, diff --git a/src/inline_snapshot/_snapshot/dict_value.py b/src/inline_snapshot/_snapshot/dict_value.py index d214b8c2..716ef0a6 100644 --- a/src/inline_snapshot/_snapshot/dict_value.py +++ b/src/inline_snapshot/_snapshot/dict_value.py @@ -25,7 +25,7 @@ def __getitem__(self, index): if isinstance(self._new_value, CustomUndefined): self._new_value = CustomDict({}) - index = self.get_builder()._get_handler(index) + index = self.to_custom(index) if index not in self._new_value.value: if isinstance(self._old_value, CustomUndefined): @@ -92,7 +92,8 @@ def _get_changes(self) -> Iterator[Change]: ): # add new values new_code = yield from new_value_element._new_code() # type:ignore - to_insert.append((key, new_code)) + + to_insert.append((key, self._file.format_expression(new_code))) if to_insert: new_code = [ diff --git a/src/inline_snapshot/_snapshot/eq_value.py b/src/inline_snapshot/_snapshot/eq_value.py index 127dd8b4..b14df492 100644 --- a/src/inline_snapshot/_snapshot/eq_value.py +++ b/src/inline_snapshot/_snapshot/eq_value.py @@ -2,9 +2,10 @@ from typing import Iterator from typing import List +from inline_snapshot._code_repr import mock_repr from inline_snapshot._customize import CustomUndefined +from inline_snapshot._generator_utils import split_gen from inline_snapshot._new_adapter import NewAdapter -from inline_snapshot._utils import map_strings from .._change import Change from .._change import ChangeBase @@ -18,7 +19,8 @@ class EqValue(GenericValue): _changes: List[Change] def __eq__(self, other): - custom_other = self.get_builder(_build_new_value=True)._get_handler(other) + with mock_repr(self._context): + custom_other = self.get_builder(_build_new_value=True)._get_handler(other) if isinstance(self._old_value, CustomUndefined): state().missing_values += 1 @@ -27,13 +29,12 @@ def __eq__(self, other): self._changes = [] adapter = NewAdapter(self._context) - it = iter(adapter.compare(self._old_value, self._ast_node, custom_other)) - while True: - try: - self._changes.append(next(it)) - except StopIteration as ex: - self._new_value = ex.value - break + + result = split_gen( + adapter.compare(self._old_value, self._ast_node, custom_other) + ) + self._changes = result.list + self._new_value = result.value return self._return( self._old_value.eval() == other, @@ -42,7 +43,7 @@ def __eq__(self, other): def _new_code(self) -> Generator[ChangeBase, None, str]: code = yield from self._new_value.repr(self._context) - return self._file._token_to_code(map_strings(code)) + return code def _get_changes(self) -> Iterator[Change]: return iter(getattr(self, "_changes", [])) diff --git a/src/inline_snapshot/_snapshot/generic_value.py b/src/inline_snapshot/_snapshot/generic_value.py index 961dee47..66f85584 100644 --- a/src/inline_snapshot/_snapshot/generic_value.py +++ b/src/inline_snapshot/_snapshot/generic_value.py @@ -3,6 +3,7 @@ from typing import Iterator from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._code_repr import mock_repr from inline_snapshot._customize import Builder from inline_snapshot._customize import Custom from inline_snapshot._customize import CustomUndefined @@ -48,12 +49,16 @@ def _return(self, result, new_result=True): def _file(self): return self._context.file + def to_custom(self, value): + with mock_repr(self._context): + return self.get_builder()._get_handler(value) + def value_to_custom(self, value): if isinstance(value, Custom): return value if self._ast_node is None: - return self.get_builder()._get_handler(value) + return self.to_custom(value) else: from inline_snapshot._snapshot.undecided_value import AstToCustom diff --git a/src/inline_snapshot/_snapshot/min_max_value.py b/src/inline_snapshot/_snapshot/min_max_value.py index 18486339..c22ec34f 100644 --- a/src/inline_snapshot/_snapshot/min_max_value.py +++ b/src/inline_snapshot/_snapshot/min_max_value.py @@ -3,12 +3,9 @@ from inline_snapshot._customize import CustomUndefined -from .._change import Change from .._change import ChangeBase from .._change import Replace from .._global_state import state -from .._utils import map_strings -from .._utils import value_to_token from .generic_value import GenericValue from .generic_value import ignore_old_value @@ -25,38 +22,32 @@ def _generic_cmp(self, other): state().missing_values += 1 if isinstance(self._new_value, CustomUndefined): - self._new_value = self.get_builder()._get_handler(other) + self._new_value = self.to_custom(other) if isinstance(self._old_value, CustomUndefined) or ignore_old_value(): return True return self._return(self.cmp(self._old_value.eval(), other)) else: if not self.cmp(self._new_value.eval(), other): - self._new_value = self.get_builder()._get_handler(other) + self._new_value = self.to_custom(other) return self._return(self.cmp(self._visible_value().eval(), other)) def _new_code(self) -> Generator[ChangeBase, None, str]: code = yield from self._new_value.repr(self._context) - return self._file._token_to_code(map_strings(code)) + return code - def _get_changes(self) -> Iterator[Change]: - # TODO repr() ... - new_token = value_to_token(self._new_value.eval()) + def _get_changes(self) -> Iterator[ChangeBase]: + new_code = yield from self._new_code() if not self.cmp(self._old_value.eval(), self._new_value.eval()): flag = "fix" elif not self.cmp(self._new_value.eval(), self._old_value.eval()): flag = "trim" - elif ( - self._ast_node is not None - and self._file._token_of_node(self._ast_node) != new_token - ): + elif self._file.code_changed(self._ast_node, new_code): flag = "update" else: return - new_code = self._file._token_to_code(new_token) - yield Replace( node=self._ast_node, file=self._file, diff --git a/src/inline_snapshot/_snapshot/undecided_value.py b/src/inline_snapshot/_snapshot/undecided_value.py index 14635a03..f665a706 100644 --- a/src/inline_snapshot/_snapshot/undecided_value.py +++ b/src/inline_snapshot/_snapshot/undecided_value.py @@ -72,17 +72,13 @@ def convert_Dict(self, value: dict, node: ast.Dict): class UndecidedValue(GenericValue): def __init__(self, old_value, ast_node, context: AdapterContext): self._context = context + self._ast_node = ast_node - if not isinstance(old_value, Custom): - if ast_node is not None: - old_value = AstToCustom(context).convert(old_value, ast_node) - else: - old_value = self.get_builder()._get_handler(old_value) + old_value = self.value_to_custom(old_value) assert isinstance(old_value, Custom) self._old_value = old_value self._new_value = CustomUndefined() - self._ast_node = ast_node def _change(self, cls): self.__class__ = cls @@ -93,7 +89,7 @@ def _new_code(self): def _get_changes(self) -> Iterator[ChangeBase]: assert isinstance(self._new_value, CustomUndefined) - new_value = self.get_builder()._get_handler(self._old_value.eval()) + new_value = self.to_custom(self._old_value.eval()) adapter = NewAdapter(self._context) @@ -101,32 +97,6 @@ def _get_changes(self) -> Iterator[ChangeBase]: assert change.flag == "update", change yield change - # def handle(node, obj): - - # adapter = get_adapter_type(obj) - # if adapter is not None and hasattr(adapter, "items"): - # for item in adapter.items(obj, node): - # yield from handle(item.node, item.value) - # return - - # if not isinstance(obj, CustomUnmanaged) and node is not None: - # new_token = value_to_token(obj.eval()) - # if self._file._token_of_node(node) != new_token: - # new_code = self._file._token_to_code(new_token) - - # yield Replace( - # node=self._ast_node, - # file=self._file, - # new_code=new_code, - # flag="update", - # old_value=self._old_value, - # new_value=self._old_value, - # ) - - # yield from handle(self._ast_node, self._old_value) - - # functions which determine the type - def __eq__(self, other): if compare_only(): return False diff --git a/src/inline_snapshot/_source_file.py b/src/inline_snapshot/_source_file.py index b3a3964e..95ee295f 100644 --- a/src/inline_snapshot/_source_file.py +++ b/src/inline_snapshot/_source_file.py @@ -11,6 +11,7 @@ from inline_snapshot._utils import simple_token from ._utils import ignore_tokens +from ._utils import map_strings class SourceFile: @@ -23,25 +24,38 @@ def __init__(self, source: Source): def filename(self) -> str: return self._source.filename - def _format(self, text): + def _format(self, code): if self._source is None or enforce_formatting(): - return text + return code else: - return format_code(text, Path(self._source.filename)) + return format_code(code, Path(self._source.filename)) + + def format_expression(self, code): + return self._format(code).strip() def asttokens(self): return self._source.asttokens() def _token_to_code(self, tokens): - return self._format(tokenize.untokenize(tokens)).strip() + return self.format_expression(tokenize.untokenize(tokens)) def _value_to_code(self, value, context): from inline_snapshot._customize._custom import Custom if isinstance(value, Custom): - return self._format(only_value(value.repr(context))).strip() + return self.format_expression(only_value(value.repr(context))) else: - return self._format(code_repr(value)).strip() + # TODO assert False + return self.format_expression(code_repr(value)) + + def code_changed(self, old_node, new_code): + + if old_node is None: + return False + + new_token = map_strings(new_code) + + return self._token_of_node(old_node) != new_token def _token_of_node(self, node): diff --git a/src/inline_snapshot/_utils.py b/src/inline_snapshot/_utils.py index ead05a81..d4188524 100644 --- a/src/inline_snapshot/_utils.py +++ b/src/inline_snapshot/_utils.py @@ -8,7 +8,6 @@ from inline_snapshot._exceptions import UsageError -from ._code_repr import code_repr from ._code_repr import value_code_repr @@ -155,7 +154,10 @@ def __eq__(self, other): def value_to_token(value): - return map_strings(code_repr(value)) + from inline_snapshot._customize._custom import Custom + + assert isinstance(value, Custom) + return map_strings(value.repr()) def map_strings(code_repr): diff --git a/tests/external/test_external.py b/tests/external/test_external.py index 5d71d5d0..6a03eff0 100644 --- a/tests/external/test_external.py +++ b/tests/external/test_external.py @@ -124,12 +124,12 @@ def test_a(): def test_hash_not_found(): - with snapshot_env(): - with raises( - snapshot( - "StorageLookupError: hash 'bbbbb*.txt' is not found in the HashStorage" - ) - ): + with raises( + snapshot( + "StorageLookupError: hash 'bbbbb*.txt' is not found in the HashStorage" + ) + ): + with snapshot_env(): assert outsource("test4") == external("hash:bbbbb*.txt") @@ -243,7 +243,7 @@ def test_a(): if sys.version_info >= (3, 11) else snapshot( """\ -> assert outsource(b"test2") == snapshot( +> assert outsource(b"test2") == s E AssertionError """ ) @@ -628,7 +628,7 @@ def test_something(): error=snapshot( """\ > assert "foo" == external("hash:aaaaaaaaaaaa*.txt") -> raise StorageLookupError(f"hash {name!r} is not found in the HashStorage") +> raise StorageLookupError( E inline_snapshot._external._storage._protocol.StorageLookupError: hash 'aaaaaaaaaaaa*.txt' is not found in the HashStorage """ ), diff --git a/tests/test_formatting.py b/tests/test_formatting.py index f918f1e6..706ce274 100644 --- a/tests/test_formatting.py +++ b/tests/test_formatting.py @@ -34,7 +34,7 @@ def test_something(): def test_something(): assert 1==snapshot(1) assert 1==snapshot(1) - assert list(range(20)) == snapshot([0 ,1 ,2 ,3 ,4 ,5 ,6 ,7 ,8 ,9 ,10 ,11 ,12 ,13 ,14 ,15 ,16 ,17 ,18 ,19 ]) + assert list(range(20)) == snapshot([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]) """ } ), @@ -60,8 +60,8 @@ def test_something(): | + assert 1==snapshot(1) | | assert 1==snapshot(2) | | - assert list(range(20)) == snapshot() | -| + assert list(range(20)) == snapshot([0 ,1 ,2 ,3 ,4 ,5 ,6 ,7 ,8 ,9 ,10 | -| ,11 ,12 ,13 ,14 ,15 ,16 ,17 ,18 ,19 ]) | +| + assert list(range(20)) == snapshot([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, | +| 11, 12, 13, 14, 15, 16, 17, 18, 19]) | +------------------------------------------------------------------------------+ These changes will be applied, because you used create @@ -74,8 +74,8 @@ def test_something(): | assert 1==snapshot(1) | | - assert 1==snapshot(2) | | + assert 1==snapshot(1) | -| assert list(range(20)) == snapshot([0 ,1 ,2 ,3 ,4 ,5 ,6 ,7 ,8 ,9 ,10 | -| ,11 ,12 ,13 ,14 ,15 ,16 ,17 ,18 ,19 ]) | +| assert list(range(20)) == snapshot([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, | +| 11, 12, 13, 14, 15, 16, 17, 18, 19]) | +------------------------------------------------------------------------------+ These changes will be applied, because you used fix From cb918b0f921e0a69c03a274f973ae2ee77a6116b Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Wed, 31 Dec 2025 00:34:07 +0100 Subject: [PATCH 23/72] refactor: removed _value_to_code --- src/inline_snapshot/_adapter_context.py | 3 -- src/inline_snapshot/_new_adapter.py | 38 ++++++++++++------- .../_snapshot/collection_value.py | 15 +++++--- src/inline_snapshot/_snapshot/dict_value.py | 20 ++++++---- src/inline_snapshot/_source_file.py | 17 +-------- 5 files changed, 48 insertions(+), 45 deletions(-) diff --git a/src/inline_snapshot/_adapter_context.py b/src/inline_snapshot/_adapter_context.py index 017edb83..9de8cc55 100644 --- a/src/inline_snapshot/_adapter_context.py +++ b/src/inline_snapshot/_adapter_context.py @@ -24,6 +24,3 @@ def eval(self, node): self.frame.globals, self.frame.locals, ) - - def _value_to_code(self, value): - return self.file._value_to_code(value, self) diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index 4e3afbfc..b12d63d0 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -269,9 +269,11 @@ def compare_CustomSequence( old_position += 1 elif c == "i": new_value_element = next(new) - new_code = self.context._value_to_code(new_value_element) + new_code = yield from new_value_element.repr(self.context) result.append(new_value_element) - to_insert[old_position].append((new_code, new_value_element)) + to_insert[old_position].append( + (self.context.file.format_expression(new_code), new_value_element) + ) elif c == "d": old_value_element, old_node_element = next(old) yield Delete( @@ -349,13 +351,17 @@ def compare_CustomDict( ) if to_insert: - new_code = [ - ( - self.context._value_to_code(k), - self.context._value_to_code(v), + new_code = [] + for k, v in to_insert: + new_code_key = yield from k.repr(self.context) + new_code_value = yield from v.repr(self.context) + new_code.append( + ( + self.context.file.format_expression(new_code_key), + self.context.file.format_expression(new_code_value), + ) ) - for k, v in to_insert - ] + yield DictInsert( "fix", self.context.file, @@ -369,13 +375,17 @@ def compare_CustomDict( insert_pos += 1 if to_insert: - new_code = [ - ( - self.context._value_to_code(k), - self.context._value_to_code(v), + new_code = [] + for k, v in to_insert: + new_key = yield from k.repr(self.context) + new_value = yield from v.repr(self.context) + new_code.append( + ( + self.context.file.format_expression(new_key), + self.context.file.format_expression(new_value), + ) ) - for k, v in to_insert - ] + yield DictInsert( "fix", self.context.file, diff --git a/src/inline_snapshot/_snapshot/collection_value.py b/src/inline_snapshot/_snapshot/collection_value.py index 49c10cbc..8166f2cc 100644 --- a/src/inline_snapshot/_snapshot/collection_value.py +++ b/src/inline_snapshot/_snapshot/collection_value.py @@ -70,15 +70,20 @@ def _get_changes(self) -> Iterator[ChangeBase]: new_value=old_value, ) - new_values = [ - v.eval() for v in self._new_value.value if v not in self._old_value.value - ] - if new_values: + new_codes = [] + new_values = [] + for v in self._new_value.value: + if v not in self._old_value.value: + new_code = yield from v.repr(self._context) + new_codes.append(self._file.format_expression(new_code)) + new_values.append(v.eval()) + + if new_codes: yield ListInsert( flag="fix", file=self._file, node=self._ast_node, position=len(self._old_value.value), - new_code=[self._context._value_to_code(v) for v in new_values], + new_code=new_codes, new_values=new_values, ) diff --git a/src/inline_snapshot/_snapshot/dict_value.py b/src/inline_snapshot/_snapshot/dict_value.py index 716ef0a6..0a8fdbb5 100644 --- a/src/inline_snapshot/_snapshot/dict_value.py +++ b/src/inline_snapshot/_snapshot/dict_value.py @@ -63,7 +63,8 @@ def _new_code(self) -> Generator[ChangeBase, None, str]: for k, v in self._new_value.value.items(): if not isinstance(v, UndecidedValue): new_code = yield from v._new_code() # type:ignore - values.append(f"{self._context._value_to_code(k)}: {new_code}") + new_key = yield from k.repr(self._context) + values.append(f"{new_key}: {new_code}") return "{" + ", ".join(values) + "}" @@ -86,24 +87,29 @@ def _get_changes(self) -> Iterator[Change]: yield Delete("trim", self._file, node, self._old_value.value[key]) to_insert = [] + to_insert_values = [] for key, new_value_element in self._new_value.value.items(): if key not in self._old_value.value and not isinstance( new_value_element, UndecidedValue ): # add new values - new_code = yield from new_value_element._new_code() # type:ignore + new_value = yield from new_value_element._new_code() # type:ignore + new_key = yield from key.repr(self._context) - to_insert.append((key, self._file.format_expression(new_code))) + to_insert.append( + ( + self._file.format_expression(new_key), + self._file.format_expression(new_value), + ) + ) + to_insert_values.append((key, new_value_element)) if to_insert: - new_code = [ - (self._context._value_to_code(k.eval()), v) for k, v in to_insert - ] yield DictInsert( "create", self._file, self._ast_node, len(self._old_value.value), - new_code, to_insert, + to_insert_values, ) diff --git a/src/inline_snapshot/_source_file.py b/src/inline_snapshot/_source_file.py index 95ee295f..ab86551d 100644 --- a/src/inline_snapshot/_source_file.py +++ b/src/inline_snapshot/_source_file.py @@ -1,12 +1,9 @@ -import tokenize from pathlib import Path from executing import Source -from inline_snapshot._code_repr import code_repr from inline_snapshot._format import enforce_formatting from inline_snapshot._format import format_code -from inline_snapshot._generator_utils import only_value from inline_snapshot._utils import normalize from inline_snapshot._utils import simple_token @@ -30,24 +27,12 @@ def _format(self, code): else: return format_code(code, Path(self._source.filename)) - def format_expression(self, code): + def format_expression(self, code: str) -> str: return self._format(code).strip() def asttokens(self): return self._source.asttokens() - def _token_to_code(self, tokens): - return self.format_expression(tokenize.untokenize(tokens)) - - def _value_to_code(self, value, context): - from inline_snapshot._customize._custom import Custom - - if isinstance(value, Custom): - return self.format_expression(only_value(value.repr(context))) - else: - # TODO assert False - return self.format_expression(code_repr(value)) - def code_changed(self, old_node, new_code): if old_node is None: From 1364b7e10cf87ccd1682ab660da20a31888712bb Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Wed, 31 Dec 2025 00:52:05 +0100 Subject: [PATCH 24/72] refactor: moved format_expression into code generation --- src/inline_snapshot/_change.py | 18 +++++++++++++++ src/inline_snapshot/_inline_snapshot.py | 2 -- src/inline_snapshot/_new_adapter.py | 21 +++++------------- .../_snapshot/collection_value.py | 2 +- src/inline_snapshot/_snapshot/dict_value.py | 10 ++------- src/inline_snapshot/_source_file.py | 20 ++++++++++++----- src/inline_snapshot/_utils.py | 22 ------------------- 7 files changed, 42 insertions(+), 53 deletions(-) diff --git a/src/inline_snapshot/_change.py b/src/inline_snapshot/_change.py index 41e97c55..1edcaa90 100644 --- a/src/inline_snapshot/_change.py +++ b/src/inline_snapshot/_change.py @@ -137,6 +137,9 @@ class AddArgument(Change): new_code: str new_value: Any + def __post_init__(self): + self.new_code = self.file.format_expression(self.new_code) + @dataclass() class ListInsert(Change): @@ -146,6 +149,9 @@ class ListInsert(Change): new_code: list[str] new_values: list[Any] + def __post_init__(self): + self.new_code = [self.file.format_expression(v) for v in self.new_code] + @dataclass() class DictInsert(Change): @@ -155,6 +161,12 @@ class DictInsert(Change): new_code: list[tuple[str, str]] new_values: list[tuple[Any, Any]] + def __post_init__(self): + self.new_code = [ + (self.file.format_expression(k), self.file.format_expression(v)) + for k, v in self.new_code + ] + @dataclass() class Replace(Change): @@ -169,6 +181,9 @@ def apply(self, recorder: ChangeRecorder): range = self.file.asttokens().get_text_positions(self.node, False) change.replace(range, self.new_code, filename=self.filename) + def __post_init__(self): + self.new_code = self.file.format_expression(self.new_code) + @dataclass() class CallArg(Change): @@ -179,6 +194,9 @@ class CallArg(Change): new_code: str new_value: Any + def __post_init__(self): + self.new_code = self.file.format_expression(self.new_code) + TokenRange = Tuple[Token, Token] diff --git a/src/inline_snapshot/_inline_snapshot.py b/src/inline_snapshot/_inline_snapshot.py index 2d05d1db..dc829d91 100644 --- a/src/inline_snapshot/_inline_snapshot.py +++ b/src/inline_snapshot/_inline_snapshot.py @@ -141,8 +141,6 @@ def _changes(self) -> Iterator[ChangeBase]: new_code = yield from self._value._new_code() - new_code = self._context.file.format_expression(new_code) - yield CallArg( flag="create", file=self._value._file, diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index b12d63d0..12b7eda0 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -213,8 +213,6 @@ def compare_CustomValue( # equal and equal repr return old_value - new_code = self.context.file.format_expression(new_code) - yield Replace( node=old_node, file=self.context.file, @@ -271,9 +269,7 @@ def compare_CustomSequence( new_value_element = next(new) new_code = yield from new_value_element.repr(self.context) result.append(new_value_element) - to_insert[old_position].append( - (self.context.file.format_expression(new_code), new_value_element) - ) + to_insert[old_position].append((new_code, new_value_element)) elif c == "d": old_value_element, old_node_element = next(old) yield Delete( @@ -355,12 +351,7 @@ def compare_CustomDict( for k, v in to_insert: new_code_key = yield from k.repr(self.context) new_code_value = yield from v.repr(self.context) - new_code.append( - ( - self.context.file.format_expression(new_code_key), - self.context.file.format_expression(new_code_value), - ) - ) + new_code.append((new_code_key, new_code_value)) yield DictInsert( "fix", @@ -377,12 +368,12 @@ def compare_CustomDict( if to_insert: new_code = [] for k, v in to_insert: - new_key = yield from k.repr(self.context) - new_value = yield from v.repr(self.context) + new_code_key = yield from k.repr(self.context) + new_code_value = yield from v.repr(self.context) new_code.append( ( - self.context.file.format_expression(new_key), - self.context.file.format_expression(new_value), + new_code_key, + new_code_value, ) ) diff --git a/src/inline_snapshot/_snapshot/collection_value.py b/src/inline_snapshot/_snapshot/collection_value.py index 8166f2cc..704c9dd1 100644 --- a/src/inline_snapshot/_snapshot/collection_value.py +++ b/src/inline_snapshot/_snapshot/collection_value.py @@ -75,7 +75,7 @@ def _get_changes(self) -> Iterator[ChangeBase]: for v in self._new_value.value: if v not in self._old_value.value: new_code = yield from v.repr(self._context) - new_codes.append(self._file.format_expression(new_code)) + new_codes.append(new_code) new_values.append(v.eval()) if new_codes: diff --git a/src/inline_snapshot/_snapshot/dict_value.py b/src/inline_snapshot/_snapshot/dict_value.py index 0a8fdbb5..5424b320 100644 --- a/src/inline_snapshot/_snapshot/dict_value.py +++ b/src/inline_snapshot/_snapshot/dict_value.py @@ -6,7 +6,6 @@ from inline_snapshot._customize import CustomUndefined from .._adapter_context import AdapterContext -from .._change import Change from .._change import ChangeBase from .._change import Delete from .._change import DictInsert @@ -68,7 +67,7 @@ def _new_code(self) -> Generator[ChangeBase, None, str]: return "{" + ", ".join(values) + "}" - def _get_changes(self) -> Iterator[Change]: + def _get_changes(self) -> Iterator[ChangeBase]: assert not isinstance(self._old_value, CustomUndefined) @@ -96,12 +95,7 @@ def _get_changes(self) -> Iterator[Change]: new_value = yield from new_value_element._new_code() # type:ignore new_key = yield from key.repr(self._context) - to_insert.append( - ( - self._file.format_expression(new_key), - self._file.format_expression(new_value), - ) - ) + to_insert.append((new_key, new_value)) to_insert_values.append((key, new_value_element)) if to_insert: diff --git a/src/inline_snapshot/_source_file.py b/src/inline_snapshot/_source_file.py index ab86551d..44dae906 100644 --- a/src/inline_snapshot/_source_file.py +++ b/src/inline_snapshot/_source_file.py @@ -1,3 +1,6 @@ +import io +import token +import tokenize from pathlib import Path from executing import Source @@ -7,8 +10,17 @@ from inline_snapshot._utils import normalize from inline_snapshot._utils import simple_token -from ._utils import ignore_tokens -from ._utils import map_strings +ignore_tokens = (token.NEWLINE, token.ENDMARKER, token.NL) + + +def _token_of_code(code_repr): + input = io.StringIO(code_repr) + + return [ + simple_token(t.type, t.string) + for t in tokenize.generate_tokens(input.readline) + if t.type not in ignore_tokens + ] class SourceFile: @@ -38,9 +50,7 @@ def code_changed(self, old_node, new_code): if old_node is None: return False - new_token = map_strings(new_code) - - return self._token_of_node(old_node) != new_token + return self._token_of_node(old_node) != _token_of_code(new_code) def _token_of_node(self, node): diff --git a/src/inline_snapshot/_utils.py b/src/inline_snapshot/_utils.py index d4188524..492d8312 100644 --- a/src/inline_snapshot/_utils.py +++ b/src/inline_snapshot/_utils.py @@ -1,8 +1,6 @@ import ast import copy -import io import token -import tokenize from collections import namedtuple from pathlib import Path @@ -76,9 +74,6 @@ def normalize(token_sequence): return skip_trailing_comma(normalize_strings(token_sequence)) -ignore_tokens = (token.NEWLINE, token.ENDMARKER, token.NL) - - # based on ast.unparse def _str_literal_helper(string, *, quote_types): """Helper for writing string literals, minimizing escapes. @@ -153,23 +148,6 @@ def __eq__(self, other): return super().__eq__(other) -def value_to_token(value): - from inline_snapshot._customize._custom import Custom - - assert isinstance(value, Custom) - return map_strings(value.repr()) - - -def map_strings(code_repr): - input = io.StringIO(code_repr) - - return [ - simple_token(t.type, t.string) - for t in tokenize.generate_tokens(input.readline) - if t.type not in ignore_tokens - ] - - def clone(obj): new = copy.deepcopy(obj) if not obj == new: From fa659a7503dac879c91e2a64bf65d1e9cea2cd67 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Thu, 1 Jan 2026 16:05:15 +0100 Subject: [PATCH 25/72] test: coverage --- pyproject.toml | 1 + src/inline_snapshot/_customize/__init__.py | 23 +++------- src/inline_snapshot/_customize/_custom.py | 3 +- .../_external/_external_base.py | 4 +- src/inline_snapshot/_external/_outsource.py | 31 +------------ src/inline_snapshot/_partial_call.py | 4 +- src/inline_snapshot/_snapshot/eq_value.py | 4 +- .../_snapshot/generic_value.py | 4 +- tests/adapter/test_dataclass.py | 37 ++++++++++++++++ tests/external/test_external.py | 5 +++ tests/test_customize.py | 43 +++++++++++++++++++ 11 files changed, 101 insertions(+), 58 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 46c74a89..d528b5bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ exclude_lines = [ "# pragma: no cover", "if TYPE_CHECKING:", "if is_insider", + ": ..." ] diff --git a/src/inline_snapshot/_customize/__init__.py b/src/inline_snapshot/_customize/__init__.py index 1d74376d..a546306d 100644 --- a/src/inline_snapshot/_customize/__init__.py +++ b/src/inline_snapshot/_customize/__init__.py @@ -38,6 +38,7 @@ from inline_snapshot._external._format._protocol import get_format_handler_from_suffix from inline_snapshot._global_state import state from inline_snapshot._partial_call import partial_call +from inline_snapshot._partial_call import partial_check_args from inline_snapshot._sentinels import undefined from inline_snapshot._unmanaged import is_dirty_equal from inline_snapshot._unmanaged import is_unmanaged @@ -123,6 +124,8 @@ def test_myclass(): if f is None: return partial(customize, priority=priority) # type: ignore[return-value] + else: + partial_check_args(f, {"value", "builder", "local_vars", "global_vars"}) from inline_snapshot._global_state import state @@ -135,7 +138,7 @@ class CustomDefault(Custom): value: Custom = field(compare=False) def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: - yield from () + yield from () # pragma: no cover # this should never be called because default values are never converted into code assert False @@ -151,15 +154,12 @@ class CustomUnmanaged(Custom): value: Any def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: - yield from () - return "'unmanaged'" # pragma: no cover + yield from () # pragma: no cover + return "'unmanaged'" def map(self, f): return f(self.value) - def _needed_imports(self): - yield from () - class CustomUndefined(Custom): def __init__(self): @@ -172,9 +172,6 @@ def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: def map(self, f): return f(undefined) - def _needed_imports(self): - yield from () - def unwrap_default(value): if isinstance(value, CustomDefault): @@ -682,14 +679,6 @@ def dirty_equals_handler(value, builder: Builder): return builder.create_call(type(value), args, kwargs) -@customize -def outsourced_handler(value, builder: Builder): - from inline_snapshot._external._outsource import Outsourced - - if isinstance(value, Outsourced): - return builder.create_external(value, value.suffix, value.storage) - - @dataclass class ContextValue: name: str diff --git a/src/inline_snapshot/_customize/_custom.py b/src/inline_snapshot/_customize/_custom.py index 12306910..80dbe65a 100644 --- a/src/inline_snapshot/_customize/_custom.py +++ b/src/inline_snapshot/_customize/_custom.py @@ -38,9 +38,8 @@ def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: def eval(self): return self.map(lambda a: a) - @abstractmethod def _needed_imports(self): - raise NotImplementedError() + yield from () CustomizeHandler: TypeAlias = Callable[[Any, "Builder"], Custom | None] diff --git a/src/inline_snapshot/_external/_external_base.py b/src/inline_snapshot/_external/_external_base.py index ed44487e..79295c02 100644 --- a/src/inline_snapshot/_external/_external_base.py +++ b/src/inline_snapshot/_external/_external_base.py @@ -83,9 +83,7 @@ def __eq__(self, other): self._original_location = ExternalLocation.from_name("") self._assign(other) state().incorrect_values += 1 - if state().update_flags.fix: - return True - return False + return True else: raise diff --git a/src/inline_snapshot/_external/_outsource.py b/src/inline_snapshot/_external/_outsource.py index 36d4224d..9f9748b0 100644 --- a/src/inline_snapshot/_external/_outsource.py +++ b/src/inline_snapshot/_external/_outsource.py @@ -15,26 +15,6 @@ class Outsourced: data: Any suffix: str | None storage: str | None - # def __init__(self, data: Any, suffix: str | None): - # self.data = data - - # self._format = get_format_handler(data, suffix or "") - # if suffix is None: - # suffix = self._format.suffix - - # self._location = ExternalLocation("hash", "", suffix, None, None) - - # tmp_path = state().new_tmp_path(suffix) - - # self._format.encode(data, tmp_path) - - # storage = state().all_storages["hash"] - - # self._location = storage.new_location( - # self._location, tmp_path # type:ignore - # ) - - # storage.store(self._location, tmp_path) # type: ignore def __eq__(self, other): if isinstance(other, GenericValue): @@ -50,14 +30,6 @@ def __eq__(self, other): return self.data == other - return NotImplemented - - # def __repr__(self) -> str: - # return f'external("{self._location.to_str()}")' - - # def _load_value(self) -> Any: - # return self.data - @customize def outsource_handler(value, builder: Builder): @@ -74,6 +46,7 @@ def outsource(data: Any, suffix: str | None = None, storage: str | None = None) if not state().active: return data - format = get_format_handler(data, suffix or "") + # check if the suffix/datatype is supported + get_format_handler(data, suffix or "") return Outsourced(data, suffix, storage) diff --git a/src/inline_snapshot/_partial_call.py b/src/inline_snapshot/_partial_call.py index dbdc623b..87a1020a 100644 --- a/src/inline_snapshot/_partial_call.py +++ b/src/inline_snapshot/_partial_call.py @@ -3,7 +3,7 @@ from inline_snapshot._exceptions import UsageError -def check_args(func, allowed): +def partial_check_args(func, allowed): sign = inspect.signature(func) for p in sign.parameters.values(): if p.default is not inspect.Parameter.empty: @@ -16,7 +16,7 @@ def check_args(func, allowed): if p.name not in allowed: raise UsageError( - f"`{p.name}` is an unknown parameter. allowed are {allowed}" + f"`{p.name}` is an unknown parameter. allowed are {sorted(allowed)}" ) diff --git a/src/inline_snapshot/_snapshot/eq_value.py b/src/inline_snapshot/_snapshot/eq_value.py index b14df492..0a1841ee 100644 --- a/src/inline_snapshot/_snapshot/eq_value.py +++ b/src/inline_snapshot/_snapshot/eq_value.py @@ -2,7 +2,6 @@ from typing import Iterator from typing import List -from inline_snapshot._code_repr import mock_repr from inline_snapshot._customize import CustomUndefined from inline_snapshot._generator_utils import split_gen from inline_snapshot._new_adapter import NewAdapter @@ -19,8 +18,7 @@ class EqValue(GenericValue): _changes: List[Change] def __eq__(self, other): - with mock_repr(self._context): - custom_other = self.get_builder(_build_new_value=True)._get_handler(other) + custom_other = self.to_custom(other, _build_new_value=True) if isinstance(self._old_value, CustomUndefined): state().missing_values += 1 diff --git a/src/inline_snapshot/_snapshot/generic_value.py b/src/inline_snapshot/_snapshot/generic_value.py index 66f85584..37546b86 100644 --- a/src/inline_snapshot/_snapshot/generic_value.py +++ b/src/inline_snapshot/_snapshot/generic_value.py @@ -49,9 +49,9 @@ def _return(self, result, new_result=True): def _file(self): return self._context.file - def to_custom(self, value): + def to_custom(self, value, **args): with mock_repr(self._context): - return self.get_builder()._get_handler(value) + return self.get_builder(**args)._get_handler(value) def value_to_custom(self, value): if isinstance(value, Custom): diff --git a/tests/adapter/test_dataclass.py b/tests/adapter/test_dataclass.py index b54297c0..b2e01d3c 100644 --- a/tests/adapter/test_dataclass.py +++ b/tests/adapter/test_dataclass.py @@ -121,6 +121,43 @@ def test_something(): ) +def test_dataclass_add_arguments(): + Example( + """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass,field + +@dataclass +class A: + a:int + b:int=2 + +def test_something(): + for _ in [1,2]: + assert A(a=1,b=5) == snapshot(A(a=1)) +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass,field + +@dataclass +class A: + a:int + b:int=2 + +def test_something(): + for _ in [1,2]: + assert A(a=1,b=5) == snapshot(A(a=1, b=5)) +""" + } + ), + ) + + def test_dataclass_positional_arguments(): Example( """\ diff --git a/tests/external/test_external.py b/tests/external/test_external.py index 6a03eff0..4148ea6c 100644 --- a/tests/external/test_external.py +++ b/tests/external/test_external.py @@ -71,6 +71,11 @@ def test_a(): ) +def test_compare_outsource(): + assert outsource("one") == outsource("one") + assert outsource("one") != outsource("two") + + def test_hash_collision(): e = ( Example( diff --git a/tests/test_customize.py b/tests/test_customize.py index e6dfb8db..7da7f97b 100644 --- a/tests/test_customize.py +++ b/tests/test_customize.py @@ -1,6 +1,8 @@ import pytest from inline_snapshot import snapshot +from inline_snapshot._customize import customize +from inline_snapshot.extra import raises from inline_snapshot.testing import Example @@ -85,3 +87,44 @@ def test(): } ), ) + + +def test_customize_argument_exceptions(): + + with raises( + snapshot("UsageError: `value` has a default value which is not supported") + ): + + @customize + def f(value=5): + pass + + with raises( + snapshot( + "UsageError: `value` is not a positional or keyword parameter, which is not supported" + ) + ): + + @customize + def f(value, /): + pass + + with raises( + snapshot( + "UsageError: `value` is not a positional or keyword parameter, which is not supported" + ) + ): + + @customize + def f(*, value): + pass + + with raises( + snapshot( + "UsageError: `my_own_arg` is an unknown parameter. allowed are ['builder', 'global_vars', 'local_vars', 'value']" + ) + ): + + @customize + def f(my_own_arg): + pass From 88e6f3e8fcf9de3977e0c31f7d9fe5d015a80098 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sun, 4 Jan 2026 21:46:41 +0100 Subject: [PATCH 26/72] feat: use pluggy to manage plugins and extensions --- src/inline_snapshot/__init__.py | 2 +- src/inline_snapshot/_customize/__init__.py | 411 +----------------- src/inline_snapshot/_external/_outsource.py | 10 - src/inline_snapshot/_global_state.py | 28 +- src/inline_snapshot/_snapshot_session.py | 25 ++ src/inline_snapshot/plugin/_context_value.py | 8 + src/inline_snapshot/plugin/_default_plugin.py | 315 ++++++++++++++ src/inline_snapshot/plugin/_spec.py | 85 ++++ src/inline_snapshot/pytest_plugin.py | 23 + src/inline_snapshot/testing/_example.py | 23 + tests/adapter/test_dataclass.py | 37 +- tests/test_customize.py | 63 +-- 12 files changed, 534 insertions(+), 496 deletions(-) create mode 100644 src/inline_snapshot/plugin/_context_value.py create mode 100644 src/inline_snapshot/plugin/_default_plugin.py create mode 100644 src/inline_snapshot/plugin/_spec.py diff --git a/src/inline_snapshot/__init__.py b/src/inline_snapshot/__init__.py index 5743d1cd..2c63b0e3 100644 --- a/src/inline_snapshot/__init__.py +++ b/src/inline_snapshot/__init__.py @@ -6,7 +6,6 @@ from ._customize import Builder from ._customize import Custom from ._customize import CustomizeHandler -from ._customize import customize from ._exceptions import UsageError from ._external._external import external from ._external._external_file import external_file @@ -20,6 +19,7 @@ from ._types import Category from ._types import Snapshot from ._unmanaged import declare_unmanaged +from .plugin._spec import customize from .version import __version__ __all__ = [ diff --git a/src/inline_snapshot/_customize/__init__.py b/src/inline_snapshot/_customize/__init__.py index a546306d..150eafbf 100644 --- a/src/inline_snapshot/_customize/__init__.py +++ b/src/inline_snapshot/_customize/__init__.py @@ -44,95 +44,12 @@ from inline_snapshot._unmanaged import is_unmanaged from inline_snapshot._utils import clone from inline_snapshot._utils import triple_quote +from inline_snapshot.plugin._context_value import ContextValue from ._custom import Custom from ._custom import CustomizeHandler -@overload -def customize( - f: None = None, *, priority: int = 0 -) -> Callable[[CustomizeHandler], CustomizeHandler]: ... - - -@overload -def customize(f: CustomizeHandler, *, priority: int = 0) -> CustomizeHandler: ... - - -def customize( - f: CustomizeHandler | None = None, *, priority: int = 0 -) -> CustomizeHandler | Callable[[CustomizeHandler], CustomizeHandler]: - """ - Registers a function as a customization hook inside inline-snapshot. - - Customization hooks allow you to control how objects are represented in snapshot code. - When inline-snapshot generates code for a value, it calls each registered customization - function in reverse order of registration until one returns a Custom object. - - **Important**: Customization handlers should be registered in your `conftest.py` file to ensure - they are loaded before your tests run. - - Args: - f: A customization handler function. See [CustomizeHandler][inline_snapshot._customize.CustomizeHandler] - for the expected signature. - - Returns: - The input function unchanged (for use as a decorator) - - Example: - Basic usage with a custom class: - - - ``` python - from inline_snapshot import customize, snapshot - - - class MyClass: - def __init__(self, arg1, arg2, key=None): - self.arg1 = arg1 - self.arg2 = arg2 - self.key_attr = key - - - @customize - def my_custom_handler(value, builder): - if isinstance(value, MyClass): - # Generate code like: MyClass(arg1, arg2, key=value) - return builder.create_call( - MyClass, [value.arg1, value.arg2], {"key": value.key_attr} - ) - return None # Let other handlers process this value - - - def test_myclass(): - obj = MyClass(42, "hello", key="world") - assert obj == snapshot(MyClass(42, "hello", key="world")) - ``` - - Note: - - **Always register handlers in `conftest.py`** to ensure they're available for all tests - - Handlers are called in **reverse order** of registration (last registered is called first) - - If no handler returns a Custom object, a default representation is used - - Use builder methods (`create_call`, `create_list`, `create_dict`, etc.) to construct representations - - Always return `None` if your handler doesn't apply to the given value type - - The builder automatically handles recursive conversion of nested values - - See Also: - - [Builder][inline_snapshot._customize.Builder]: Available builder methods - - [Custom][inline_snapshot._customize.Custom]: Base class for custom representations - """ - - if f is None: - return partial(customize, priority=priority) # type: ignore[return-value] - else: - partial_check_args(f, {"value", "builder", "local_vars", "global_vars"}) - - from inline_snapshot._global_state import state - - state().custom_functions[priority].append(f) - return f - - @dataclass(frozen=True) class CustomDefault(Custom): value: Custom = field(compare=False) @@ -396,301 +313,6 @@ def with_import(self, module, name, simplify=True): return self -@customize -def standard_handler(value, builder: Builder): - if isinstance(value, list): - return builder.create_list(value) - - if type(value) is tuple: - return builder.create_tuple(value) - - if isinstance(value, dict): - return builder.create_dict(value) - - -@customize -def string_handler(value, builder: Builder): - if isinstance(value, str) and ( - ("\n" in value and value[-1] != "\n") or value.count("\n") > 1 - ): - - triple_quoted_string = triple_quote(value) - - assert ast.literal_eval(triple_quoted_string) == value - - return builder.create_value(value, triple_quoted_string) - - -@customize -def counter_handler(value, builder: Builder): - if isinstance(value, Counter): - return builder.create_call(Counter, [dict(value)]) - - -@customize -def function_handler(value, builder: Builder): - if isinstance(value, FunctionType): - qualname = value.__qualname__ - name = qualname.split(".")[0] - return builder.create_value(value, qualname).with_import(value.__module__, name) - - -@customize -def builtin_function_handler(value, builder: Builder): - if isinstance(value, BuiltinFunctionType): - return builder.create_value(value, value.__name__) - - -@customize -def type_handler(value, builder: Builder): - if isinstance(value, type): - qualname = value.__qualname__ - name = qualname.split(".")[0] - return builder.create_value(value, qualname).with_import(value.__module__, name) - - -@customize -def path_handler(value, builder: Builder): - if isinstance(value, Path): - return builder.create_call(Path, [value.as_posix()]) - - if isinstance(value, PurePath): - return builder.create_call(PurePath, [value.as_posix()]) - - -def sort_set_values(set_values): - is_sorted = False - try: - set_values = sorted(set_values) - is_sorted = True - except TypeError: - pass - - set_values = list(map(repr, set_values)) - if not is_sorted: - set_values = sorted(set_values) - - return set_values - - -@customize -def set_handler(value, builder: Builder): - if isinstance(value, set): - if len(value) == 0: - return builder.create_value(value, "set()") - else: - return builder.create_value( - value, "{" + ", ".join(sort_set_values(value)) + "}" - ) - - -@customize -def frozenset_handler(value, builder: Builder): - if isinstance(value, frozenset): - if len(value) == 0: - return builder.create_value(value, "frozenset()") - else: - return builder.create_call(frozenset, [set(value)]) - - -@customize -def dataclass_handler(value, builder: Builder): - - if is_dataclass(value) and not isinstance(value, type): - - kwargs = {} - - for field in fields(value): # type: ignore - if field.repr: - field_value = getattr(value, field.name) - is_default = False - - if field.default != MISSING and field.default == field_value: - is_default = True - - if ( - field.default_factory != MISSING - and field.default_factory() == field_value - ): - is_default = True - - if is_default: - field_value = builder.create_default(field_value) - kwargs[field.name] = field_value - - return builder.create_call(type(value), [], kwargs, {}) - - -try: - import attrs -except ImportError: # pragma: no cover - pass -else: - - @customize - def attrs_handler(value, builder: Builder): - - if attrs.has(type(value)): - - kwargs = {} - - for field in attrs.fields(type(value)): - if field.repr: - field_value = getattr(value, field.name) - is_default = False - - if field.default is not attrs.NOTHING: - - default_value = ( - field.default - if not isinstance(field.default, attrs.Factory) # type: ignore - else ( - field.default.factory() - if not field.default.takes_self - else field.default.factory(value) - ) - ) - - if default_value == field_value: - is_default = True - - if is_default: - field_value = builder.create_default(field_value) - - kwargs[field.name] = field_value - - return builder.create_call(type(value), [], kwargs, {}) - - -try: - import pydantic -except ImportError: # pragma: no cover - pass -else: - # import pydantic - if pydantic.version.VERSION.startswith("1."): - # pydantic v1 - from pydantic.fields import Undefined as PydanticUndefined # type: ignore[attr-defined,no-redef] - - def get_fields(value): - return value.__fields__ - - else: - # pydantic v2 - from pydantic_core import PydanticUndefined - - def get_fields(value): - return type(value).model_fields - - from pydantic import BaseModel - - @customize - def attrs_handler(value, builder: Builder): - - if isinstance(value, BaseModel): - - kwargs = {} - - for name, field in get_fields(value).items(): # type: ignore - if getattr(field, "repr", True): - field_value = getattr(value, name) - is_default = False - - if ( - field.default is not PydanticUndefined - and field.default == field_value - ): - is_default = True - - if ( - field.default_factory is not None - and field.default_factory() == field_value - ): - is_default = True - - if is_default: - field_value = builder.create_default(field_value) - - kwargs[name] = field_value - - return builder.create_call(type(value), [], kwargs, {}) - - -@customize -def namedtuple_handler(value, builder: Builder): - t = type(value) - b = t.__bases__ - if len(b) != 1 or b[0] != tuple: - return - f = getattr(t, "_fields", None) - if not isinstance(f, tuple): - return - if not all(type(n) == str for n in f): - return - - # TODO handle with builder.Default - - return builder.create_call( - type(value), - [], - { - field: getattr(value, field) - for field in value._fields - if field not in value._field_defaults - or getattr(value, field) != value._field_defaults[field] - }, - {}, - ) - - -@customize -def defaultdict_handler(value, builder: Builder): - if isinstance(value, defaultdict): - return builder.create_call( - type(value), [value.default_factory, dict(value)], {}, {} - ) - - -@customize -def unmanaged_handler(value, builder: Builder): - if is_unmanaged(value): - return CustomUnmanaged(value=value) - - -@customize -def undefined_handler(value, builder: Builder): - if value is undefined: - return CustomUndefined() - - -@customize(priority=1000) -def dirty_equals_handler(value, builder: Builder): - - if is_dirty_equal(value) and builder._build_new_value: - if isinstance(value, type): - return builder.create_value(value, value.__name__).with_import( - "dirty_equals", value.__name__ - ) - else: - from dirty_equals._utils import Omit - - args = [a for a in value._repr_args if a is not Omit] - kwargs = {k: a for k, a in value._repr_kwargs.items() if a is not Omit} - return builder.create_call(type(value), args, kwargs) - - -@dataclass -class ContextValue: - name: str - value: Any - - -@customize -def context_value_handler(value, builder: Builder): - if isinstance(value, ContextValue): - return builder.create_value(value.value, value.name) - - @dataclass class Builder: _snapshot_context: AdapterContext @@ -720,29 +342,18 @@ def _get_handler(self, v) -> Custom: result = v - custom_functions = [ - f - for _, function_list in sorted(state().custom_functions.items()) - for f in function_list - ] - while not isinstance(result, Custom): - for f in reversed(custom_functions): - with compare_context(): - r = partial_call( - f, - { - "value": result, - "builder": self, - "local_vars": local_vars, - "global_vars": global_vars, - }, - ) - if r is not None: - result = r - break - else: + with compare_context(): + r = state().pm.hook.customize( + value=result, + builder=self, + local_vars=local_vars, + global_vars=global_vars, + ) + if r is None: result = CustomValue(result) + else: + result = r result.__dict__["original_value"] = v return result diff --git a/src/inline_snapshot/_external/_outsource.py b/src/inline_snapshot/_external/_outsource.py index 9f9748b0..dd5e62ca 100644 --- a/src/inline_snapshot/_external/_outsource.py +++ b/src/inline_snapshot/_external/_outsource.py @@ -3,8 +3,6 @@ from dataclasses import dataclass from typing import Any -from inline_snapshot._customize import Builder -from inline_snapshot._customize import customize from inline_snapshot._external._format._protocol import get_format_handler from inline_snapshot._global_state import state from inline_snapshot._snapshot.generic_value import GenericValue @@ -31,14 +29,6 @@ def __eq__(self, other): return self.data == other -@customize -def outsource_handler(value, builder: Builder): - if isinstance(value, Outsourced): - return builder.create_external( - value.data, format=value.suffix, storage=value.storage - ) - - def outsource(data: Any, suffix: str | None = None, storage: str | None = None) -> Any: if suffix and suffix[0] != ".": raise ValueError("suffix has to start with a '.' like '.png'") diff --git a/src/inline_snapshot/_global_state.py b/src/inline_snapshot/_global_state.py index 2bac6202..f8df70ee 100644 --- a/src/inline_snapshot/_global_state.py +++ b/src/inline_snapshot/_global_state.py @@ -1,7 +1,6 @@ from __future__ import annotations import contextlib -from collections import defaultdict from copy import deepcopy from dataclasses import dataclass from dataclasses import field @@ -13,10 +12,11 @@ from typing import Literal from uuid import uuid4 +import pluggy + from inline_snapshot._config import Config if TYPE_CHECKING: - from inline_snapshot._customize._custom import CustomizeHandler from inline_snapshot._external._format._protocol import Format from inline_snapshot._external._storage._protocol import StorageProtocol from inline_snapshot._types import SnapshotRefBase @@ -52,16 +52,16 @@ class State: default_factory=lambda: TemporaryDirectory(prefix="inline-snapshot-") ) + pm: pluggy.PluginManager = field( + default_factory=lambda: pluggy.PluginManager("inline_snapshot") + ) + def new_tmp_path(self, suffix: str) -> Path: assert self.tmp_dir is not None return Path(self.tmp_dir.name) / f"tmp-path-{uuid4()}{suffix}" disable_reason: Literal["xdist", "ci", "implementation", None] = None - custom_functions: dict[int, list[CustomizeHandler]] = field( - default_factory=lambda: defaultdict(list) - ) - _latest_global_states: list[State] = [] @@ -78,12 +78,22 @@ def enter_snapshot_context(): latest = _current _latest_global_states.append(_current) _current = State() - _current.custom_functions = defaultdict( - list, {k: list(v) for k, v in latest.custom_functions.items()} - ) _current.all_formats = dict(latest.all_formats) _current.config = deepcopy(latest.config) + from .plugin._spec import InlineSnapshotPluginSpec + + _current.pm.add_hookspecs(InlineSnapshotPluginSpec) + _current.pm.load_setuptools_entrypoints("inline_snapshot") + + from .plugin._default_plugin import InlineSnapshotAttrsPlugin + from .plugin._default_plugin import InlineSnapshotPlugin + from .plugin._default_plugin import InlineSnapshotPydanticPlugin + + _current.pm.register(InlineSnapshotPlugin()) + _current.pm.register(InlineSnapshotAttrsPlugin()) + _current.pm.register(InlineSnapshotPydanticPlugin()) + def leave_snapshot_context(): global _current diff --git a/src/inline_snapshot/_snapshot_session.py b/src/inline_snapshot/_snapshot_session.py index 09456764..21f2213d 100644 --- a/src/inline_snapshot/_snapshot_session.py +++ b/src/inline_snapshot/_snapshot_session.py @@ -237,6 +237,31 @@ def header(): class SnapshotSession: + def __init__(self): + self.registered_modules = set() + + def register_customize_hooks_from_module(self, module): + """Find and register functions decorated with @customize from a module""" + + if module.__file__ in self.registered_modules: + return + + self.registered_modules.add(module.__file__) + + class ConftestPlugin: + pass + + for name in dir(module): + obj = getattr(module, name, None) + if obj is None or not callable(obj): + continue + + # Check if the function has the customize hookimpl marker + if hasattr(obj, "inline_snapshot_impl"): + setattr(ConftestPlugin, name, obj) + + state().pm.register(ConftestPlugin, name=f"") + @staticmethod def test_enter(): state().missing_values = 0 diff --git a/src/inline_snapshot/plugin/_context_value.py b/src/inline_snapshot/plugin/_context_value.py new file mode 100644 index 00000000..2d2a05b0 --- /dev/null +++ b/src/inline_snapshot/plugin/_context_value.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass +from typing import Any + + +@dataclass +class ContextValue: + name: str + value: Any diff --git a/src/inline_snapshot/plugin/_default_plugin.py b/src/inline_snapshot/plugin/_default_plugin.py new file mode 100644 index 00000000..d85d7b3b --- /dev/null +++ b/src/inline_snapshot/plugin/_default_plugin.py @@ -0,0 +1,315 @@ +import ast +from collections import Counter +from collections import defaultdict +from dataclasses import MISSING +from dataclasses import fields +from dataclasses import is_dataclass +from pathlib import Path +from pathlib import PurePath +from types import BuiltinFunctionType +from types import FunctionType + +from inline_snapshot._customize import Builder +from inline_snapshot._customize import CustomUndefined +from inline_snapshot._customize import CustomUnmanaged +from inline_snapshot._external._outsource import Outsourced +from inline_snapshot._sentinels import undefined +from inline_snapshot._unmanaged import is_dirty_equal +from inline_snapshot._unmanaged import is_unmanaged +from inline_snapshot._utils import triple_quote +from inline_snapshot.plugin._context_value import ContextValue + +from ._spec import customize + + +class InlineSnapshotPlugin: + @customize + def standard_handler(self, value, builder: Builder): + if isinstance(value, list): + return builder.create_list(value) + + if type(value) is tuple: + return builder.create_tuple(value) + + if isinstance(value, dict): + return builder.create_dict(value) + + @customize + def string_handler(self, value, builder: Builder): + if isinstance(value, str) and ( + ("\n" in value and value[-1] != "\n") or value.count("\n") > 1 + ): + + triple_quoted_string = triple_quote(value) + + assert ast.literal_eval(triple_quoted_string) == value + + return builder.create_value(value, triple_quoted_string) + + @customize(tryfirst=True) + def counter_handler(self, value, builder: Builder): + if isinstance(value, Counter): + return builder.create_call(Counter, [dict(value)]) + + @customize + def function_handler(self, value, builder: Builder): + if isinstance(value, FunctionType): + qualname = value.__qualname__ + name = qualname.split(".")[0] + return builder.create_value(value, qualname).with_import( + value.__module__, name + ) + + @customize + def builtin_function_handler(self, value, builder: Builder): + if isinstance(value, BuiltinFunctionType): + return builder.create_value(value, value.__name__) + + @customize + def type_handler(self, value, builder: Builder): + if isinstance(value, type): + qualname = value.__qualname__ + name = qualname.split(".")[0] + return builder.create_value(value, qualname).with_import( + value.__module__, name + ) + + @customize + def path_handler(self, value, builder: Builder): + if isinstance(value, Path): + return builder.create_call(Path, [value.as_posix()]) + + if isinstance(value, PurePath): + return builder.create_call(PurePath, [value.as_posix()]) + + def sort_set_values(self, set_values): + is_sorted = False + try: + set_values = sorted(set_values) + is_sorted = True + except TypeError: + pass + + set_values = list(map(repr, set_values)) + if not is_sorted: + set_values = sorted(set_values) + + return set_values + + @customize + def set_handler(self, value, builder: Builder): + if isinstance(value, set): + if len(value) == 0: + return builder.create_value(value, "set()") + else: + return builder.create_value( + value, "{" + ", ".join(self.sort_set_values(value)) + "}" + ) + + @customize + def frozenset_handler(self, value, builder: Builder): + if isinstance(value, frozenset): + if len(value) == 0: + return builder.create_value(value, "frozenset()") + else: + return builder.create_call(frozenset, [set(value)]) + + @customize + def dataclass_handler(self, value, builder: Builder): + + if is_dataclass(value) and not isinstance(value, type): + + kwargs = {} + + for field in fields(value): # type: ignore + if field.repr: + field_value = getattr(value, field.name) + is_default = False + + if field.default != MISSING and field.default == field_value: + is_default = True + + if ( + field.default_factory != MISSING + and field.default_factory() == field_value + ): + is_default = True + + if is_default: + field_value = builder.create_default(field_value) + kwargs[field.name] = field_value + + return builder.create_call(type(value), [], kwargs, {}) + + @customize + def namedtuple_handler(self, value, builder: Builder): + t = type(value) + b = t.__bases__ + if len(b) != 1 or b[0] != tuple: + return + f = getattr(t, "_fields", None) + if not isinstance(f, tuple): + return + if not all(type(n) == str for n in f): + return + + # TODO handle with builder.Default + + return builder.create_call( + type(value), + [], + { + field: getattr(value, field) + for field in value._fields + if field not in value._field_defaults + or getattr(value, field) != value._field_defaults[field] + }, + {}, + ) + + @customize(tryfirst=True) + def defaultdict_handler(self, value, builder: Builder): + if isinstance(value, defaultdict): + return builder.create_call( + type(value), [value.default_factory, dict(value)], {}, {} + ) + + @customize + def unmanaged_handler(self, value, builder: Builder): + if is_unmanaged(value): + return CustomUnmanaged(value=value) + + @customize + def undefined_handler(self, value, builder: Builder): + if value is undefined: + return CustomUndefined() + + @customize(tryfirst=True) + def dirty_equals_handler(self, value, builder: Builder): + + if is_dirty_equal(value) and builder._build_new_value: + if isinstance(value, type): + return builder.create_value(value, value.__name__).with_import( + "dirty_equals", value.__name__ + ) + else: + from dirty_equals._utils import Omit + + args = [a for a in value._repr_args if a is not Omit] + kwargs = {k: a for k, a in value._repr_kwargs.items() if a is not Omit} + return builder.create_call(type(value), args, kwargs) + + @customize + def context_value_handler(self, value, builder: Builder): + if isinstance(value, ContextValue): + return builder.create_value(value.value, value.name) + + @customize + def outsource_handler(self, value, builder: Builder): + if isinstance(value, Outsourced): + return builder.create_external( + value.data, format=value.suffix, storage=value.storage + ) + + +try: + import attrs +except ImportError: # pragma: no cover + + class InlineSnapshotAttrsPlugin: + pass + +else: + + class InlineSnapshotAttrsPlugin: + @customize + def attrs_handler(self, value, builder: Builder): + + if attrs.has(type(value)): + + kwargs = {} + + for field in attrs.fields(type(value)): + if field.repr: + field_value = getattr(value, field.name) + is_default = False + + if field.default is not attrs.NOTHING: + + default_value = ( + field.default + if not isinstance(field.default, attrs.Factory) # type: ignore + else ( + field.default.factory() + if not field.default.takes_self + else field.default.factory(value) + ) + ) + + if default_value == field_value: + is_default = True + + if is_default: + field_value = builder.create_default(field_value) + + kwargs[field.name] = field_value + + return builder.create_call(type(value), [], kwargs, {}) + + +try: + import pydantic +except ImportError: # pragma: no cover + + class InlineSnapshotPydanticPlugin: + pass + +else: + # import pydantic + if pydantic.version.VERSION.startswith("1."): + # pydantic v1 + from pydantic.fields import Undefined as PydanticUndefined # type: ignore[attr-defined,no-redef] + + def get_fields(value): + return value.__fields__ + + else: + # pydantic v2 + from pydantic_core import PydanticUndefined + + def get_fields(value): + return type(value).model_fields + + from pydantic import BaseModel + + class InlineSnapshotPydanticPlugin: + @customize + def attrs_handler(self, value, builder: Builder): + + if isinstance(value, BaseModel): + + kwargs = {} + + for name, field in get_fields(value).items(): # type: ignore + if getattr(field, "repr", True): + field_value = getattr(value, name) + is_default = False + + if ( + field.default is not PydanticUndefined + and field.default == field_value + ): + is_default = True + + if ( + field.default_factory is not None + and field.default_factory() == field_value + ): + is_default = True + + if is_default: + field_value = builder.create_default(field_value) + + kwargs[name] = field_value + + return builder.create_call(type(value), [], kwargs, {}) diff --git a/src/inline_snapshot/plugin/_spec.py b/src/inline_snapshot/plugin/_spec.py new file mode 100644 index 00000000..ea770997 --- /dev/null +++ b/src/inline_snapshot/plugin/_spec.py @@ -0,0 +1,85 @@ +from functools import partial +from typing import Any + +import pluggy + +from inline_snapshot._customize import Builder +from inline_snapshot.plugin._context_value import ContextValue + +hookspec = pluggy.HookspecMarker("inline_snapshot") +hookimpl = pluggy.HookimplMarker("inline_snapshot") + +customize = partial(hookimpl, specname="customize") +""" + Registers a function as a customization hook inside inline-snapshot. + + Customization hooks allow you to control how objects are represented in snapshot code. + When inline-snapshot generates code for a value, it calls each registered customization + function in reverse order of registration until one returns a Custom object. + + **Important**: Customization handlers should be registered in your `conftest.py` file to ensure + they are loaded before your tests run. + + Args: + f: A customization handler function. See [CustomizeHandler][inline_snapshot._customize.CustomizeHandler] + for the expected signature. + + Returns: + The input function unchanged (for use as a decorator) + + Example: + Basic usage with a custom class: + + + ``` python + from inline_snapshot import customize, snapshot + + + class MyClass: + def __init__(self, arg1, arg2, key=None): + self.arg1 = arg1 + self.arg2 = arg2 + self.key_attr = key + + + @customize + def my_custom_handler(value, builder): + if isinstance(value, MyClass): + # Generate code like: MyClass(arg1, arg2, key=value) + return builder.create_call( + MyClass, [value.arg1, value.arg2], {"key": value.key_attr} + ) + return None # Let other handlers process this value + + + def test_myclass(): + obj = MyClass(42, "hello", key="world") + assert obj == snapshot(MyClass(42, "hello", key="world")) + ``` + + Note: + - **Always register handlers in `conftest.py`** to ensure they're available for all tests + - Handlers are called in **reverse order** of registration (last registered is called first) + - If no handler returns a Custom object, a default representation is used + - Use builder methods (`create_call`, `create_list`, `create_dict`, etc.) to construct representations + - Always return `None` if your handler doesn't apply to the given value type + - The builder automatically handles recursive conversion of nested values + + See Also: + - [Builder][inline_snapshot._customize.Builder]: Available builder methods + - [Custom][inline_snapshot._customize.Custom]: Base class for custom representations + """ + + +class InlineSnapshotPluginSpec: + @hookspec(firstresult=True) + def customize( + self, + value: Any, + builder: Builder, + local_vars: list[ContextValue], + global_vars: list[ContextValue], + ) -> Any: ... + + @hookspec + def format_code(self, filename, str) -> str: ... diff --git a/src/inline_snapshot/pytest_plugin.py b/src/inline_snapshot/pytest_plugin.py index f26250cb..86b3ca7e 100644 --- a/src/inline_snapshot/pytest_plugin.py +++ b/src/inline_snapshot/pytest_plugin.py @@ -126,10 +126,33 @@ class InlineSnapshotPlugin: def __init__(self): self.session = SnapshotSession() + @pytest.hookimpl(tryfirst=True) + def pytest_plugin_registered(self, plugin, manager): + """Register @customize hooks from conftest.py files""" + # Skip internal plugins + + if not hasattr(plugin, "__file__") or plugin.__file__ is None: + return + + # Only process conftest.py files + if "conftest.py" not in plugin.__file__: + return + + self.session.register_customize_hooks_from_module(plugin) + @pytest.hookimpl def pytest_configure(self, config): enter_snapshot_context() + # Register customize hooks from all already loaded conftest.py files + for plugin in config.pluginmanager.get_plugins(): + if ( + hasattr(plugin, "__file__") + and plugin.__file__ + and "conftest.py" in plugin.__file__ + ): + self.session.register_customize_hooks_from_module(plugin) + # setup default flags if is_pytest_compatible(): state().config.default_flags_tui = ["create", "review"] diff --git a/src/inline_snapshot/testing/_example.py b/src/inline_snapshot/testing/_example.py index 135f0648..69f63fb0 100644 --- a/src/inline_snapshot/testing/_example.py +++ b/src/inline_snapshot/testing/_example.py @@ -332,6 +332,29 @@ def report_error(message): report_output = StringIO() console = Console(file=report_output, width=80) + # Load and register all conftest.py files first + for conftest_path in tmp_path.rglob("conftest.py"): + print("load> conftest", conftest_path) + + # Load conftest module using importlib + spec = importlib.util.spec_from_file_location( + f"conftest_{conftest_path.parent.name}", conftest_path + ) + if spec and spec.loader: + conftest_module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = conftest_module + conftest_module.__file__ = str(conftest_path) + spec.loader.exec_module(conftest_module) + + # Register customize hooks from this conftest + session.register_customize_hooks_from_module( + conftest_module + ) + else: + raise UsageError( + f"Could not load conftest from {conftest_path}" + ) + tests_found = False for filename in tmp_path.rglob("test_*.py"): print("run> pytest-inline", filename, *args) diff --git a/tests/adapter/test_dataclass.py b/tests/adapter/test_dataclass.py index b2e01d3c..8bdfbdd0 100644 --- a/tests/adapter/test_dataclass.py +++ b/tests/adapter/test_dataclass.py @@ -487,11 +487,8 @@ def test_something(): def test_remove_positional_argument(): Example( - """\ -from inline_snapshot import snapshot -from inline_snapshot._customize import CustomCall,customize - - + { + "tests/helper.py": """\ class L: def __init__(self,*l): self.l=l @@ -500,11 +497,20 @@ def __eq__(self,other): if not isinstance(other,L): return NotImplemented return other.l==self.l +""", + "tests/conftest.py": """\ +from inline_snapshot import customize +from helper import L @customize -def handle(value,builder): +def handle_L(value,builder): if isinstance(value,L): return builder.create_call(L,value.l) +""", + "tests/test_something.py": """\ +from inline_snapshot import snapshot +from helper import L + def test_L1(): for _ in [1,2]: @@ -517,29 +523,16 @@ def test_L2(): def test_L3(): for _ in [1,2]: assert L(1,2) == snapshot(L(1, 2)), "not equal" -""" +""", + } ).run_pytest(returncode=snapshot(1)).run_pytest( ["--inline-snapshot=fix"], changed_files=snapshot( { "tests/test_something.py": """\ from inline_snapshot import snapshot -from inline_snapshot._customize import CustomCall,customize - - -class L: - def __init__(self,*l): - self.l=l - - def __eq__(self,other): - if not isinstance(other,L): - return NotImplemented - return other.l==self.l +from helper import L -@customize -def handle(value,builder): - if isinstance(value,L): - return builder.create_call(L,value.l) def test_L1(): for _ in [1,2]: diff --git a/tests/test_customize.py b/tests/test_customize.py index 7da7f97b..616becff 100644 --- a/tests/test_customize.py +++ b/tests/test_customize.py @@ -1,8 +1,6 @@ import pytest from inline_snapshot import snapshot -from inline_snapshot._customize import customize -from inline_snapshot.extra import raises from inline_snapshot.testing import Example @@ -12,34 +10,32 @@ def test_custom_dirty_equal(original, flag): Example( - f"""\ + { + "tests/conftest.py": """\ from inline_snapshot import customize from inline_snapshot import Builder -from inline_snapshot import snapshot from dirty_equals import IsStr @customize def re_handler(value, builder: Builder): if value == IsStr(regex="[a-z]"): - return builder.create_call(IsStr, [], {{"regex": "[a-z]"}}) + return builder.create_call(IsStr, [], {"regex": "[a-z]"}) +""", + "tests/test_something.py": f"""\ +from inline_snapshot import snapshot def test_a(): assert snapshot({original}) == "a" -""" +""", + } ).run_inline( [f"--inline-snapshot={flag}"], changed_files=snapshot( { "tests/test_something.py": """\ -from inline_snapshot import customize -from inline_snapshot import Builder from inline_snapshot import snapshot -from dirty_equals import IsStr -@customize -def re_handler(value, builder: Builder): - if value == IsStr(regex="[a-z]"): - return builder.create_call(IsStr, [], {"regex": "[a-z]"}) +from dirty_equals import IsStr def test_a(): assert snapshot(IsStr(regex="[a-z]")) == "a" @@ -87,44 +83,3 @@ def test(): } ), ) - - -def test_customize_argument_exceptions(): - - with raises( - snapshot("UsageError: `value` has a default value which is not supported") - ): - - @customize - def f(value=5): - pass - - with raises( - snapshot( - "UsageError: `value` is not a positional or keyword parameter, which is not supported" - ) - ): - - @customize - def f(value, /): - pass - - with raises( - snapshot( - "UsageError: `value` is not a positional or keyword parameter, which is not supported" - ) - ): - - @customize - def f(*, value): - pass - - with raises( - snapshot( - "UsageError: `my_own_arg` is an unknown parameter. allowed are ['builder', 'global_vars', 'local_vars', 'value']" - ) - ): - - @customize - def f(my_own_arg): - pass From 81f7d44b9c70f7a0674c8ae40ef11a241912d2f3 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sun, 4 Jan 2026 21:59:33 +0100 Subject: [PATCH 27/72] refactor: moved customization classes --- src/inline_snapshot/__init__.py | 6 +- src/inline_snapshot/_code_repr.py | 2 +- src/inline_snapshot/_customize/__init__.py | 432 ------------------ src/inline_snapshot/_customize/_builder.py | 136 ++++++ src/inline_snapshot/_customize/_custom.py | 2 +- .../_customize/_custom_call.py | 90 ++++ .../_customize/_custom_dict.py | 34 ++ .../_customize/_custom_external.py | 58 +++ .../_customize/_custom_sequence.py | 52 +++ .../_customize/_custom_undefined.py | 21 + .../_customize/_custom_unmanaged.py | 22 + .../_customize/_custom_value.py | 67 +++ src/inline_snapshot/_inline_snapshot.py | 2 +- src/inline_snapshot/_new_adapter.py | 18 +- .../_snapshot/collection_value.py | 4 +- src/inline_snapshot/_snapshot/dict_value.py | 4 +- src/inline_snapshot/_snapshot/eq_value.py | 2 +- .../_snapshot/generic_value.py | 6 +- .../_snapshot/min_max_value.py | 2 +- .../_snapshot/undecided_value.py | 16 +- src/inline_snapshot/plugin/_default_plugin.py | 6 +- src/inline_snapshot/plugin/_spec.py | 2 +- 22 files changed, 516 insertions(+), 468 deletions(-) create mode 100644 src/inline_snapshot/_customize/_builder.py create mode 100644 src/inline_snapshot/_customize/_custom_call.py create mode 100644 src/inline_snapshot/_customize/_custom_dict.py create mode 100644 src/inline_snapshot/_customize/_custom_external.py create mode 100644 src/inline_snapshot/_customize/_custom_sequence.py create mode 100644 src/inline_snapshot/_customize/_custom_undefined.py create mode 100644 src/inline_snapshot/_customize/_custom_unmanaged.py create mode 100644 src/inline_snapshot/_customize/_custom_value.py diff --git a/src/inline_snapshot/__init__.py b/src/inline_snapshot/__init__.py index 2c63b0e3..86edcd24 100644 --- a/src/inline_snapshot/__init__.py +++ b/src/inline_snapshot/__init__.py @@ -3,9 +3,9 @@ from ._code_repr import HasRepr from ._code_repr import customize_repr -from ._customize import Builder -from ._customize import Custom -from ._customize import CustomizeHandler +from ._customize._builder import Builder +from ._customize._custom import Custom +from ._customize._custom import CustomizeHandler from ._exceptions import UsageError from ._external._external import external from ._external._external_file import external_file diff --git a/src/inline_snapshot/_code_repr.py b/src/inline_snapshot/_code_repr.py index 4ed0fb83..a39313f6 100644 --- a/src/inline_snapshot/_code_repr.py +++ b/src/inline_snapshot/_code_repr.py @@ -79,7 +79,7 @@ def code_repr(obj): @contextmanager def mock_repr(context: AdapterContext): def new_repr(obj): - from inline_snapshot._customize import Builder + from inline_snapshot._customize._builder import Builder return only_value( Builder(_snapshot_context=context)._get_handler(obj).repr(context) diff --git a/src/inline_snapshot/_customize/__init__.py b/src/inline_snapshot/_customize/__init__.py index 150eafbf..e69de29b 100644 --- a/src/inline_snapshot/_customize/__init__.py +++ b/src/inline_snapshot/_customize/__init__.py @@ -1,432 +0,0 @@ -from __future__ import annotations - -import ast -import importlib -from abc import ABC -from abc import abstractmethod -from collections import Counter -from collections import defaultdict -from dataclasses import MISSING -from dataclasses import dataclass -from dataclasses import field -from dataclasses import fields -from dataclasses import is_dataclass -from dataclasses import replace -from functools import partial -from pathlib import Path -from pathlib import PurePath -from types import BuiltinFunctionType -from types import FunctionType -from typing import Any -from typing import Callable -from typing import Generator -from typing import Optional -from typing import TypeAlias -from typing import overload - -from inline_snapshot._adapter_context import AdapterContext -from inline_snapshot._change import Change -from inline_snapshot._change import ChangeBase -from inline_snapshot._change import ExternalChange -from inline_snapshot._code_repr import HasRepr -from inline_snapshot._code_repr import value_code_repr -from inline_snapshot._compare_context import compare_context -from inline_snapshot._compare_context import compare_only -from inline_snapshot._customize._custom import CustomizeHandler -from inline_snapshot._external._external_location import ExternalLocation -from inline_snapshot._external._format._protocol import get_format_handler -from inline_snapshot._external._format._protocol import get_format_handler_from_suffix -from inline_snapshot._global_state import state -from inline_snapshot._partial_call import partial_call -from inline_snapshot._partial_call import partial_check_args -from inline_snapshot._sentinels import undefined -from inline_snapshot._unmanaged import is_dirty_equal -from inline_snapshot._unmanaged import is_unmanaged -from inline_snapshot._utils import clone -from inline_snapshot._utils import triple_quote -from inline_snapshot.plugin._context_value import ContextValue - -from ._custom import Custom -from ._custom import CustomizeHandler - - -@dataclass(frozen=True) -class CustomDefault(Custom): - value: Custom = field(compare=False) - - def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: - yield from () # pragma: no cover - # this should never be called because default values are never converted into code - assert False - - def map(self, f): - return self.value.map(f) - - def _needed_imports(self): - yield from self.value._needed_imports() - - -@dataclass() -class CustomUnmanaged(Custom): - value: Any - - def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: - yield from () # pragma: no cover - return "'unmanaged'" - - def map(self, f): - return f(self.value) - - -class CustomUndefined(Custom): - def __init__(self): - self.value = undefined - - def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: - yield from () - return "..." - - def map(self, f): - return f(undefined) - - -def unwrap_default(value): - if isinstance(value, CustomDefault): - return value.value - return value - - -@dataclass(frozen=True) -class CustomCall(Custom): - node_type = ast.Call - _function: Custom = field(compare=False) - _args: list[Custom] = field(compare=False) - _kwargs: dict[str, Custom] = field(compare=False) - _kwonly: dict[str, Custom] = field(default_factory=dict, compare=False) - - def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: - args = [] - for a in self.args: - v = yield from a.repr(context) - args.append(v) - - for k, v in self.kwargs.items(): - if not isinstance(v, CustomDefault): - value = yield from v.repr(context) - args.append(f"{k}={value}") - - return f"{yield from self._function.repr(context)}({', '.join(args)})" - - @property - def args(self): - return self._args - - @property - def all_pos_args(self): - return [*self._args, *self._kwargs.values()] - - @property - def kwargs(self): - return {**self._kwargs, **self._kwonly} - - def argument(self, pos_or_str): - if isinstance(pos_or_str, int): - return unwrap_default(self.all_pos_args[pos_or_str]) - else: - return unwrap_default(self.kwargs[pos_or_str]) - - def map(self, f): - return self._function.map(f)( - *[f(x.map(f)) for x in self._args], - **{k: f(v.map(f)) for k, v in self.kwargs.items()}, - ) - - def _needed_imports(self): - yield from self._function._needed_imports() - for v in self._args: - yield from v._needed_imports() - - for v in self._kwargs.values(): - yield from v._needed_imports() - - for v in self._kwonly.values(): - yield from v._needed_imports() - - -class CustomSequenceTypes: - trailing_comma: bool - braces: str - value_type: type - - -@dataclass(frozen=True) -class CustomSequence(Custom, CustomSequenceTypes): - value: list[Custom] = field(compare=False) - - def map(self, f): - return f(self.value_type([x.map(f) for x in self.value])) - - def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: - values = [] - for v in self.value: - value = yield from v.repr(context) - values.append(value) - - trailing_comma = self.trailing_comma and len(self.value) == 1 - return f"{self.braces[0]}{', '.join(values)}{', ' if trailing_comma else ''}{self.braces[1]}" - - def _needed_imports(self): - for v in self.value: - yield from v._needed_imports() - - -class CustomList(CustomSequence): - node_type = ast.List - value_type = list - braces = "[]" - trailing_comma = False - - -class CustomTuple(CustomSequence): - node_type = ast.Tuple - value_type = tuple - braces = "()" - trailing_comma = True - - -@dataclass(frozen=True) -class CustomDict(Custom): - node_type = ast.Dict - value: dict[Custom, Custom] = field(compare=False) - - def map(self, f): - return f({k.map(f): v.map(f) for k, v in self.value.items()}) - - def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: - values = [] - for k, v in self.value.items(): - key = yield from k.repr(context) - value = yield from v.repr(context) - values.append(f"{key}: {value}") - - return f"{{{ ', '.join(values)}}}" - - def _needed_imports(self): - for k, v in self.value.items(): - yield from k._needed_imports() - yield from v._needed_imports() - - -@dataclass(frozen=True) -class CustomExternal(Custom): - value: Any - format: str | None = None - storage: str | None = None - - def map(self, f): - return f(self.value) - - def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: - storage_name = self.storage or state().config.default_storage - - format = get_format_handler(self.value, self.format or "") - - location = ExternalLocation( - storage=storage_name, - stem="", - suffix=self.format or format.suffix, - filename=Path(context.file.filename), - qualname=context.qualname, - ) - - tmp_file = state().new_tmp_path(location.suffix) - - storage = state().all_storages[storage_name] - - format.encode(self.value, tmp_file) - location = storage.new_location(location, tmp_file) - - yield ExternalChange( - "create", - tmp_file, - ExternalLocation.from_name("", context=context), - location, - format, - ) - - return f"external({location.to_str()!r})" - - def _needed_imports(self): - return [("inline_snapshot", ["external"])] - - -class CustomValue(Custom): - def __init__(self, value, repr_str=None): - assert not isinstance(value, Custom) - value = clone(value) - self._imports = defaultdict(list) - - if repr_str is None: - self.repr_str = value_code_repr(value) - - try: - ast.parse(self.repr_str) - except SyntaxError: - self.repr_str = HasRepr(type(value), self.repr_str).__repr__() - self.with_import("inline_snapshot", "HasRepr") - else: - self.repr_str = repr_str - - self.value = value - - super().__init__() - - def map(self, f): - return f(self.value) - - def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: - yield from () - return self.repr_str - - def __repr__(self): - return f"CustomValue({self.repr_str})" - - def _needed_imports(self): - yield from self._imports.items() - - def with_import(self, module, name, simplify=True): - value = getattr(importlib.import_module(module), name) - if simplify: - parts = module.split(".") - while len(parts) >= 2: - if ( - getattr(importlib.import_module(".".join(parts[:-1])), name, None) - == value - ): - parts.pop() - else: - break - module = ".".join(parts) - - self._imports[module].append(name) - - return self - - -@dataclass -class Builder: - _snapshot_context: AdapterContext - _build_new_value: bool = False - - def _get_handler(self, v) -> Custom: - - from inline_snapshot._global_state import state - - if ( - self._snapshot_context is not None - and (frame := self._snapshot_context.frame) is not None - ): - local_vars = [ - ContextValue(var_name, var_value) - for var_name, var_value in frame.locals.items() - if "@" not in var_name - ] - global_vars = [ - ContextValue(var_name, var_value) - for var_name, var_value in frame.globals.items() - if "@" not in var_name - ] - else: - local_vars = [] - global_vars = [] - - result = v - - while not isinstance(result, Custom): - with compare_context(): - r = state().pm.hook.customize( - value=result, - builder=self, - local_vars=local_vars, - global_vars=global_vars, - ) - if r is None: - result = CustomValue(result) - else: - result = r - - result.__dict__["original_value"] = v - return result - - def create_external(self, value: Any, format: str | None, storage: str | None): - - return CustomExternal(value, format=format, storage=storage) - - def create_list(self, value) -> Custom: - """ - Creates an intermediate node for a list-expression which can be used as a result for your customization function. - - `create_list([1,2,3])` becomes `[1,2,3]` in the code. - List elements are recursively converted into CustomNodes. - """ - custom = [self._get_handler(v) for v in value] - return CustomList(value=custom) - - def create_tuple(self, value) -> Custom: - """ - Creates an intermediate node for a tuple-expression which can be used as a result for your customization function. - - `create_tuple((1, 2, 3))` becomes `(1, 2, 3)` in the code. - Tuple elements are recursively converted into CustomNodes. - """ - custom = [self._get_handler(v) for v in value] - return CustomTuple(value=custom) - - def create_call( - self, function, posonly_args=[], kwargs={}, kwonly_args={} - ) -> Custom: - """ - Creates an intermediate node for a function call expression which can be used as a result for your customization function. - - `create_call(MyClass, [arg1, arg2], {'key': value})` becomes `MyClass(arg1, arg2, key=value)` in the code. - Function, arguments, and keyword arguments are recursively converted into CustomNodes. - """ - function = self._get_handler(function) - posonly_args = [self._get_handler(arg) for arg in posonly_args] - kwargs = {k: self._get_handler(arg) for k, arg in kwargs.items()} - kwonly_args = {k: self._get_handler(arg) for k, arg in kwonly_args.items()} - - return CustomCall( - _function=function, - _args=posonly_args, - _kwargs=kwargs, - _kwonly=kwonly_args, - ) - - def create_default(self, value) -> Custom: - """ - Creates an intermediate node for a default value which can be used as a result for your customization function. - - Default values are not included in the generated code when they match the actual default. - The value is recursively converted into a CustomNode. - """ - return CustomDefault(value=self._get_handler(value)) - - def create_dict(self, value) -> Custom: - """ - Creates an intermediate node for a dict-expression which can be used as a result for your customization function. - - `create_dict({'key': 'value'})` becomes `{'key': 'value'}` in the code. - Dict keys and values are recursively converted into CustomNodes. - """ - custom = {self._get_handler(k): self._get_handler(v) for k, v in value.items()} - return CustomDict(value=custom) - - def create_value(self, value, repr: str | None = None) -> CustomValue: - """ - Creates an intermediate node for a value with a custom representation which can be used as a result for your customization function. - - `create_value(my_obj, 'MyClass')` becomes `MyClass` in the code. - Use this when you want to control the exact string representation of a value. - """ - return CustomValue(value, repr) diff --git a/src/inline_snapshot/_customize/_builder.py b/src/inline_snapshot/_customize/_builder.py new file mode 100644 index 00000000..33e8ed21 --- /dev/null +++ b/src/inline_snapshot/_customize/_builder.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._compare_context import compare_context +from inline_snapshot.plugin._context_value import ContextValue + +from ._custom import Custom +from ._custom_call import CustomCall +from ._custom_call import CustomDefault +from ._custom_dict import CustomDict +from ._custom_external import CustomExternal +from ._custom_sequence import CustomList +from ._custom_sequence import CustomTuple +from ._custom_value import CustomValue + + +@dataclass +class Builder: + _snapshot_context: AdapterContext + _build_new_value: bool = False + + def _get_handler(self, v) -> Custom: + + from inline_snapshot._global_state import state + + if ( + self._snapshot_context is not None + and (frame := self._snapshot_context.frame) is not None + ): + local_vars = [ + ContextValue(var_name, var_value) + for var_name, var_value in frame.locals.items() + if "@" not in var_name + ] + global_vars = [ + ContextValue(var_name, var_value) + for var_name, var_value in frame.globals.items() + if "@" not in var_name + ] + else: + local_vars = [] + global_vars = [] + + result = v + + while not isinstance(result, Custom): + with compare_context(): + r = state().pm.hook.customize( + value=result, + builder=self, + local_vars=local_vars, + global_vars=global_vars, + ) + if r is None: + result = CustomValue(result) + else: + result = r + + result.__dict__["original_value"] = v + return result + + def create_external(self, value: Any, format: str | None, storage: str | None): + + return CustomExternal(value, format=format, storage=storage) + + def create_list(self, value) -> Custom: + """ + Creates an intermediate node for a list-expression which can be used as a result for your customization function. + + `create_list([1,2,3])` becomes `[1,2,3]` in the code. + List elements are recursively converted into CustomNodes. + """ + custom = [self._get_handler(v) for v in value] + return CustomList(value=custom) + + def create_tuple(self, value) -> Custom: + """ + Creates an intermediate node for a tuple-expression which can be used as a result for your customization function. + + `create_tuple((1, 2, 3))` becomes `(1, 2, 3)` in the code. + Tuple elements are recursively converted into CustomNodes. + """ + custom = [self._get_handler(v) for v in value] + return CustomTuple(value=custom) + + def create_call( + self, function, posonly_args=[], kwargs={}, kwonly_args={} + ) -> Custom: + """ + Creates an intermediate node for a function call expression which can be used as a result for your customization function. + + `create_call(MyClass, [arg1, arg2], {'key': value})` becomes `MyClass(arg1, arg2, key=value)` in the code. + Function, arguments, and keyword arguments are recursively converted into CustomNodes. + """ + function = self._get_handler(function) + posonly_args = [self._get_handler(arg) for arg in posonly_args] + kwargs = {k: self._get_handler(arg) for k, arg in kwargs.items()} + kwonly_args = {k: self._get_handler(arg) for k, arg in kwonly_args.items()} + + return CustomCall( + _function=function, + _args=posonly_args, + _kwargs=kwargs, + _kwonly=kwonly_args, + ) + + def create_default(self, value) -> Custom: + """ + Creates an intermediate node for a default value which can be used as a result for your customization function. + + Default values are not included in the generated code when they match the actual default. + The value is recursively converted into a CustomNode. + """ + return CustomDefault(value=self._get_handler(value)) + + def create_dict(self, value) -> Custom: + """ + Creates an intermediate node for a dict-expression which can be used as a result for your customization function. + + `create_dict({'key': 'value'})` becomes `{'key': 'value'}` in the code. + Dict keys and values are recursively converted into CustomNodes. + """ + custom = {self._get_handler(k): self._get_handler(v) for k, v in value.items()} + return CustomDict(value=custom) + + def create_value(self, value, repr: str | None = None) -> CustomValue: + """ + Creates an intermediate node for a value with a custom representation which can be used as a result for your customization function. + + `create_value(my_obj, 'MyClass')` becomes `MyClass` in the code. + Use this when you want to control the exact string representation of a value. + """ + return CustomValue(value, repr) diff --git a/src/inline_snapshot/_customize/_custom.py b/src/inline_snapshot/_customize/_custom.py index 80dbe65a..e1847522 100644 --- a/src/inline_snapshot/_customize/_custom.py +++ b/src/inline_snapshot/_customize/_custom.py @@ -13,7 +13,7 @@ from inline_snapshot._change import ChangeBase if TYPE_CHECKING: - from inline_snapshot._customize import Builder + from inline_snapshot._customize._builder import Builder class Custom(ABC): diff --git a/src/inline_snapshot/_customize/_custom_call.py b/src/inline_snapshot/_customize/_custom_call.py new file mode 100644 index 00000000..ddd69372 --- /dev/null +++ b/src/inline_snapshot/_customize/_custom_call.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import ast +from dataclasses import dataclass +from dataclasses import field +from typing import Generator + +from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._change import ChangeBase + +from ._custom import Custom + + +@dataclass(frozen=True) +class CustomDefault(Custom): + value: Custom = field(compare=False) + + def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + yield from () # pragma: no cover + # this should never be called because default values are never converted into code + assert False + + def map(self, f): + return self.value.map(f) + + def _needed_imports(self): + yield from self.value._needed_imports() + + +def unwrap_default(value): + if isinstance(value, CustomDefault): + return value.value + return value + + +@dataclass(frozen=True) +class CustomCall(Custom): + node_type = ast.Call + _function: Custom = field(compare=False) + _args: list[Custom] = field(compare=False) + _kwargs: dict[str, Custom] = field(compare=False) + _kwonly: dict[str, Custom] = field(default_factory=dict, compare=False) + + def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + args = [] + for a in self.args: + v = yield from a.repr(context) + args.append(v) + + for k, v in self.kwargs.items(): + if not isinstance(v, CustomDefault): + value = yield from v.repr(context) + args.append(f"{k}={value}") + + return f"{yield from self._function.repr(context)}({', '.join(args)})" + + @property + def args(self): + return self._args + + @property + def all_pos_args(self): + return [*self._args, *self._kwargs.values()] + + @property + def kwargs(self): + return {**self._kwargs, **self._kwonly} + + def argument(self, pos_or_str): + if isinstance(pos_or_str, int): + return unwrap_default(self.all_pos_args[pos_or_str]) + else: + return unwrap_default(self.kwargs[pos_or_str]) + + def map(self, f): + return self._function.map(f)( + *[f(x.map(f)) for x in self._args], + **{k: f(v.map(f)) for k, v in self.kwargs.items()}, + ) + + def _needed_imports(self): + yield from self._function._needed_imports() + for v in self._args: + yield from v._needed_imports() + + for v in self._kwargs.values(): + yield from v._needed_imports() + + for v in self._kwonly.values(): + yield from v._needed_imports() diff --git a/src/inline_snapshot/_customize/_custom_dict.py b/src/inline_snapshot/_customize/_custom_dict.py new file mode 100644 index 00000000..fafd49bd --- /dev/null +++ b/src/inline_snapshot/_customize/_custom_dict.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import ast +from dataclasses import dataclass +from dataclasses import field +from typing import Generator + +from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._change import ChangeBase + +from ._custom import Custom + + +@dataclass(frozen=True) +class CustomDict(Custom): + node_type = ast.Dict + value: dict[Custom, Custom] = field(compare=False) + + def map(self, f): + return f({k.map(f): v.map(f) for k, v in self.value.items()}) + + def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + values = [] + for k, v in self.value.items(): + key = yield from k.repr(context) + value = yield from v.repr(context) + values.append(f"{key}: {value}") + + return f"{{{ ', '.join(values)}}}" + + def _needed_imports(self): + for k, v in self.value.items(): + yield from k._needed_imports() + yield from v._needed_imports() diff --git a/src/inline_snapshot/_customize/_custom_external.py b/src/inline_snapshot/_customize/_custom_external.py new file mode 100644 index 00000000..0bd05c45 --- /dev/null +++ b/src/inline_snapshot/_customize/_custom_external.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any +from typing import Generator + +from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._change import ChangeBase +from inline_snapshot._change import ExternalChange +from inline_snapshot._external._external_location import ExternalLocation +from inline_snapshot._external._format._protocol import get_format_handler +from inline_snapshot._global_state import state + +from ._custom import Custom + + +@dataclass(frozen=True) +class CustomExternal(Custom): + value: Any + format: str | None = None + storage: str | None = None + + def map(self, f): + return f(self.value) + + def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + storage_name = self.storage or state().config.default_storage + + format = get_format_handler(self.value, self.format or "") + + location = ExternalLocation( + storage=storage_name, + stem="", + suffix=self.format or format.suffix, + filename=Path(context.file.filename), + qualname=context.qualname, + ) + + tmp_file = state().new_tmp_path(location.suffix) + + storage = state().all_storages[storage_name] + + format.encode(self.value, tmp_file) + location = storage.new_location(location, tmp_file) + + yield ExternalChange( + "create", + tmp_file, + ExternalLocation.from_name("", context=context), + location, + format, + ) + + return f"external({location.to_str()!r})" + + def _needed_imports(self): + return [("inline_snapshot", ["external"])] diff --git a/src/inline_snapshot/_customize/_custom_sequence.py b/src/inline_snapshot/_customize/_custom_sequence.py new file mode 100644 index 00000000..ad14878e --- /dev/null +++ b/src/inline_snapshot/_customize/_custom_sequence.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import ast +from dataclasses import dataclass +from dataclasses import field +from typing import Generator + +from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._change import ChangeBase + +from ._custom import Custom + + +class CustomSequenceTypes: + trailing_comma: bool + braces: str + value_type: type + + +@dataclass(frozen=True) +class CustomSequence(Custom, CustomSequenceTypes): + value: list[Custom] = field(compare=False) + + def map(self, f): + return f(self.value_type([x.map(f) for x in self.value])) + + def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + values = [] + for v in self.value: + value = yield from v.repr(context) + values.append(value) + + trailing_comma = self.trailing_comma and len(self.value) == 1 + return f"{self.braces[0]}{', '.join(values)}{', ' if trailing_comma else ''}{self.braces[1]}" + + def _needed_imports(self): + for v in self.value: + yield from v._needed_imports() + + +class CustomList(CustomSequence): + node_type = ast.List + value_type = list + braces = "[]" + trailing_comma = False + + +class CustomTuple(CustomSequence): + node_type = ast.Tuple + value_type = tuple + braces = "()" + trailing_comma = True diff --git a/src/inline_snapshot/_customize/_custom_undefined.py b/src/inline_snapshot/_customize/_custom_undefined.py new file mode 100644 index 00000000..f7b06344 --- /dev/null +++ b/src/inline_snapshot/_customize/_custom_undefined.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from typing import Generator + +from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._change import ChangeBase +from inline_snapshot._sentinels import undefined + +from ._custom import Custom + + +class CustomUndefined(Custom): + def __init__(self): + self.value = undefined + + def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + yield from () + return "..." + + def map(self, f): + return f(undefined) diff --git a/src/inline_snapshot/_customize/_custom_unmanaged.py b/src/inline_snapshot/_customize/_custom_unmanaged.py new file mode 100644 index 00000000..e909ea6d --- /dev/null +++ b/src/inline_snapshot/_customize/_custom_unmanaged.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any +from typing import Generator + +from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._change import ChangeBase + +from ._custom import Custom + + +@dataclass() +class CustomUnmanaged(Custom): + value: Any + + def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + yield from () # pragma: no cover + return "'unmanaged'" + + def map(self, f): + return f(self.value) diff --git a/src/inline_snapshot/_customize/_custom_value.py b/src/inline_snapshot/_customize/_custom_value.py new file mode 100644 index 00000000..2842c419 --- /dev/null +++ b/src/inline_snapshot/_customize/_custom_value.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import ast +import importlib +from collections import defaultdict +from typing import Generator + +from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._change import ChangeBase +from inline_snapshot._code_repr import HasRepr +from inline_snapshot._code_repr import value_code_repr +from inline_snapshot._utils import clone + +from ._custom import Custom + + +class CustomValue(Custom): + def __init__(self, value, repr_str=None): + assert not isinstance(value, Custom) + value = clone(value) + self._imports = defaultdict(list) + + if repr_str is None: + self.repr_str = value_code_repr(value) + + try: + ast.parse(self.repr_str) + except SyntaxError: + self.repr_str = HasRepr(type(value), self.repr_str).__repr__() + self.with_import("inline_snapshot", "HasRepr") + else: + self.repr_str = repr_str + + self.value = value + + super().__init__() + + def map(self, f): + return f(self.value) + + def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + yield from () + return self.repr_str + + def __repr__(self): + return f"CustomValue({self.repr_str})" + + def _needed_imports(self): + yield from self._imports.items() + + def with_import(self, module, name, simplify=True): + value = getattr(importlib.import_module(module), name) + if simplify: + parts = module.split(".") + while len(parts) >= 2: + if ( + getattr(importlib.import_module(".".join(parts[:-1])), name, None) + == value + ): + parts.pop() + else: + break + module = ".".join(parts) + + self._imports[module].append(name) + + return self diff --git a/src/inline_snapshot/_inline_snapshot.py b/src/inline_snapshot/_inline_snapshot.py index dc829d91..37d8804d 100644 --- a/src/inline_snapshot/_inline_snapshot.py +++ b/src/inline_snapshot/_inline_snapshot.py @@ -10,7 +10,7 @@ from inline_snapshot._adapter_context import AdapterContext from inline_snapshot._adapter_context import FrameContext -from inline_snapshot._customize import CustomUndefined +from inline_snapshot._customize._custom_undefined import CustomUndefined from inline_snapshot._source_file import SourceFile from inline_snapshot._types import SnapshotRefBase diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index 12b7eda0..8ebd978a 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -17,15 +17,15 @@ from inline_snapshot._change import Replace from inline_snapshot._change import RequiredImports from inline_snapshot._compare_context import compare_context -from inline_snapshot._customize import Custom -from inline_snapshot._customize import CustomCall -from inline_snapshot._customize import CustomDefault -from inline_snapshot._customize import CustomDict -from inline_snapshot._customize import CustomList -from inline_snapshot._customize import CustomSequence -from inline_snapshot._customize import CustomUndefined -from inline_snapshot._customize import CustomUnmanaged -from inline_snapshot._customize import CustomValue +from inline_snapshot._customize._custom import Custom +from inline_snapshot._customize._custom_call import CustomCall +from inline_snapshot._customize._custom_call import CustomDefault +from inline_snapshot._customize._custom_dict import CustomDict +from inline_snapshot._customize._custom_sequence import CustomList +from inline_snapshot._customize._custom_sequence import CustomSequence +from inline_snapshot._customize._custom_undefined import CustomUndefined +from inline_snapshot._customize._custom_unmanaged import CustomUnmanaged +from inline_snapshot._customize._custom_value import CustomValue from inline_snapshot._exceptions import UsageError from inline_snapshot._generator_utils import only_value from inline_snapshot.syntax_warnings import InlineSnapshotInfo diff --git a/src/inline_snapshot/_snapshot/collection_value.py b/src/inline_snapshot/_snapshot/collection_value.py index 704c9dd1..b36d53a8 100644 --- a/src/inline_snapshot/_snapshot/collection_value.py +++ b/src/inline_snapshot/_snapshot/collection_value.py @@ -2,8 +2,8 @@ from typing import Generator from typing import Iterator -from inline_snapshot._customize import CustomList -from inline_snapshot._customize import CustomUndefined +from inline_snapshot._customize._custom_sequence import CustomList +from inline_snapshot._customize._custom_undefined import CustomUndefined from .._change import ChangeBase from .._change import Delete diff --git a/src/inline_snapshot/_snapshot/dict_value.py b/src/inline_snapshot/_snapshot/dict_value.py index 5424b320..aa98c941 100644 --- a/src/inline_snapshot/_snapshot/dict_value.py +++ b/src/inline_snapshot/_snapshot/dict_value.py @@ -2,8 +2,8 @@ from typing import Generator from typing import Iterator -from inline_snapshot._customize import CustomDict -from inline_snapshot._customize import CustomUndefined +from inline_snapshot._customize._custom_dict import CustomDict +from inline_snapshot._customize._custom_undefined import CustomUndefined from .._adapter_context import AdapterContext from .._change import ChangeBase diff --git a/src/inline_snapshot/_snapshot/eq_value.py b/src/inline_snapshot/_snapshot/eq_value.py index 0a1841ee..d1502928 100644 --- a/src/inline_snapshot/_snapshot/eq_value.py +++ b/src/inline_snapshot/_snapshot/eq_value.py @@ -2,7 +2,7 @@ from typing import Iterator from typing import List -from inline_snapshot._customize import CustomUndefined +from inline_snapshot._customize._custom_undefined import CustomUndefined from inline_snapshot._generator_utils import split_gen from inline_snapshot._new_adapter import NewAdapter diff --git a/src/inline_snapshot/_snapshot/generic_value.py b/src/inline_snapshot/_snapshot/generic_value.py index 37546b86..8d803b21 100644 --- a/src/inline_snapshot/_snapshot/generic_value.py +++ b/src/inline_snapshot/_snapshot/generic_value.py @@ -4,9 +4,9 @@ from inline_snapshot._adapter_context import AdapterContext from inline_snapshot._code_repr import mock_repr -from inline_snapshot._customize import Builder -from inline_snapshot._customize import Custom -from inline_snapshot._customize import CustomUndefined +from inline_snapshot._customize._builder import Builder +from inline_snapshot._customize._custom import Custom +from inline_snapshot._customize._custom_undefined import CustomUndefined from inline_snapshot._new_adapter import reeval from .._change import ChangeBase diff --git a/src/inline_snapshot/_snapshot/min_max_value.py b/src/inline_snapshot/_snapshot/min_max_value.py index c22ec34f..1d70dba6 100644 --- a/src/inline_snapshot/_snapshot/min_max_value.py +++ b/src/inline_snapshot/_snapshot/min_max_value.py @@ -1,7 +1,7 @@ from typing import Generator from typing import Iterator -from inline_snapshot._customize import CustomUndefined +from inline_snapshot._customize._custom_undefined import CustomUndefined from .._change import ChangeBase from .._change import Replace diff --git a/src/inline_snapshot/_snapshot/undecided_value.py b/src/inline_snapshot/_snapshot/undecided_value.py index f665a706..f0e3eddc 100644 --- a/src/inline_snapshot/_snapshot/undecided_value.py +++ b/src/inline_snapshot/_snapshot/undecided_value.py @@ -3,14 +3,14 @@ from typing import Iterator from inline_snapshot._compare_context import compare_only -from inline_snapshot._customize import Custom -from inline_snapshot._customize import CustomCall -from inline_snapshot._customize import CustomDict -from inline_snapshot._customize import CustomList -from inline_snapshot._customize import CustomTuple -from inline_snapshot._customize import CustomUndefined -from inline_snapshot._customize import CustomUnmanaged -from inline_snapshot._customize import CustomValue +from inline_snapshot._customize._custom import Custom +from inline_snapshot._customize._custom_call import CustomCall +from inline_snapshot._customize._custom_dict import CustomDict +from inline_snapshot._customize._custom_sequence import CustomList +from inline_snapshot._customize._custom_sequence import CustomTuple +from inline_snapshot._customize._custom_undefined import CustomUndefined +from inline_snapshot._customize._custom_unmanaged import CustomUnmanaged +from inline_snapshot._customize._custom_value import CustomValue from inline_snapshot._new_adapter import NewAdapter from inline_snapshot._new_adapter import warn_star_expression from inline_snapshot._unmanaged import is_unmanaged diff --git a/src/inline_snapshot/plugin/_default_plugin.py b/src/inline_snapshot/plugin/_default_plugin.py index d85d7b3b..f9501cea 100644 --- a/src/inline_snapshot/plugin/_default_plugin.py +++ b/src/inline_snapshot/plugin/_default_plugin.py @@ -9,9 +9,9 @@ from types import BuiltinFunctionType from types import FunctionType -from inline_snapshot._customize import Builder -from inline_snapshot._customize import CustomUndefined -from inline_snapshot._customize import CustomUnmanaged +from inline_snapshot._customize._builder import Builder +from inline_snapshot._customize._custom_undefined import CustomUndefined +from inline_snapshot._customize._custom_unmanaged import CustomUnmanaged from inline_snapshot._external._outsource import Outsourced from inline_snapshot._sentinels import undefined from inline_snapshot._unmanaged import is_dirty_equal diff --git a/src/inline_snapshot/plugin/_spec.py b/src/inline_snapshot/plugin/_spec.py index ea770997..ddf63a9f 100644 --- a/src/inline_snapshot/plugin/_spec.py +++ b/src/inline_snapshot/plugin/_spec.py @@ -3,7 +3,7 @@ import pluggy -from inline_snapshot._customize import Builder +from inline_snapshot._customize._builder import Builder from inline_snapshot.plugin._context_value import ContextValue hookspec = pluggy.HookspecMarker("inline_snapshot") From 512652e43bf9a808561558b063720bc52e726f73 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sun, 4 Jan 2026 22:18:06 +0100 Subject: [PATCH 28/72] refactor: changed _needed_imports signature --- .../_customize/_custom_external.py | 2 +- .../_customize/_custom_value.py | 32 +++++++++---------- src/inline_snapshot/_inline_snapshot.py | 4 +-- src/inline_snapshot/_new_adapter.py | 4 +-- src/inline_snapshot/_partial_call.py | 26 --------------- src/inline_snapshot/plugin/_spec.py | 9 +++--- 6 files changed, 26 insertions(+), 51 deletions(-) delete mode 100644 src/inline_snapshot/_partial_call.py diff --git a/src/inline_snapshot/_customize/_custom_external.py b/src/inline_snapshot/_customize/_custom_external.py index 0bd05c45..0ef47092 100644 --- a/src/inline_snapshot/_customize/_custom_external.py +++ b/src/inline_snapshot/_customize/_custom_external.py @@ -55,4 +55,4 @@ def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: return f"external({location.to_str()!r})" def _needed_imports(self): - return [("inline_snapshot", ["external"])] + return [("inline_snapshot", "external")] diff --git a/src/inline_snapshot/_customize/_custom_value.py b/src/inline_snapshot/_customize/_custom_value.py index 2842c419..9118f7e0 100644 --- a/src/inline_snapshot/_customize/_custom_value.py +++ b/src/inline_snapshot/_customize/_custom_value.py @@ -2,7 +2,6 @@ import ast import importlib -from collections import defaultdict from typing import Generator from inline_snapshot._adapter_context import AdapterContext @@ -14,11 +13,23 @@ from ._custom import Custom +def _simplify_module_path(module: str, name: str) -> str: + """Simplify module path by finding the shortest import path for a given name.""" + value = getattr(importlib.import_module(module), name) + parts = module.split(".") + while len(parts) >= 2: + if getattr(importlib.import_module(".".join(parts[:-1])), name, None) == value: + parts.pop() + else: + break + return ".".join(parts) + + class CustomValue(Custom): def __init__(self, value, repr_str=None): assert not isinstance(value, Custom) value = clone(value) - self._imports = defaultdict(list) + self._imports = [] if repr_str is None: self.repr_str = value_code_repr(value) @@ -46,22 +57,11 @@ def __repr__(self): return f"CustomValue({self.repr_str})" def _needed_imports(self): - yield from self._imports.items() + yield from self._imports def with_import(self, module, name, simplify=True): - value = getattr(importlib.import_module(module), name) if simplify: - parts = module.split(".") - while len(parts) >= 2: - if ( - getattr(importlib.import_module(".".join(parts[:-1])), name, None) - == value - ): - parts.pop() - else: - break - module = ".".join(parts) - - self._imports[module].append(name) + module = _simplify_module_path(module, name) + self._imports.append([module, name]) return self diff --git a/src/inline_snapshot/_inline_snapshot.py b/src/inline_snapshot/_inline_snapshot.py index 37d8804d..3e9a61a2 100644 --- a/src/inline_snapshot/_inline_snapshot.py +++ b/src/inline_snapshot/_inline_snapshot.py @@ -152,8 +152,8 @@ def _changes(self) -> Iterator[ChangeBase]: ) imports: dict[str, set[str]] = defaultdict(set) - for module, names in self._value._needed_imports(): - imports[module] |= set(names) + for module, name in self._value._needed_imports(): + imports[module].add(name) yield RequiredImports( flag="create", file=self._value._file, imports=imports diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index 8ebd978a..b2754f55 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -224,8 +224,8 @@ def compare_CustomValue( def needed_imports(value: Custom): imports: dict[str, set] = defaultdict(set) - for module, names in value._needed_imports(): - imports[module] |= set(names) + for module, name in value._needed_imports(): + imports[module].add(name) return imports if imports := needed_imports(new_value): diff --git a/src/inline_snapshot/_partial_call.py b/src/inline_snapshot/_partial_call.py deleted file mode 100644 index 87a1020a..00000000 --- a/src/inline_snapshot/_partial_call.py +++ /dev/null @@ -1,26 +0,0 @@ -import inspect - -from inline_snapshot._exceptions import UsageError - - -def partial_check_args(func, allowed): - sign = inspect.signature(func) - for p in sign.parameters.values(): - if p.default is not inspect.Parameter.empty: - raise UsageError(f"`{p.name}` has a default value which is not supported") - - if p.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD: - raise UsageError( - f"`{p.name}` is not a positional or keyword parameter, which is not supported" - ) - - if p.name not in allowed: - raise UsageError( - f"`{p.name}` is an unknown parameter. allowed are {sorted(allowed)}" - ) - - -def partial_call(func, args): - sign = inspect.signature(func) - used = [p.name for p in sign.parameters.values()] - return func(**{n: args[n] for n in used}) diff --git a/src/inline_snapshot/plugin/_spec.py b/src/inline_snapshot/plugin/_spec.py index ddf63a9f..28361542 100644 --- a/src/inline_snapshot/plugin/_spec.py +++ b/src/inline_snapshot/plugin/_spec.py @@ -1,5 +1,6 @@ from functools import partial from typing import Any +from typing import List import pluggy @@ -77,9 +78,9 @@ def customize( self, value: Any, builder: Builder, - local_vars: list[ContextValue], - global_vars: list[ContextValue], + local_vars: List[ContextValue], + global_vars: List[ContextValue], ) -> Any: ... - @hookspec - def format_code(self, filename, str) -> str: ... + # @hookspec + # def format_code(self, filename, str) -> str: ... From 931ed3433102e745918f0125f7faf24599e60cda Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sun, 4 Jan 2026 22:40:46 +0100 Subject: [PATCH 29/72] fix: fixed type errors --- src/inline_snapshot/_global_state.py | 18 ++++++++++++++---- src/inline_snapshot/plugin/_default_plugin.py | 7 ++----- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/inline_snapshot/_global_state.py b/src/inline_snapshot/_global_state.py index f8df70ee..3338feb2 100644 --- a/src/inline_snapshot/_global_state.py +++ b/src/inline_snapshot/_global_state.py @@ -86,13 +86,23 @@ def enter_snapshot_context(): _current.pm.add_hookspecs(InlineSnapshotPluginSpec) _current.pm.load_setuptools_entrypoints("inline_snapshot") - from .plugin._default_plugin import InlineSnapshotAttrsPlugin from .plugin._default_plugin import InlineSnapshotPlugin - from .plugin._default_plugin import InlineSnapshotPydanticPlugin _current.pm.register(InlineSnapshotPlugin()) - _current.pm.register(InlineSnapshotAttrsPlugin()) - _current.pm.register(InlineSnapshotPydanticPlugin()) + + try: + from .plugin._default_plugin import InlineSnapshotAttrsPlugin + except ImportError: + pass + else: + _current.pm.register(InlineSnapshotAttrsPlugin()) + + try: + from .plugin._default_plugin import InlineSnapshotPydanticPlugin + except ImportError: + pass + else: + _current.pm.register(InlineSnapshotPydanticPlugin()) def leave_snapshot_context(): diff --git a/src/inline_snapshot/plugin/_default_plugin.py b/src/inline_snapshot/plugin/_default_plugin.py index f9501cea..23fb77b2 100644 --- a/src/inline_snapshot/plugin/_default_plugin.py +++ b/src/inline_snapshot/plugin/_default_plugin.py @@ -216,8 +216,7 @@ def outsource_handler(self, value, builder: Builder): import attrs except ImportError: # pragma: no cover - class InlineSnapshotAttrsPlugin: - pass + pass else: @@ -260,9 +259,7 @@ def attrs_handler(self, value, builder: Builder): try: import pydantic except ImportError: # pragma: no cover - - class InlineSnapshotPydanticPlugin: - pass + pass else: # import pydantic From 4ede15ce86c788cc744d13f4adea1d398ba9f8bd Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sun, 4 Jan 2026 22:56:46 +0100 Subject: [PATCH 30/72] fix: fixed type errors --- .../_external/_storage/__init__.py | 1 - .../_snapshot/collection_value.py | 2 ++ src/inline_snapshot/_snapshot/dict_value.py | 1 + .../_snapshot/generic_value.py | 12 +++++----- .../_snapshot/undecided_value.py | 4 ++-- src/inline_snapshot/_snapshot_session.py | 8 ++++--- src/inline_snapshot/plugin/_spec.py | 24 ++++++++++++------- tests/test_docs.py | 5 +++- 8 files changed, 36 insertions(+), 21 deletions(-) diff --git a/src/inline_snapshot/_external/_storage/__init__.py b/src/inline_snapshot/_external/_storage/__init__.py index 45220b1b..7766487e 100644 --- a/src/inline_snapshot/_external/_storage/__init__.py +++ b/src/inline_snapshot/_external/_storage/__init__.py @@ -9,5 +9,4 @@ def default_storages(storage_dir: Path): - return {"hash": HashStorage(storage_dir / "external"), "uuid": UuidStorage()} diff --git a/src/inline_snapshot/_snapshot/collection_value.py b/src/inline_snapshot/_snapshot/collection_value.py index b36d53a8..7afa3643 100644 --- a/src/inline_snapshot/_snapshot/collection_value.py +++ b/src/inline_snapshot/_snapshot/collection_value.py @@ -16,6 +16,8 @@ class CollectionValue(GenericValue): _current_op = "x in snapshot" + _ast_node: ast.List | ast.Tuple + _new_value: CustomList def __contains__(self, item): if isinstance(self._old_value, CustomUndefined): diff --git a/src/inline_snapshot/_snapshot/dict_value.py b/src/inline_snapshot/_snapshot/dict_value.py index aa98c941..efec48f5 100644 --- a/src/inline_snapshot/_snapshot/dict_value.py +++ b/src/inline_snapshot/_snapshot/dict_value.py @@ -19,6 +19,7 @@ class DictValue(GenericValue): _new_value: CustomDict _old_value: CustomDict + _ast_node: ast.Dict def __getitem__(self, index): if isinstance(self._new_value, CustomUndefined): diff --git a/src/inline_snapshot/_snapshot/generic_value.py b/src/inline_snapshot/_snapshot/generic_value.py index 8d803b21..2ce33fe2 100644 --- a/src/inline_snapshot/_snapshot/generic_value.py +++ b/src/inline_snapshot/_snapshot/generic_value.py @@ -24,7 +24,7 @@ class GenericValue(SnapshotBase): _new_value: Custom _old_value: Custom _current_op = "undefined" - _ast_node: ast.Expr + _ast_node: ast.expr _context: AdapterContext def get_builder(self, **args): @@ -101,22 +101,22 @@ def _type_error(self, op): f"This snapshot cannot be use with `{op}`, because it was previously used with `{self._current_op}`" ) - def __eq__(self, _other): + def __eq__(self, other): __tracebackhide__ = True self._type_error("==") - def __le__(self, _other): + def __le__(self, other): __tracebackhide__ = True self._type_error("<=") - def __ge__(self, _other): + def __ge__(self, other): __tracebackhide__ = True self._type_error(">=") - def __contains__(self, _other): + def __contains__(self, item): __tracebackhide__ = True self._type_error("in") - def __getitem__(self, _item): + def __getitem__(self, item): __tracebackhide__ = True self._type_error("snapshot[key]") diff --git a/src/inline_snapshot/_snapshot/undecided_value.py b/src/inline_snapshot/_snapshot/undecided_value.py index f0e3eddc..e4c99a1a 100644 --- a/src/inline_snapshot/_snapshot/undecided_value.py +++ b/src/inline_snapshot/_snapshot/undecided_value.py @@ -118,11 +118,11 @@ def __ge__(self, other): self._change(MaxValue) return self >= other - def __contains__(self, other): + def __contains__(self, item): from .._snapshot.collection_value import CollectionValue self._change(CollectionValue) - return other in self + return item in self def __getitem__(self, item): from .._snapshot.dict_value import DictValue diff --git a/src/inline_snapshot/_snapshot_session.py b/src/inline_snapshot/_snapshot_session.py index 21f2213d..429c6700 100644 --- a/src/inline_snapshot/_snapshot_session.py +++ b/src/inline_snapshot/_snapshot_session.py @@ -356,10 +356,12 @@ def load_config(self, pyproject, cli_flags, parallel_run, error, project_root): state().active = "disable" not in flags state().update_flags = Flags(flags & categories) - if state().config.storage_dir is None: - state().config.storage_dir = project_root / ".inline-snapshot" + storage_dir = state().config.storage_dir + if storage_dir is None: + storage_dir = project_root / ".inline-snapshot" + state().config.storage_dir = storage_dir - state().all_storages = default_storages(state().config.storage_dir) + state().all_storages = default_storages(storage_dir) if flags - {"short-report", "disable"} and not is_pytest_compatible(): diff --git a/src/inline_snapshot/plugin/_spec.py b/src/inline_snapshot/plugin/_spec.py index 28361542..8e548033 100644 --- a/src/inline_snapshot/plugin/_spec.py +++ b/src/inline_snapshot/plugin/_spec.py @@ -30,17 +30,19 @@ Example: Basic usage with a custom class: - - + ``` python - from inline_snapshot import customize, snapshot - - class MyClass: def __init__(self, arg1, arg2, key=None): self.arg1 = arg1 self.arg2 = arg2 self.key_attr = key + ``` + + + ``` python + from myclass import MyClass + from inline_snapshot import customize @customize @@ -51,6 +53,12 @@ def my_custom_handler(value, builder): MyClass, [value.arg1, value.arg2], {"key": value.key_attr} ) return None # Let other handlers process this value + ``` + + + ``` python + from inline_snapshot import snapshot + from myclass import MyClass def test_myclass(): @@ -60,11 +68,11 @@ def test_myclass(): Note: - **Always register handlers in `conftest.py`** to ensure they're available for all tests - - Handlers are called in **reverse order** of registration (last registered is called first) - If no handler returns a Custom object, a default representation is used - - Use builder methods (`create_call`, `create_list`, `create_dict`, etc.) to construct representations + - Use builder methods (`create_call`, `create_external`) to construct representations - Always return `None` if your handler doesn't apply to the given value type - - The builder automatically handles recursive conversion of nested values + - The builder automatically handles recursive conversion of nested values, therfor `create_list` and `create_dict` are unlikely needed because you can just use `[]` or `{}` + See Also: - [Builder][inline_snapshot._customize.Builder]: Available builder methods diff --git a/tests/test_docs.py b/tests/test_docs.py index 26ea291c..2a012673 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -50,11 +50,13 @@ def map_code_blocks(file: Path, func): block_options: Optional[str] = None code_header = None header_line = "" + block_found = False for linenumber, line in enumerate(current_code.splitlines(), start=1): m = block_start.fullmatch(line) if m and not is_block: # ``` python + block_found = True block_start_linenum = linenumber indent = m[1] block_options = m[2] @@ -130,7 +132,8 @@ def map_code_blocks(file: Path, func): new_code = "\n".join(new_lines) + "\n" - assert external_file(file, format=".txt") == new_code + if block_found: + assert external_file(file, format=".txt") == new_code def test_map_code_blocks(tmp_path): From 2774c6121cd33eac70b174df8de387cc9946a805 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Mon, 5 Jan 2026 20:51:04 +0100 Subject: [PATCH 31/72] test: coverage --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d528b5bf..81f9b7e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,7 +93,7 @@ exclude_lines = [ "# pragma: no cover", "if TYPE_CHECKING:", "if is_insider", - ": ..." + "\\.\\.\\." ] From ba3c7c66e0379a374e688efbeb0f379744657215 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Mon, 5 Jan 2026 23:27:00 +0100 Subject: [PATCH 32/72] docs: added docs for customize --- conftest.py | 10 ++ docs/customize.md | 156 +++++++++++++++--- src/inline_snapshot/__init__.py | 2 - src/inline_snapshot/_customize/_builder.py | 4 +- src/inline_snapshot/_customize/_custom.py | 20 +-- src/inline_snapshot/plugin/_default_plugin.py | 6 +- src/inline_snapshot/plugin/_spec.py | 6 - 7 files changed, 156 insertions(+), 48 deletions(-) diff --git a/conftest.py b/conftest.py index 1b4ee73f..e7747f1a 100644 --- a/conftest.py +++ b/conftest.py @@ -10,3 +10,13 @@ def snapshot_env_for_doctest(request): yield else: yield + + +from inline_snapshot import customize, Builder +from dirty_equals import IsNow + + +@customize +def is_now_handler(value, builder: Builder): + if value == IsNow(): + return IsNow diff --git a/docs/customize.md b/docs/customize.md index 69262ce5..df32fdb3 100644 --- a/docs/customize.md +++ b/docs/customize.md @@ -1,19 +1,20 @@ -`@customize` allows you to register special hooks that control how inline-snapshot generates your snapshots. +`@customize` allows you to register special hooks to control how inline-snapshot generates your snapshots. You should use it when you find yourself manually editing snapshots after they were created by inline-snapshot. -inline-snapshot calls each hook until it finds one that returns a custom object, which can be created with the `create_*` methods of the [`Builder`][inline_snapshot.Builder]. +`@customize` hooks can have the following arguments (but you do not have to use all of them). -One use case might be that you have a dataclass with a special constructor function that can be used for certain instances of this dataclass, and you want inline-snapshot to use this constructor when possible. +* **value:** the value of your snapshot that is currently being converted to source code. +* **builder:** your `Builder` object can be used to create Custom objects that represent your new code. +* **local_vars:** a list of objects with `name` and `value` attributes that represent the local variables that are usable in your snapshot. +* **global_vars:** same as for `local_vars`, but for global variables. +## Custom constructor methods +One use case might be that you have a dataclass with a special constructor function that can be used for specific instances of this dataclass, and you want inline-snapshot to use this constructor when possible. - + ``` python from dataclasses import dataclass -from inline_snapshot import customize -from inline_snapshot import Builder -from inline_snapshot import snapshot - @dataclass class Rect: @@ -23,12 +24,31 @@ class Rect: @staticmethod def make_quadrat(size): return Rect(size, size) +``` + +You can define a hook in your `conftest.py` that checks if your value is a square and calls the correct constructor function. +Inline-snapshot tries each hook until it finds one that does not return None. +It keeps converting this value until a hook returns a Custom object, which can be created with the `create_*` methods of the [`Builder`][inline_snapshot.Builder]. + + +``` python +from rect import Rect +from inline_snapshot import customize +from inline_snapshot import Builder @customize def quadrat_handler(value, builder: Builder): if isinstance(value, Rect) and value.width == value.height: return builder.create_call(Rect.make_quadrat, [value.width]) +``` + +This allows you to influence the code that is created by inline-snapshot. + + +``` python +from inline_snapshot import snapshot +from rect import Rect def test_quadrat(): @@ -37,43 +57,141 @@ def test_quadrat(): assert Rect(1, 2) == snapshot(Rect(width=1, height=2)) # (3)! ``` -1. Your handler is used because you created a quadrat +1. Your handler is used because you created a square 2. Your handler is used because you created a rect that happens to have the same width and height 3. Your handler is not used because width and height are different +## dirty-equal expressions +It can also be used to instruct inline-snapshot to use specific dirty-equals expressions for specific values. + + +``` python +from inline_snapshot import customize +from inline_snapshot import Builder +from dirty_equals import IsNow -It can also be used to teach inline-snapshot to use specific dirty-equals expressions for specific values. +@customize +def is_now_handler(value): + if value == IsNow(): + return IsNow +``` + +Inline-snapshot provides a handler that can convert dirty-equals expressions back into source code. This allows you to return `IsNow` here without the need to construct a custom object with the builder. +This works because the value is converted with the customize functions until one hook uses the builder to create a Custom object. ``` python -from dataclasses import dataclass +from inline_snapshot import snapshot +from datetime import datetime + +from dirty_equals import IsNow # (1)! + + +def test_is_now(): + assert datetime.now() == snapshot(IsNow) +``` + +1. Inline-snapshot also creates the imports when they are missing + +!!! important + Inline-snapshot will never change the dirty-equals expressions in your code because they are unmanaged. + Using `@customize` with dirty-equals is a one-way ticket. Once the code is created, inline-snapshot does not know if it was created by inline-snapshot itself or by the user and will not change it, because it has to assume that it was created by the user. + +## Conditional external objects + +`create_external` can be used to store values in external files if a specific criterion is met. + + +``` python from inline_snapshot import customize from inline_snapshot import Builder -from inline_snapshot import snapshot - from dirty_equals import IsNow + + +@customize +def is_now_handler(value, builder: Builder): + if isinstance(value, str) and value.count("\n") > 5: + return builder.create_external(value) +``` + + +``` python +from inline_snapshot import snapshot from datetime import datetime +from inline_snapshot import external + + +def test_long_strings(): + assert "a\nb\nc" == snapshot( + """\ +a +b +c\ +""" + ) + assert "a\n" * 50 == snapshot( + external("uuid:e3e70682-c209-4cac-a29f-6fbed82c07cd.txt") + ) +``` + +## Reusing local variables + +There are times when your local or global variables become part of your snapshots, like uuids or user names. +Customize hooks accept `local_vars` and `global_vars` as arguments that can be used to generate the code. + + +``` python title="conftest.py" +from inline_snapshot import customize +from inline_snapshot import Builder + @customize -def quadrat_handler(value, builder: Builder): - if value == IsNow(): - return builder.create_call(IsNow) +def local_var_handler(value, local_vars): + for local in local_vars: + if local.name.startswith("v_") and local.value == value: + return local +``` +We check all local variables to see if they match our naming convention and are equal to the value that is part of our snapshot, and return the local if we find one that fits the criteria. -def test_quadrat(): - assert datetime.now() == snapshot(IsNow()) + + +``` python title="test_user.py" +from inline_snapshot import snapshot +from datetime import datetime + +from inline_snapshot import external + + +def get_data(user): + return {"user": user, "age": 55} + + +def test_user(): + v_user = "Bob" + some_number = 50 + 5 + + assert get_data(v_user) == snapshot({"user": v_user, "age": 55}) ``` +Inline-snapshot uses `v_user` because it met the criteria in your customization hook, but not `some_number` because it does not start with `v_`. +You can also do this only for specific types of objects or for a whitelist of variable names. +It is up to you to set the rules that work best in your project. +!!! note + It is not recommended to check only for the value because this might result in local variables which become part of the snapshot just because they are equal to the value and not because they should be there (see `age=55` in the example above). + This is also the reason why inline-snapshot does not provide default customizations that check your local variables. + The rules are project specific and what might work well for one project can cause problems for others. +# Reference ::: inline_snapshot options: heading_level: 3 - members: [customize,Builder,Custom,CustomizeHandler] + members: [customize,Builder,Custom] show_root_heading: false show_bases: false show_source: false diff --git a/src/inline_snapshot/__init__.py b/src/inline_snapshot/__init__.py index 86edcd24..48bb7569 100644 --- a/src/inline_snapshot/__init__.py +++ b/src/inline_snapshot/__init__.py @@ -5,7 +5,6 @@ from ._code_repr import customize_repr from ._customize._builder import Builder from ._customize._custom import Custom -from ._customize._custom import CustomizeHandler from ._exceptions import UsageError from ._external._external import external from ._external._external_file import external_file @@ -44,5 +43,4 @@ "customize", "Custom", "Builder", - "CustomizeHandler", ] diff --git a/src/inline_snapshot/_customize/_builder.py b/src/inline_snapshot/_customize/_builder.py index 33e8ed21..9e9a0c39 100644 --- a/src/inline_snapshot/_customize/_builder.py +++ b/src/inline_snapshot/_customize/_builder.py @@ -62,7 +62,9 @@ def _get_handler(self, v) -> Custom: result.__dict__["original_value"] = v return result - def create_external(self, value: Any, format: str | None, storage: str | None): + def create_external( + self, value: Any, format: str | None = None, storage: str | None = None + ): return CustomExternal(value, format=format, storage=storage) diff --git a/src/inline_snapshot/_customize/_custom.py b/src/inline_snapshot/_customize/_custom.py index e1847522..eabc34e6 100644 --- a/src/inline_snapshot/_customize/_custom.py +++ b/src/inline_snapshot/_customize/_custom.py @@ -5,15 +5,13 @@ from abc import abstractmethod from typing import TYPE_CHECKING from typing import Any -from typing import Callable from typing import Generator -from typing import TypeAlias from inline_snapshot._adapter_context import AdapterContext from inline_snapshot._change import ChangeBase if TYPE_CHECKING: - from inline_snapshot._customize._builder import Builder + pass class Custom(ABC): @@ -40,19 +38,3 @@ def eval(self): def _needed_imports(self): yield from () - - -CustomizeHandler: TypeAlias = Callable[[Any, "Builder"], Custom | None] -""" -Type alias for customization handler functions. - -A customization handler is a function that takes a Python value and a Builder, -and returns either a Custom representation or None. - -The handler receives two parameters: - -- `value` (Any): The Python object to be converted to snapshot code -- `builder` (Builder): Helper object providing methods to create Custom representations - -The handler should return a Custom object if it processes the value type, or None otherwise. -""" diff --git a/src/inline_snapshot/plugin/_default_plugin.py b/src/inline_snapshot/plugin/_default_plugin.py index 23fb77b2..0f497e26 100644 --- a/src/inline_snapshot/plugin/_default_plugin.py +++ b/src/inline_snapshot/plugin/_default_plugin.py @@ -188,18 +188,22 @@ def undefined_handler(self, value, builder: Builder): def dirty_equals_handler(self, value, builder: Builder): if is_dirty_equal(value) and builder._build_new_value: + if isinstance(value, type): return builder.create_value(value, value.__name__).with_import( "dirty_equals", value.__name__ ) else: + from dirty_equals import IsNow from dirty_equals._utils import Omit args = [a for a in value._repr_args if a is not Omit] kwargs = {k: a for k, a in value._repr_kwargs.items() if a is not Omit} + if type(value) == IsNow: + kwargs.pop("approx") return builder.create_call(type(value), args, kwargs) - @customize + @customize(tryfirst=True) def context_value_handler(self, value, builder: Builder): if isinstance(value, ContextValue): return builder.create_value(value.value, value.name) diff --git a/src/inline_snapshot/plugin/_spec.py b/src/inline_snapshot/plugin/_spec.py index 8e548033..3f684fd0 100644 --- a/src/inline_snapshot/plugin/_spec.py +++ b/src/inline_snapshot/plugin/_spec.py @@ -21,12 +21,6 @@ **Important**: Customization handlers should be registered in your `conftest.py` file to ensure they are loaded before your tests run. - Args: - f: A customization handler function. See [CustomizeHandler][inline_snapshot._customize.CustomizeHandler] - for the expected signature. - - Returns: - The input function unchanged (for use as a decorator) Example: Basic usage with a custom class: From b9fe8a2778cd66ff1d50822073d9b0e2f8b44f1d Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Tue, 6 Jan 2026 08:51:19 +0100 Subject: [PATCH 33/72] fix: removed customize function from my conftest.py --- conftest.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/conftest.py b/conftest.py index e7747f1a..1b4ee73f 100644 --- a/conftest.py +++ b/conftest.py @@ -10,13 +10,3 @@ def snapshot_env_for_doctest(request): yield else: yield - - -from inline_snapshot import customize, Builder -from dirty_equals import IsNow - - -@customize -def is_now_handler(value, builder: Builder): - if value == IsNow(): - return IsNow From b71a4ea784a9d83830d62dc59e9e480b157eebb6 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Thu, 8 Jan 2026 08:55:22 +0100 Subject: [PATCH 34/72] feat!: deprecated @customize_repr and import only InlineSnapshot* classes from conftest.py --- docs/customize.md | 38 +++++++------- src/inline_snapshot/_code_repr.py | 8 +++ src/inline_snapshot/_snapshot_session.py | 13 +---- src/inline_snapshot/plugin/_spec.py | 2 +- tests/adapter/test_dataclass.py | 9 ++-- tests/conftest.py | 21 ++++---- tests/test_customize.py | 9 ++-- tests/test_docs.py | 65 ++++++++++++------------ 8 files changed, 87 insertions(+), 78 deletions(-) diff --git a/docs/customize.md b/docs/customize.md index df32fdb3..eddd3084 100644 --- a/docs/customize.md +++ b/docs/customize.md @@ -37,10 +37,11 @@ from inline_snapshot import customize from inline_snapshot import Builder -@customize -def quadrat_handler(value, builder: Builder): - if isinstance(value, Rect) and value.width == value.height: - return builder.create_call(Rect.make_quadrat, [value.width]) +class InlineSnapshotExtension: + @customize + def quadrat_handler(self, value, builder: Builder): + if isinstance(value, Rect) and value.width == value.height: + return builder.create_call(Rect.make_quadrat, [value.width]) ``` This allows you to influence the code that is created by inline-snapshot. @@ -71,10 +72,11 @@ from inline_snapshot import Builder from dirty_equals import IsNow -@customize -def is_now_handler(value): - if value == IsNow(): - return IsNow +class InlineSnapshotExtension: + @customize + def is_now_handler(self, value): + if value == IsNow(): + return IsNow ``` Inline-snapshot provides a handler that can convert dirty-equals expressions back into source code. This allows you to return `IsNow` here without the need to construct a custom object with the builder. @@ -110,10 +112,11 @@ from inline_snapshot import Builder from dirty_equals import IsNow -@customize -def is_now_handler(value, builder: Builder): - if isinstance(value, str) and value.count("\n") > 5: - return builder.create_external(value) +class InlineSnapshotExtension: + @customize + def is_now_handler(self, value, builder: Builder): + if isinstance(value, str) and value.count("\n") > 5: + return builder.create_external(value) ``` @@ -148,11 +151,12 @@ from inline_snapshot import customize from inline_snapshot import Builder -@customize -def local_var_handler(value, local_vars): - for local in local_vars: - if local.name.startswith("v_") and local.value == value: - return local +class InlineSnapshotExtension: + @customize + def local_var_handler(self, value, local_vars): + for local in local_vars: + if local.name.startswith("v_") and local.value == value: + return local ``` We check all local variables to see if they match our naming convention and are equal to the value that is part of our snapshot, and return the local if we find one that fits the criteria. diff --git a/src/inline_snapshot/_code_repr.py b/src/inline_snapshot/_code_repr.py index a39313f6..964f2664 100644 --- a/src/inline_snapshot/_code_repr.py +++ b/src/inline_snapshot/_code_repr.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from contextlib import contextmanager from enum import Enum from enum import Flag @@ -7,6 +8,8 @@ from typing import TYPE_CHECKING from unittest import mock +from typing_extensions import deprecated + from inline_snapshot._generator_utils import only_value if TYPE_CHECKING: @@ -52,6 +55,7 @@ def code_repr_dispatch(value): return real_repr(value) +@deprecated("use @customize instead") def customize_repr(f): """Register a function which should be used to get the code representation of a object. @@ -68,6 +72,10 @@ def _(obj: MyCustomClass): * __repr__() of your class returns a valid code representation, * and __repr__() uses `repr()` to get the representation of the child objects """ + warnings.warn( + "@customize_repr is deprecated, @customize should be used instead", + DeprecationWarning, + ) code_repr_dispatch.register(f) diff --git a/src/inline_snapshot/_snapshot_session.py b/src/inline_snapshot/_snapshot_session.py index 429c6700..9535f31b 100644 --- a/src/inline_snapshot/_snapshot_session.py +++ b/src/inline_snapshot/_snapshot_session.py @@ -248,19 +248,10 @@ def register_customize_hooks_from_module(self, module): self.registered_modules.add(module.__file__) - class ConftestPlugin: - pass - for name in dir(module): obj = getattr(module, name, None) - if obj is None or not callable(obj): - continue - - # Check if the function has the customize hookimpl marker - if hasattr(obj, "inline_snapshot_impl"): - setattr(ConftestPlugin, name, obj) - - state().pm.register(ConftestPlugin, name=f"") + if isinstance(obj, type) and name.startswith("InlineSnapshot"): + state().pm.register(obj(), name=f"") @staticmethod def test_enter(): diff --git a/src/inline_snapshot/plugin/_spec.py b/src/inline_snapshot/plugin/_spec.py index 3f684fd0..6d8ba784 100644 --- a/src/inline_snapshot/plugin/_spec.py +++ b/src/inline_snapshot/plugin/_spec.py @@ -49,7 +49,7 @@ def my_custom_handler(value, builder): return None # Let other handlers process this value ``` - + ``` python from inline_snapshot import snapshot from myclass import MyClass diff --git a/tests/adapter/test_dataclass.py b/tests/adapter/test_dataclass.py index 8bdfbdd0..6aa073c8 100644 --- a/tests/adapter/test_dataclass.py +++ b/tests/adapter/test_dataclass.py @@ -502,10 +502,11 @@ def __eq__(self,other): from inline_snapshot import customize from helper import L -@customize -def handle_L(value,builder): - if isinstance(value,L): - return builder.create_call(L,value.l) +class InlineSnapshotExtension: + @customize + def handle_L(self,value,builder): + if isinstance(value,L): + return builder.create_call(L,value.l) """, "tests/test_something.py": """\ from inline_snapshot import snapshot diff --git a/tests/conftest.py b/tests/conftest.py index 593e0d7f..2746e53a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -311,15 +311,18 @@ def setup(self, source: str, add_header=True): import datetime import pytest from freezegun.api import FakeDatetime,FakeDate -from inline_snapshot import customize_repr - -@customize_repr -def _(value:FakeDatetime): - return value.__repr__().replace("FakeDatetime","datetime.datetime") - -@customize_repr -def _(value:FakeDate): - return value.__repr__().replace("FakeDate","datetime.date") +from inline_snapshot import customize + +class InlineSnapshotExtension: + @customize + def fakedatetime_handler(self,value,builder): + if isinstance(value,FakeDatetime): + return builder.create_value(value,value.__repr__().replace("FakeDatetime","datetime.datetime")) + + @customize + def fakedate_handler(self,value,builder): + if isinstance(value,FakeDate): + return builder.create_value(value,value.__repr__().replace("FakeDate","datetime.date")) @pytest.fixture(autouse=True) diff --git a/tests/test_customize.py b/tests/test_customize.py index 616becff..4d96de99 100644 --- a/tests/test_customize.py +++ b/tests/test_customize.py @@ -16,10 +16,11 @@ def test_custom_dirty_equal(original, flag): from inline_snapshot import Builder from dirty_equals import IsStr -@customize -def re_handler(value, builder: Builder): - if value == IsStr(regex="[a-z]"): - return builder.create_call(IsStr, [], {"regex": "[a-z]"}) +class InlineSnapshotExtension: + @customize + def re_handler(self,value, builder: Builder): + if value == IsStr(regex="[a-z]"): + return builder.create_call(IsStr, [], {"regex": "[a-z]"}) """, "tests/test_something.py": f"""\ from inline_snapshot import snapshot diff --git a/tests/test_docs.py b/tests/test_docs.py index 2a012673..be56d7c8 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -17,6 +17,7 @@ from executing import is_pytest_compatible from inline_snapshot import snapshot +from inline_snapshot._align import align from inline_snapshot._external._external_file import external_file from inline_snapshot._flags import Flags from inline_snapshot._global_state import snapshot_env @@ -289,7 +290,6 @@ def __eq__(self, other: Any): def file_test( file: Path, width: int = 80, - use_hl_lines: bool = True, ): """Test code blocks with the header @@ -313,15 +313,18 @@ def file_test( import datetime import pytest from freezegun.api import FakeDatetime,FakeDate -from inline_snapshot import customize_repr +from inline_snapshot import customize -@customize_repr -def _(value:FakeDatetime): - return value.__repr__().replace("FakeDatetime","datetime.datetime") +class InlineSnapshotExtension: + @customize + def fakedatetime_handler(self,value,builder): + if isinstance(value,FakeDatetime): + return builder.create_value(value,value.__repr__().replace("FakeDatetime","datetime.datetime")) -@customize_repr -def _(value:FakeDate): - return value.__repr__().replace("FakeDate","datetime.date") + @customize + def fakedate_handler(self,value,builder): + if isinstance(value,FakeDate): + return builder.create_value(value,value.__repr__().replace("FakeDate","datetime.date")) @pytest.fixture(autouse=True) @@ -430,30 +433,28 @@ def test_block(block: Block): ] ) - if use_hl_lines: - from inline_snapshot._align import align - - linenum = 1 - hl_lines = "" - - if last_code is not None and "first_block" not in options: - changed_lines = [] - alignment = align(last_code.split("\n"), new_code.split("\n")) - for c in alignment: - if c == "d": - continue - elif c == "m": - linenum += 1 - else: - changed_lines.append(str(linenum)) - linenum += 1 - if changed_lines: - hl_lines = f'hl_lines="{" ".join(changed_lines)}"' + linenum = 1 + hl_lines = "" + + if last_code is not None and "first_block" not in options: + changed_lines = [] + alignment = align(last_code.split("\n"), new_code.split("\n")) + for c in alignment: + if c == "d": + continue + elif c == "m": + linenum += 1 else: - assert False, "no lines changed" - block.block_options = hl_lines - else: - pass # pragma: no cover + changed_lines.append(str(linenum)) + linenum += 1 + if changed_lines: + hl_lines = f'hl_lines="{" ".join(changed_lines)}"' + else: + assert False, "no lines changed" + + old_options = re.sub(r'hl_lines="[^"]*"', "", block.block_options).strip() + + block.block_options = f"{old_options} {hl_lines}".strip() block.code = new_code @@ -470,4 +471,4 @@ def test_block(block: Block): print(file) - file_test(file, width=60, use_hl_lines=True) + file_test(file, width=60) From 0b9ccb6593896c978ad955b19683dfe64c260441 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Thu, 8 Jan 2026 18:46:43 +0100 Subject: [PATCH 35/72] docs: support for title="..." in code-blocks --- docs/customize.md | 14 ++++++------ docs/eq_snapshot.md | 4 ++-- src/inline_snapshot/plugin/_spec.py | 4 ++-- tests/test_docs.py | 34 ++++++++++++++--------------- 4 files changed, 28 insertions(+), 28 deletions(-) diff --git a/docs/customize.md b/docs/customize.md index eddd3084..e2207125 100644 --- a/docs/customize.md +++ b/docs/customize.md @@ -12,7 +12,7 @@ You should use it when you find yourself manually editing snapshots after they w One use case might be that you have a dataclass with a special constructor function that can be used for specific instances of this dataclass, and you want inline-snapshot to use this constructor when possible. -``` python +``` python title="rect.py" from dataclasses import dataclass @@ -31,7 +31,7 @@ Inline-snapshot tries each hook until it finds one that does not return None. It keeps converting this value until a hook returns a Custom object, which can be created with the `create_*` methods of the [`Builder`][inline_snapshot.Builder]. -``` python +``` python title="conftest.py" from rect import Rect from inline_snapshot import customize from inline_snapshot import Builder @@ -47,7 +47,7 @@ class InlineSnapshotExtension: This allows you to influence the code that is created by inline-snapshot. -``` python +``` python title="test_quadrat.py" from inline_snapshot import snapshot from rect import Rect @@ -66,7 +66,7 @@ def test_quadrat(): It can also be used to instruct inline-snapshot to use specific dirty-equals expressions for specific values. -``` python +``` python title="conftest.py" from inline_snapshot import customize from inline_snapshot import Builder from dirty_equals import IsNow @@ -83,7 +83,7 @@ Inline-snapshot provides a handler that can convert dirty-equals expressions bac This works because the value is converted with the customize functions until one hook uses the builder to create a Custom object. -``` python +``` python title="test_is_now.py" from inline_snapshot import snapshot from datetime import datetime @@ -106,7 +106,7 @@ def test_is_now(): `create_external` can be used to store values in external files if a specific criterion is met. -``` python +``` python title="conftest.py" from inline_snapshot import customize from inline_snapshot import Builder from dirty_equals import IsNow @@ -120,7 +120,7 @@ class InlineSnapshotExtension: ``` -``` python +``` python title="test_long_strings.py" from inline_snapshot import snapshot from datetime import datetime diff --git a/docs/eq_snapshot.md b/docs/eq_snapshot.md index d5d3d4ad..d5af2cb2 100644 --- a/docs/eq_snapshot.md +++ b/docs/eq_snapshot.md @@ -326,7 +326,7 @@ The following example shows how this can be used to run a tests with two differe === "my_lib.py v1" - ``` python + ``` python title="my_lib.py" version = 1 @@ -337,7 +337,7 @@ The following example shows how this can be used to run a tests with two differe === "my_lib.py v2" - ``` python + ``` python title="my_lib.py" version = 2 diff --git a/src/inline_snapshot/plugin/_spec.py b/src/inline_snapshot/plugin/_spec.py index 6d8ba784..553394a3 100644 --- a/src/inline_snapshot/plugin/_spec.py +++ b/src/inline_snapshot/plugin/_spec.py @@ -25,7 +25,7 @@ Example: Basic usage with a custom class: - ``` python + ``` python title="myclass.py" class MyClass: def __init__(self, arg1, arg2, key=None): self.arg1 = arg1 @@ -34,7 +34,7 @@ def __init__(self, arg1, arg2, key=None): ``` - ``` python + ``` python title="conftest.py" from myclass import MyClass from inline_snapshot import customize diff --git a/tests/test_docs.py b/tests/test_docs.py index be56d7c8..518686f2 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -30,7 +30,7 @@ class Block: code: str code_header: Optional[str] - block_options: str + block_options: Dict[str, str] line: int @@ -60,7 +60,7 @@ def map_code_blocks(file: Path, func): block_found = True block_start_linenum = linenumber indent = m[1] - block_options = m[2] + block_options = {m[0]: m[1] for m in re.findall(r'(\w*)="([^"]*)"', m[2])} block_lines = [] is_block = True continue @@ -95,9 +95,9 @@ def map_code_blocks(file: Path, func): if new_block.code_header is not None: new_lines.append(f"{indent}") - new_lines.append( - f"{indent}``` {('python '+new_block.block_options.strip()).strip()}" - ) + options = " ".join(f'{k}="{v}"' for k, v in new_block.block_options.items()) + + new_lines.append(f"{indent}``` {('python '+options).strip()}") new_code = new_block.code.rstrip() if file.suffix == ".py": @@ -206,12 +206,12 @@ def test_block(block): blocks=snapshot( [ Block( - code="print(1 + 1)\n", code_header=None, block_options="", line=2 + code="print(1 + 1)\n", code_header=None, block_options={}, line=2 ), Block( code="print(1 - 1)\n", code_header="inline-snapshot: create test", - block_options=' hl_lines="1 2 3"', + block_options={"hl_lines": "1 2 3"}, line=7, ), ] @@ -221,7 +221,7 @@ def test_block(block): def change_block(block): block.code = "# removed" block.code_header = "header" - block.block_options = "option a b c" + block.block_options = {"a": "b c"} test_doc( """\ @@ -236,7 +236,7 @@ def change_block(block): Block( code="# removed", code_header="header", - block_options="option a b c", + block_options={"a": "b c"}, line=2, ) ] @@ -245,7 +245,7 @@ def change_block(block): """\ text -``` python option a b c +``` python a="b c" # removed ``` """ @@ -353,11 +353,15 @@ def test_block(block: Block): return block if block.code_header.startswith("inline-snapshot-lib:"): - extra_files[block.code_header.split()[1]].append(block.code) + name = block.code_header.split()[1] + extra_files[name].append(block.code) + block.block_options["title"] = name return block if block.code_header.startswith("inline-snapshot-lib-set:"): - extra_files[block.code_header.split()[1]] = [block.code] + name = block.code_header.split()[1] + extra_files[name] = [block.code] + block.block_options["title"] = name return block if block.code_header.startswith("todo-inline-snapshot:"): @@ -448,14 +452,10 @@ def test_block(block: Block): changed_lines.append(str(linenum)) linenum += 1 if changed_lines: - hl_lines = f'hl_lines="{" ".join(changed_lines)}"' + block.block_options["hl_lines"] = " ".join(changed_lines) else: assert False, "no lines changed" - old_options = re.sub(r'hl_lines="[^"]*"', "", block.block_options).strip() - - block.block_options = f"{old_options} {hl_lines}".strip() - block.code = new_code last_code = code From c12ea751a20ce215963fd9d14ab0d34d23e3b894 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Thu, 8 Jan 2026 23:25:22 +0100 Subject: [PATCH 36/72] docs: added docs for customize --- docs/categories.md | 12 +- docs/customize.md | 105 +++++++++++++++--- docs/customize_repr.md | 91 +++++++++------ src/inline_snapshot/_code_repr.py | 17 --- src/inline_snapshot/_customize/_builder.py | 10 +- .../_customize/_custom_value.py | 4 +- src/inline_snapshot/_new_adapter.py | 10 +- .../_snapshot/undecided_value.py | 4 +- src/inline_snapshot/plugin/_default_plugin.py | 46 ++++++-- src/inline_snapshot/testing/_example.py | 5 + tests/conftest.py | 4 +- tests/test_code_repr.py | 42 ++++--- tests/test_docs.py | 4 +- 13 files changed, 240 insertions(+), 114 deletions(-) diff --git a/docs/categories.md b/docs/categories.md index 55c31e48..157a5ba2 100644 --- a/docs/categories.md +++ b/docs/categories.md @@ -184,8 +184,10 @@ It is recommended to use trim only if you run your complete test suite. ### Update Changes in the update category do not change the value of the snapshot, just its representation. -These updates are not shown by default in your reports and can be enabled with [show-updates](configuration.md/#show-updates). -The reason might be that `#!python repr()` of the object has changed or that inline-snapshot provides some new logic which changes the representation. Like with the strings in the following example: +These updates are not shown by default in your reports, because it can be confusing for users who uses inline-snapshot the first time or want to change the snapshot values manual. +Updates can be enabled with [show-updates](configuration.md/#show-updates). + +The reason for updates might be that `#!python repr()` of the object has changed or that inline-snapshot provides some new logic which changes the representation. Like with the strings in the following example: === "original" @@ -254,7 +256,7 @@ The reason might be that `#!python repr()` of the object has changed or that inl ``` -The approval of this type of changes is easier, because inline-snapshot assures that the value has not changed. +The approval of this type of changes is easier, because the update category assures that the value has not changed. The goal of inline-snapshot is to generate the values for you in the correct format so that no manual editing is required. This improves your productivity and saves time. @@ -270,5 +272,5 @@ You can agree with inline-snapshot and accept the changes or you can use one of 4. you can also open an [issue](https://github.com/15r10nk/inline-snapshot/issues?q=is%3Aissue%20state%3Aopen%20label%3Aupdate_related) if you have a specific problem with the way inline-snapshot generates the code. -!!! note: - [#177](https://github.com/15r10nk/inline-snapshot/issues/177) will give the developer more control about how snapshots are created. *update* will them become much more useful. +!!! note + [#177](https://github.com/15r10nk/inline-snapshot/issues/177) will give the developer more control about how snapshots are created. *update* will then become much more useful. diff --git a/docs/customize.md b/docs/customize.md index e2207125..1a0d71b5 100644 --- a/docs/customize.md +++ b/docs/customize.md @@ -6,7 +6,7 @@ You should use it when you find yourself manually editing snapshots after they w * **value:** the value of your snapshot that is currently being converted to source code. * **builder:** your `Builder` object can be used to create Custom objects that represent your new code. * **local_vars:** a list of objects with `name` and `value` attributes that represent the local variables that are usable in your snapshot. -* **global_vars:** same as for `local_vars`, but for global variables. +* **global_vars:** same as `local_vars`, but for global variables. ## Custom constructor methods One use case might be that you have a dataclass with a special constructor function that can be used for specific instances of this dataclass, and you want inline-snapshot to use this constructor when possible. @@ -22,7 +22,7 @@ class Rect: height: int @staticmethod - def make_quadrat(size): + def make_square(size): return Rect(size, size) ``` @@ -39,27 +39,27 @@ from inline_snapshot import Builder class InlineSnapshotExtension: @customize - def quadrat_handler(self, value, builder: Builder): + def square_handler(self, value, builder: Builder): if isinstance(value, Rect) and value.width == value.height: - return builder.create_call(Rect.make_quadrat, [value.width]) + return builder.create_call(Rect.make_square, [value.width]) ``` This allows you to influence the code that is created by inline-snapshot. -``` python title="test_quadrat.py" +``` python title="test_square.py" from inline_snapshot import snapshot from rect import Rect -def test_quadrat(): - assert Rect.make_quadrat(5) == snapshot(Rect.make_quadrat(5)) # (1)! - assert Rect(1, 1) == snapshot(Rect.make_quadrat(1)) # (2)! +def test_square(): + assert Rect.make_square(5) == snapshot(Rect.make_square(5)) # (1)! + assert Rect(1, 1) == snapshot(Rect.make_square(1)) # (2)! assert Rect(1, 2) == snapshot(Rect(width=1, height=2)) # (3)! ``` 1. Your handler is used because you created a square -2. Your handler is used because you created a rect that happens to have the same width and height +2. Your handler is used because you created a Rect that happens to have the same width and height 3. Your handler is not used because width and height are different ## dirty-equal expressions @@ -98,7 +98,7 @@ def test_is_now(): !!! important Inline-snapshot will never change the dirty-equals expressions in your code because they are unmanaged. - Using `@customize` with dirty-equals is a one-way ticket. Once the code is created, inline-snapshot does not know if it was created by inline-snapshot itself or by the user and will not change it, because it has to assume that it was created by the user. + Using `@customize` with dirty-equals is a one-way ticket. Once the code is created, inline-snapshot does not know if it was created by inline-snapshot itself or by the user and will not change it when you change the `@customize` implementation, because it has to assume that it was created by the user. ## Conditional external objects @@ -114,7 +114,7 @@ from dirty_equals import IsNow class InlineSnapshotExtension: @customize - def is_now_handler(self, value, builder: Builder): + def long_string_handler(self, value, builder: Builder): if isinstance(value, str) and value.count("\n") > 5: return builder.create_external(value) ``` @@ -159,15 +159,12 @@ class InlineSnapshotExtension: return local ``` -We check all local variables to see if they match our naming convention and are equal to the value that is part of our snapshot, and return the local if we find one that fits the criteria. +We check all local variables to see if they match our naming convention and are equal to the value that is part of our snapshot, and return the local variable if we find one that fits the criteria. ``` python title="test_user.py" from inline_snapshot import snapshot -from datetime import datetime - -from inline_snapshot import external def get_data(user): @@ -188,7 +185,83 @@ It is up to you to set the rules that work best in your project. !!! note It is not recommended to check only for the value because this might result in local variables which become part of the snapshot just because they are equal to the value and not because they should be there (see `age=55` in the example above). This is also the reason why inline-snapshot does not provide default customizations that check your local variables. - The rules are project specific and what might work well for one project can cause problems for others. + The rules are project-specific and what might work well for one project can cause problems for others. + +## Creating special code + +Let's say that you have an array of secrets which are used in your code. + + +``` python title="my_secrets.py" +secrets = ["some_secret", "some_other_secret"] +``` + + +``` python title="get_data.py" +from my_secrets import secrets + + +def get_data(): + return {"data": "large data block", "used_secret": secrets[1]} +``` + +The problem is that `--inline-snapshot=create` puts your secret into your test. + + +``` python +from inline_snapshot import snapshot +from get_data import get_data + + +def test_my_class(): + assert get_data() == snapshot( + {"data": "large data block", "used_secret": "some_other_secret"} + ) +``` + +Maybe this is not what you want because the secret is a different one in CI or for every test run or the raw value leads to unreadable tests. +What you can do now, instead of replacing `"some_other_secret"` with `secrets[1]` by hand, is to tell inline-snapshot how it should generate this code in your *conftest.py*. + + +``` python title="conftest.py" +from my_secrets import secrets +from inline_snapshot import customize, Builder + + +class InlineSnapshotExtension: + @customize + def secret_handler(self, value, builder: Builder): + for i, secret in enumerate(secrets): + if value == secret: + return builder.create_code(secret, f"secrets[{i}]").with_import( + "my_secrets", "secrets" + ) +``` + +Inline-snapshot will now create the correct code and import statement when you run your tests with `--inline-snapshot=update`. + + +``` python hl_lines="4 5 9" +from inline_snapshot import snapshot +from get_data import get_data + +from my_secrets import secrets + + +def test_my_class(): + assert get_data() == snapshot( + {"data": "large data block", "used_secret": secrets[1]} + ) +``` + +!!! question "why update?" + `"some_other_secret"` was already a correct value for your assertion and `--inline-snapshot=fix` only changes code when the current value is not correct and needs to be fixed. `update` is the category for all other changes where inline-snapshot wants to generate different code which represents the same value as the code before. + + You only have to use `update` when you changed your customizations and want to use the new code representations in your existing tests. The new representation is also used by `create` or `fix` when you write new tests. + + The `update` category is not enabled by default for `--inline-snapshot=review/report`. + You can read [here](categories.md#update) more about it. + # Reference diff --git a/docs/customize_repr.md b/docs/customize_repr.md index 5b01c13f..6eec2fc1 100644 --- a/docs/customize_repr.md +++ b/docs/customize_repr.md @@ -1,5 +1,28 @@ +!!! warning "deprecated" + `@customize_repr` will be removed in the future because `@customize` provides the same and even more features. + You should use + ``` python title="conftest.py" + class InlineSnapshotExtension: + @customize + def my_class_handler(value, builder): + if isinstance(value, MyClass): + return builder.create_code(value, "my_class_repr") + ``` + + instead of + + ``` python title="conftest.py" + @customize_repr + def my_class_handler(value: MyClass): + return "my_class_repr" + ``` + + `@customize` allows you not only to generate code but also imports and function calls which can be analysed by inline-snapshot. + + +That said, what is/was `@customize_repr` for? `repr()` can be used to convert a python object into a source code representation of the object, but this does not work for every type. Here are some examples: @@ -16,10 +39,34 @@ Here are some examples: `customize_repr` can be used to overwrite the default `repr()` behaviour. -The implementation for `Enum` looks like this: +The implementation for `MyClass` could look like this: + + +``` python title="my_class.py" +class MyClass: + def __init__(self, values): + self.values = values.split() + + def __repr__(self): + return repr(self.values) -``` python exec="1" result="python" -print('--8<-- "src/inline_snapshot/_code_repr.py:Enum"') + def __eq__(self, other): + if not isinstance(other, MyClass): + return NotImplemented + return self.values == other.values +``` + +You can specify the `repr()` used by inline-snapshot in your *conftest.py* + + +``` python title="conftest.py" +from my_class import MyClass +from inline_snapshot import customize_repr + + +@customize_repr +def _(value: MyClass): + return f"{MyClass.__qualname__}({' '.join(value.values) !r})" ``` This implementation is then used by inline-snapshot if `repr()` is called during the code generation, but not in normal code. @@ -27,43 +74,21 @@ This implementation is then used by inline-snapshot if `repr()` is called during ``` python from inline_snapshot import snapshot -from enum import Enum +from my_class import MyClass -def test_enum(): - E = Enum("E", ["a", "b"]) +def test_my_class(): + e = MyClass("1 5 hello") # normal repr - assert repr(E.a) == "" + assert repr(e) == "['1', '5', 'hello']" # the special implementation to convert the Enum into a code - assert E.a == snapshot(E.a) + assert e == snapshot(MyClass("1 5 hello")) ``` -## built-in data types - -inline-snapshot comes with a special implementation for the following types: - -``` python exec="1" -from inline_snapshot._code_repr import code_repr_dispatch, code_repr - -for name, obj in sorted( - ( - getattr( - obj, "_inline_snapshot_name", f"{obj.__module__}.{obj.__qualname__}" - ), - obj, - ) - for obj in code_repr_dispatch.registry.keys() -): - if obj is not object: - print(f"- `{name}`") -``` - -Please open an [issue](https://github.com/15r10nk/inline-snapshot/issues) if you found a built-in type which is not supported by inline-snapshot. - !!! note - Container types like `dict`, `list`, `tuple` or `dataclass` are handled in a different way, because inline-snapshot also needs to inspect these types to implement [unmanaged](/eq_snapshot.md#unmanaged-snapshot-values) snapshot values. + The example above can be better handled with `@customize` as shown in the documentation there. ## customize recursive repr @@ -94,8 +119,10 @@ class Pair: return self.a == other.a and self.b == other.b +E = Enum("E", ["a", "b"]) + + def test_enum(): - E = Enum("E", ["a", "b"]) # the special repr implementation is used recursive here # to convert every Enum to the correct representation diff --git a/src/inline_snapshot/_code_repr.py b/src/inline_snapshot/_code_repr.py index 964f2664..8add9201 100644 --- a/src/inline_snapshot/_code_repr.py +++ b/src/inline_snapshot/_code_repr.py @@ -2,8 +2,6 @@ import warnings from contextlib import contextmanager -from enum import Enum -from enum import Flag from functools import singledispatch from typing import TYPE_CHECKING from unittest import mock @@ -110,18 +108,3 @@ def value_code_repr(obj): result = code_repr_dispatch(obj) return result - - -# -8<- [start:Enum] -@customize_repr -def _(value: Enum): - return f"{type(value).__qualname__}.{value.name}" - - -# -8<- [end:Enum] - - -@customize_repr -def _(value: Flag): - name = type(value).__qualname__ - return " | ".join(f"{name}.{flag.name}" for flag in type(value) if flag in value) diff --git a/src/inline_snapshot/_customize/_builder.py b/src/inline_snapshot/_customize/_builder.py index 9e9a0c39..38fcc903 100644 --- a/src/inline_snapshot/_customize/_builder.py +++ b/src/inline_snapshot/_customize/_builder.py @@ -14,7 +14,7 @@ from ._custom_external import CustomExternal from ._custom_sequence import CustomList from ._custom_sequence import CustomTuple -from ._custom_value import CustomValue +from ._custom_value import CustomCode @dataclass @@ -55,7 +55,7 @@ def _get_handler(self, v) -> Custom: global_vars=global_vars, ) if r is None: - result = CustomValue(result) + result = CustomCode(result) else: result = r @@ -128,11 +128,11 @@ def create_dict(self, value) -> Custom: custom = {self._get_handler(k): self._get_handler(v) for k, v in value.items()} return CustomDict(value=custom) - def create_value(self, value, repr: str | None = None) -> CustomValue: + def create_code(self, value, repr: str | None = None) -> CustomCode: """ Creates an intermediate node for a value with a custom representation which can be used as a result for your customization function. - `create_value(my_obj, 'MyClass')` becomes `MyClass` in the code. + `create_code(my_obj, 'MyClass')` becomes `MyClass` in the code. Use this when you want to control the exact string representation of a value. """ - return CustomValue(value, repr) + return CustomCode(value, repr) diff --git a/src/inline_snapshot/_customize/_custom_value.py b/src/inline_snapshot/_customize/_custom_value.py index 9118f7e0..a6bf381e 100644 --- a/src/inline_snapshot/_customize/_custom_value.py +++ b/src/inline_snapshot/_customize/_custom_value.py @@ -25,7 +25,7 @@ def _simplify_module_path(module: str, name: str) -> str: return ".".join(parts) -class CustomValue(Custom): +class CustomCode(Custom): def __init__(self, value, repr_str=None): assert not isinstance(value, Custom) value = clone(value) @@ -54,7 +54,7 @@ def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: return self.repr_str def __repr__(self): - return f"CustomValue({self.repr_str})" + return f"CustomCode({self.repr_str})" def _needed_imports(self): yield from self._imports diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index b2754f55..6102cbd8 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -25,7 +25,7 @@ from inline_snapshot._customize._custom_sequence import CustomSequence from inline_snapshot._customize._custom_undefined import CustomUndefined from inline_snapshot._customize._custom_unmanaged import CustomUnmanaged -from inline_snapshot._customize._custom_value import CustomValue +from inline_snapshot._customize._custom_value import CustomCode from inline_snapshot._exceptions import UsageError from inline_snapshot._generator_utils import only_value from inline_snapshot.syntax_warnings import InlineSnapshotInfo @@ -106,7 +106,7 @@ def reeval_CustomUndefined(old_value, value): return value -def reeval_CustomValue(old_value: CustomValue, value: CustomValue): +def reeval_CustomCode(old_value: CustomCode, value: CustomCode): if not old_value.eval() == value.eval(): raise UsageError( @@ -168,10 +168,10 @@ def compare( old_value, old_node, new_value ) else: - result = yield from self.compare_CustomValue(old_value, old_node, new_value) + result = yield from self.compare_CustomCode(old_value, old_node, new_value) return result - def compare_CustomValue( + def compare_CustomCode( self, old_value: Custom, old_node: ast.expr, new_value: Custom ) -> Generator[ChangeBase, None, Custom]: @@ -186,7 +186,7 @@ def compare_CustomValue( if ( isinstance(old_node, ast.JoinedStr) - and isinstance(new_value, CustomValue) + and isinstance(new_value, CustomCode) and isinstance(new_value.value, str) ): if not old_value.eval() == new_value.eval(): diff --git a/src/inline_snapshot/_snapshot/undecided_value.py b/src/inline_snapshot/_snapshot/undecided_value.py index e4c99a1a..a4cdb9ed 100644 --- a/src/inline_snapshot/_snapshot/undecided_value.py +++ b/src/inline_snapshot/_snapshot/undecided_value.py @@ -10,7 +10,7 @@ from inline_snapshot._customize._custom_sequence import CustomTuple from inline_snapshot._customize._custom_undefined import CustomUndefined from inline_snapshot._customize._custom_unmanaged import CustomUnmanaged -from inline_snapshot._customize._custom_value import CustomValue +from inline_snapshot._customize._custom_value import CustomCode from inline_snapshot._new_adapter import NewAdapter from inline_snapshot._new_adapter import warn_star_expression from inline_snapshot._unmanaged import is_unmanaged @@ -43,7 +43,7 @@ def convert_generic(self, value: Any, node: ast.expr): if value is ...: return CustomUndefined() else: - return CustomValue(value, ast.unparse(node)) + return CustomCode(value, ast.unparse(node)) def convert_Call(self, value: Any, node: ast.Call): return CustomCall( diff --git a/src/inline_snapshot/plugin/_default_plugin.py b/src/inline_snapshot/plugin/_default_plugin.py index 0f497e26..72ffb8b1 100644 --- a/src/inline_snapshot/plugin/_default_plugin.py +++ b/src/inline_snapshot/plugin/_default_plugin.py @@ -4,6 +4,8 @@ from dataclasses import MISSING from dataclasses import fields from dataclasses import is_dataclass +from enum import Enum +from enum import Flag from pathlib import Path from pathlib import PurePath from types import BuiltinFunctionType @@ -44,7 +46,7 @@ def string_handler(self, value, builder: Builder): assert ast.literal_eval(triple_quoted_string) == value - return builder.create_value(value, triple_quoted_string) + return builder.create_code(value, triple_quoted_string) @customize(tryfirst=True) def counter_handler(self, value, builder: Builder): @@ -56,21 +58,21 @@ def function_handler(self, value, builder: Builder): if isinstance(value, FunctionType): qualname = value.__qualname__ name = qualname.split(".")[0] - return builder.create_value(value, qualname).with_import( + return builder.create_code(value, qualname).with_import( value.__module__, name ) @customize def builtin_function_handler(self, value, builder: Builder): if isinstance(value, BuiltinFunctionType): - return builder.create_value(value, value.__name__) + return builder.create_code(value, value.__name__) @customize def type_handler(self, value, builder: Builder): if isinstance(value, type): qualname = value.__qualname__ name = qualname.split(".")[0] - return builder.create_value(value, qualname).with_import( + return builder.create_code(value, qualname).with_import( value.__module__, name ) @@ -100,9 +102,9 @@ def sort_set_values(self, set_values): def set_handler(self, value, builder: Builder): if isinstance(value, set): if len(value) == 0: - return builder.create_value(value, "set()") + return builder.create_code(value, "set()") else: - return builder.create_value( + return builder.create_code( value, "{" + ", ".join(self.sort_set_values(value)) + "}" ) @@ -110,10 +112,36 @@ def set_handler(self, value, builder: Builder): def frozenset_handler(self, value, builder: Builder): if isinstance(value, frozenset): if len(value) == 0: - return builder.create_value(value, "frozenset()") + return builder.create_code(value, "frozenset()") else: return builder.create_call(frozenset, [set(value)]) + # -8<- [start:Enum] + @customize + def enum_handler(self, value, builder: Builder): + if isinstance(value, Enum): + qualname = type(value).__qualname__ + name = qualname.split(".")[0] + + return builder.create_code( + value, f"{type(value).__qualname__}.{value.name}" + ).with_import(type(value).__module__, name) + + # -8<- [end:Enum] + + @customize + def flag_handler(self, value, builder: Builder): + if isinstance(value, Flag): + qualname = type(value).__qualname__ + name = qualname.split(".")[0] + + return builder.create_code( + value, + " | ".join( + f"{qualname}.{flag.name}" for flag in type(value) if flag in value + ), + ).with_import(type(value).__module__, name) + @customize def dataclass_handler(self, value, builder: Builder): @@ -190,7 +218,7 @@ def dirty_equals_handler(self, value, builder: Builder): if is_dirty_equal(value) and builder._build_new_value: if isinstance(value, type): - return builder.create_value(value, value.__name__).with_import( + return builder.create_code(value, value.__name__).with_import( "dirty_equals", value.__name__ ) else: @@ -206,7 +234,7 @@ def dirty_equals_handler(self, value, builder: Builder): @customize(tryfirst=True) def context_value_handler(self, value, builder: Builder): if isinstance(value, ContextValue): - return builder.create_value(value.value, value.name) + return builder.create_code(value.value, value.name) @customize def outsource_handler(self, value, builder: Builder): diff --git a/src/inline_snapshot/testing/_example.py b/src/inline_snapshot/testing/_example.py index 69f63fb0..7121e360 100644 --- a/src/inline_snapshot/testing/_example.py +++ b/src/inline_snapshot/testing/_example.py @@ -319,6 +319,10 @@ def report_error(message): snapshot_flags = set() old_modules = sys.modules + old_path = sys.path[:] + # Add tmp_path to sys.path so modules can be imported normally + sys.path.insert(0, str(tmp_path)) + try: enter_snapshot_context() session.load_config( @@ -406,6 +410,7 @@ def fail(message): assert stderr == f"ERROR: {e}\n" finally: sys.modules = old_modules + sys.path = old_path leave_snapshot_context() if reported_categories is not None: diff --git a/tests/conftest.py b/tests/conftest.py index 2746e53a..7aacfb46 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -317,12 +317,12 @@ class InlineSnapshotExtension: @customize def fakedatetime_handler(self,value,builder): if isinstance(value,FakeDatetime): - return builder.create_value(value,value.__repr__().replace("FakeDatetime","datetime.datetime")) + return builder.create_code(value,value.__repr__().replace("FakeDatetime","datetime.datetime")) @customize def fakedate_handler(self,value,builder): if isinstance(value,FakeDate): - return builder.create_value(value,value.__repr__().replace("FakeDate","datetime.date")) + return builder.create_code(value,value.__repr__().replace("FakeDate","datetime.date")) @pytest.fixture(autouse=True) diff --git a/tests/test_code_repr.py b/tests/test_code_repr.py index 169ab440..8ceb291e 100644 --- a/tests/test_code_repr.py +++ b/tests/test_code_repr.py @@ -19,33 +19,41 @@ def test_enum(check_update): - assert ( - check_update( - """ + Example( + { + "color.py": """ from enum import Enum class color(Enum): val="val" +def get_color(): + return [color.val,color.val] -assert [color.val] == snapshot() - - """, - flags="create", - ) - == snapshot( - """\ - -from enum import Enum - -class color(Enum): - val="val" + """, + "test_color.py": """\ +from inline_snapshot import snapshot +from color import get_color +def test_enum(): + assert get_color() == snapshot() +""", + } + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "test_color.py": """\ +from inline_snapshot import snapshot +from color import get_color -assert [color.val] == snapshot([color.val]) +from color import color +def test_enum(): + assert get_color() == snapshot([color.val, color.val]) """ - ) + } + ), ) diff --git a/tests/test_docs.py b/tests/test_docs.py index 518686f2..1edcc124 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -319,12 +319,12 @@ class InlineSnapshotExtension: @customize def fakedatetime_handler(self,value,builder): if isinstance(value,FakeDatetime): - return builder.create_value(value,value.__repr__().replace("FakeDatetime","datetime.datetime")) + return builder.create_code(value,value.__repr__().replace("FakeDatetime","datetime.datetime")) @customize def fakedate_handler(self,value,builder): if isinstance(value,FakeDate): - return builder.create_value(value,value.__repr__().replace("FakeDate","datetime.date")) + return builder.create_code(value,value.__repr__().replace("FakeDate","datetime.date")) @pytest.fixture(autouse=True) From caaa7b6a42e38889a0347308d5e1d897904cd20d Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sat, 10 Jan 2026 21:56:36 +0100 Subject: [PATCH 37/72] docs: moved some docs --- docs/customize_repr.md | 2 +- docs/{customize.md => plugin.md} | 100 ++++++++++--- mkdocs.yml | 2 +- pyproject.toml | 4 +- src/inline_snapshot/__init__.py | 5 +- src/inline_snapshot/_customize/_builder.py | 13 +- src/inline_snapshot/_customize/_custom.py | 5 + .../_customize/_custom_external.py | 3 +- .../_customize/_custom_value.py | 20 ++- src/inline_snapshot/_global_state.py | 6 +- src/inline_snapshot/_snapshot/__init__.py | 0 src/inline_snapshot/plugin/__init__.py | 5 + src/inline_snapshot/plugin/_context_value.py | 8 - .../plugin/_context_variable.py | 17 +++ src/inline_snapshot/plugin/_default_plugin.py | 4 +- src/inline_snapshot/plugin/_spec.py | 141 ++++++++++-------- tests/adapter/test_dataclass.py | 2 +- tests/conftest.py | 2 +- tests/test_customize.py | 2 +- tests/test_docs.py | 5 +- 20 files changed, 234 insertions(+), 112 deletions(-) rename docs/{customize.md => plugin.md} (80%) create mode 100644 src/inline_snapshot/_snapshot/__init__.py create mode 100644 src/inline_snapshot/plugin/__init__.py delete mode 100644 src/inline_snapshot/plugin/_context_value.py create mode 100644 src/inline_snapshot/plugin/_context_variable.py diff --git a/docs/customize_repr.md b/docs/customize_repr.md index 6eec2fc1..5ff5e621 100644 --- a/docs/customize_repr.md +++ b/docs/customize_repr.md @@ -4,7 +4,7 @@ You should use ``` python title="conftest.py" - class InlineSnapshotExtension: + class InlineSnapshotPlugin: @customize def my_class_handler(value, builder): if isinstance(value, MyClass): diff --git a/docs/customize.md b/docs/plugin.md similarity index 80% rename from docs/customize.md rename to docs/plugin.md index 1a0d71b5..716c17b6 100644 --- a/docs/customize.md +++ b/docs/plugin.md @@ -1,14 +1,58 @@ -`@customize` allows you to register special hooks to control how inline-snapshot generates your snapshots. -You should use it when you find yourself manually editing snapshots after they were created by inline-snapshot. -`@customize` hooks can have the following arguments (but you do not have to use all of them). -* **value:** the value of your snapshot that is currently being converted to source code. -* **builder:** your `Builder` object can be used to create Custom objects that represent your new code. -* **local_vars:** a list of objects with `name` and `value` attributes that represent the local variables that are usable in your snapshot. -* **global_vars:** same as `local_vars`, but for global variables. +inline-snapshot provides a plugin architecture based on [pluggy](https://pluggy.readthedocs.io/en/latest/index.html) which can be used to extend and customize it. + +The plugins are searched in your `conftest.py` and has to be called `InlineSnapshotPlugin`. + +You can also create packages which provide plugins using setuptools entry points. + +### Creating a Plugin Package + +To distribute inline-snapshot plugins as a package, register your plugin class using the `inline_snapshot` entry point in your `setup.py` or `pyproject.toml`: + +=== "pyproject.toml (recommended)" + ``` toml + [project.entry-points.inline-snapshot] + my_plugin = "my_package.plugin:MyInlineSnapshotPlugin" + ``` + +=== "setup.py" + ``` python + setup( + name="my-inline-snapshot-plugin", + entry_points={ + "inline-snapshot": [ + "my_plugin = my_package.plugin:MyInlineSnapshotPlugin", + ], + }, + ) + ``` + +``` python title="my_package/plugin.py" +class MyInlineSnapshotPlugin: ... +``` + +Once installed, the plugin will be automatically loaded by inline-snapshot. + +### Plugin Specification + +::: inline_snapshot.plugin + options: + heading_level: 3 + members: [InlineSnapshotPluginSpec] + show_root_heading: false + show_bases: false + show_source: false + + + +## Customize Examples + +The [customize][inline_snapshot.plugin.InlineSnapshotPluginSpec.customize] hook controls how inline-snapshot generates your snapshots. +You should use it when you find yourself manually editing snapshots after they were created by inline-snapshot. -## Custom constructor methods + +### Custom constructor methods One use case might be that you have a dataclass with a special constructor function that can be used for specific instances of this dataclass, and you want inline-snapshot to use this constructor when possible. @@ -27,8 +71,6 @@ class Rect: ``` You can define a hook in your `conftest.py` that checks if your value is a square and calls the correct constructor function. -Inline-snapshot tries each hook until it finds one that does not return None. -It keeps converting this value until a hook returns a Custom object, which can be created with the `create_*` methods of the [`Builder`][inline_snapshot.Builder]. ``` python title="conftest.py" @@ -37,7 +79,7 @@ from inline_snapshot import customize from inline_snapshot import Builder -class InlineSnapshotExtension: +class InlineSnapshotPlugin: @customize def square_handler(self, value, builder: Builder): if isinstance(value, Rect) and value.width == value.height: @@ -62,7 +104,7 @@ def test_square(): 2. Your handler is used because you created a Rect that happens to have the same width and height 3. Your handler is not used because width and height are different -## dirty-equal expressions +### dirty-equal expressions It can also be used to instruct inline-snapshot to use specific dirty-equals expressions for specific values. @@ -72,7 +114,7 @@ from inline_snapshot import Builder from dirty_equals import IsNow -class InlineSnapshotExtension: +class InlineSnapshotPlugin: @customize def is_now_handler(self, value): if value == IsNow(): @@ -101,7 +143,7 @@ def test_is_now(): Using `@customize` with dirty-equals is a one-way ticket. Once the code is created, inline-snapshot does not know if it was created by inline-snapshot itself or by the user and will not change it when you change the `@customize` implementation, because it has to assume that it was created by the user. -## Conditional external objects +### Conditional external objects `create_external` can be used to store values in external files if a specific criterion is met. @@ -112,7 +154,7 @@ from inline_snapshot import Builder from dirty_equals import IsNow -class InlineSnapshotExtension: +class InlineSnapshotPlugin: @customize def long_string_handler(self, value, builder: Builder): if isinstance(value, str) and value.count("\n") > 5: @@ -140,7 +182,7 @@ c\ ) ``` -## Reusing local variables +### Reusing local variables There are times when your local or global variables become part of your snapshots, like uuids or user names. Customize hooks accept `local_vars` and `global_vars` as arguments that can be used to generate the code. @@ -151,7 +193,7 @@ from inline_snapshot import customize from inline_snapshot import Builder -class InlineSnapshotExtension: +class InlineSnapshotPlugin: @customize def local_var_handler(self, value, local_vars): for local in local_vars: @@ -187,7 +229,7 @@ It is up to you to set the rules that work best in your project. This is also the reason why inline-snapshot does not provide default customizations that check your local variables. The rules are project-specific and what might work well for one project can cause problems for others. -## Creating special code +### Creating special code Let's say that you have an array of secrets which are used in your code. @@ -220,7 +262,7 @@ def test_my_class(): ``` Maybe this is not what you want because the secret is a different one in CI or for every test run or the raw value leads to unreadable tests. -What you can do now, instead of replacing `"some_other_secret"` with `secrets[1]` by hand, is to tell inline-snapshot how it should generate this code in your *conftest.py*. +What you can do now, instead of replacing `"some_other_secret"` with `secrets[1]` by hand, is to tell inline-snapshot in your *conftest.py* how it should generate this code. ``` python title="conftest.py" @@ -228,7 +270,7 @@ from my_secrets import secrets from inline_snapshot import customize, Builder -class InlineSnapshotExtension: +class InlineSnapshotPlugin: @customize def secret_handler(self, value, builder: Builder): for i, secret in enumerate(secrets): @@ -263,12 +305,26 @@ def test_my_class(): You can read [here](categories.md#update) more about it. -# Reference + + + + + + + +## Reference +::: inline_snapshot.plugin + options: + heading_level: 3 + members: [hookimpl,customize] + show_root_heading: false + show_bases: false + show_source: false ::: inline_snapshot options: heading_level: 3 - members: [customize,Builder,Custom] + members: [Builder,Custom,CustomCode,ContextVariable] show_root_heading: false show_bases: false show_source: false diff --git a/mkdocs.yml b/mkdocs.yml index cf13ae26..fb37b78f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -47,6 +47,7 @@ nav: - Pytest: pytest.md - PyCharm: pycharm.md - Categories: categories.md + - Plugin: plugin.md - Code generation: code_generation.md - Limitations: limitations.md - Alternatives: alternatives.md @@ -63,7 +64,6 @@ nav: - external_file(): external/external_file.md - outsource(): external/outsource.md - '@register_format()': external/register_format.md - - '@customize': customize.md - '@customize_repr': customize_repr.md - types: types.md - get_snapshot_value(): get_snapshot_value.md diff --git a/pyproject.toml b/pyproject.toml index 81f9b7e4..b89e05d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -127,9 +127,9 @@ dependencies = [ ] [tool.hatch.envs.docs.scripts] -build = "mkdocs build --strict" +build = "mkdocs build --strict {args}" export-deps = "pip freeze" -serve = "mkdocs serve --livereload" +serve = "mkdocs serve --livereload {args}" [tool.hatch.envs.default] diff --git a/src/inline_snapshot/__init__.py b/src/inline_snapshot/__init__.py index 48bb7569..b10afcdf 100644 --- a/src/inline_snapshot/__init__.py +++ b/src/inline_snapshot/__init__.py @@ -1,5 +1,6 @@ from inline_snapshot._external._diff import BinaryDiff from inline_snapshot._external._diff import TextDiff +from inline_snapshot.plugin._context_variable import ContextVariable from ._code_repr import HasRepr from ._code_repr import customize_repr @@ -18,7 +19,7 @@ from ._types import Category from ._types import Snapshot from ._unmanaged import declare_unmanaged -from .plugin._spec import customize +from .plugin import customize from .version import __version__ __all__ = [ @@ -40,7 +41,7 @@ "declare_unmanaged", "get_snapshot_value", "__version__", - "customize", "Custom", "Builder", + "ContextVariable", ] diff --git a/src/inline_snapshot/_customize/_builder.py b/src/inline_snapshot/_customize/_builder.py index 38fcc903..d94bd751 100644 --- a/src/inline_snapshot/_customize/_builder.py +++ b/src/inline_snapshot/_customize/_builder.py @@ -5,7 +5,7 @@ from inline_snapshot._adapter_context import AdapterContext from inline_snapshot._compare_context import compare_context -from inline_snapshot.plugin._context_value import ContextValue +from inline_snapshot.plugin._context_variable import ContextVariable from ._custom import Custom from ._custom_call import CustomCall @@ -31,12 +31,12 @@ def _get_handler(self, v) -> Custom: and (frame := self._snapshot_context.frame) is not None ): local_vars = [ - ContextValue(var_name, var_value) + ContextVariable(var_name, var_value) for var_name, var_value in frame.locals.items() if "@" not in var_name ] global_vars = [ - ContextValue(var_name, var_value) + ContextVariable(var_name, var_value) for var_name, var_value in frame.globals.items() if "@" not in var_name ] @@ -132,7 +132,10 @@ def create_code(self, value, repr: str | None = None) -> CustomCode: """ Creates an intermediate node for a value with a custom representation which can be used as a result for your customization function. - `create_code(my_obj, 'MyClass')` becomes `MyClass` in the code. - Use this when you want to control the exact string representation of a value. + `create_code(value, '{value-1!r}+1')` becomes `4+1` in the code for a given `value=5`. + Use this when you need to control the exact string representation of a value. + + You can use `.with_import(module,name)` to create an import in the code. + `create_code(Color.red,"Color.red").with_import("my_colors","Color")` will create a `from my_colors import Color` if needed and `Color.red` in the code. """ return CustomCode(value, repr) diff --git a/src/inline_snapshot/_customize/_custom.py b/src/inline_snapshot/_customize/_custom.py index eabc34e6..08243dd7 100644 --- a/src/inline_snapshot/_customize/_custom.py +++ b/src/inline_snapshot/_customize/_custom.py @@ -15,6 +15,11 @@ class Custom(ABC): + """ + Custom objects are returned by the `create_*` functions of the builder. + They should only be returned in your customize function or used as arguments for other `create_*` functions. + """ + node_type: type[ast.AST] = ast.AST original_value: Any diff --git a/src/inline_snapshot/_customize/_custom_external.py b/src/inline_snapshot/_customize/_custom_external.py index 0ef47092..91a619e9 100644 --- a/src/inline_snapshot/_customize/_custom_external.py +++ b/src/inline_snapshot/_customize/_custom_external.py @@ -10,7 +10,6 @@ from inline_snapshot._change import ExternalChange from inline_snapshot._external._external_location import ExternalLocation from inline_snapshot._external._format._protocol import get_format_handler -from inline_snapshot._global_state import state from ._custom import Custom @@ -25,6 +24,8 @@ def map(self, f): return f(self.value) def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + from inline_snapshot._global_state import state + storage_name = self.storage or state().config.default_storage format = get_format_handler(self.value, self.format or "") diff --git a/src/inline_snapshot/_customize/_custom_value.py b/src/inline_snapshot/_customize/_custom_value.py index a6bf381e..bc946374 100644 --- a/src/inline_snapshot/_customize/_custom_value.py +++ b/src/inline_snapshot/_customize/_custom_value.py @@ -54,12 +54,30 @@ def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: return self.repr_str def __repr__(self): - return f"CustomCode({self.repr_str})" + return f"CustomValue({self.repr_str})" def _needed_imports(self): yield from self._imports def with_import(self, module, name, simplify=True): + """ + Adds a `from module import name` statement to the generated code. + + Arguments: + module: The module path to import from (e.g., "my_module" or "package.submodule"). + name: The name to import from the module (e.g., "MyClass" or "my_function"). + simplify: If True (default), attempts to find the shortest valid import path + by checking parent modules. For example, if "package.submodule.MyClass" + is accessible from "package", it will use the shorter path. + + Returns: + The CustomCode instance itself, allowing for method chaining. + + Example: + ``` python + builder.create_value(my_obj, "secrets[0]").with_import("my_secrets", "secrets") + ``` + """ if simplify: module = _simplify_module_path(module, name) self._imports.append([module, name]) diff --git a/src/inline_snapshot/_global_state.py b/src/inline_snapshot/_global_state.py index 3338feb2..fe2c67ef 100644 --- a/src/inline_snapshot/_global_state.py +++ b/src/inline_snapshot/_global_state.py @@ -15,6 +15,7 @@ import pluggy from inline_snapshot._config import Config +from inline_snapshot.plugin._spec import inline_snapshot_plugin_name if TYPE_CHECKING: from inline_snapshot._external._format._protocol import Format @@ -53,7 +54,7 @@ class State: ) pm: pluggy.PluginManager = field( - default_factory=lambda: pluggy.PluginManager("inline_snapshot") + default_factory=lambda: pluggy.PluginManager(inline_snapshot_plugin_name) ) def new_tmp_path(self, suffix: str) -> Path: @@ -84,7 +85,6 @@ def enter_snapshot_context(): from .plugin._spec import InlineSnapshotPluginSpec _current.pm.add_hookspecs(InlineSnapshotPluginSpec) - _current.pm.load_setuptools_entrypoints("inline_snapshot") from .plugin._default_plugin import InlineSnapshotPlugin @@ -104,6 +104,8 @@ def enter_snapshot_context(): else: _current.pm.register(InlineSnapshotPydanticPlugin()) + _current.pm.load_setuptools_entrypoints(inline_snapshot_plugin_name) + def leave_snapshot_context(): global _current diff --git a/src/inline_snapshot/_snapshot/__init__.py b/src/inline_snapshot/_snapshot/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/inline_snapshot/plugin/__init__.py b/src/inline_snapshot/plugin/__init__.py new file mode 100644 index 00000000..4958be0e --- /dev/null +++ b/src/inline_snapshot/plugin/__init__.py @@ -0,0 +1,5 @@ +from ._spec import InlineSnapshotPluginSpec +from ._spec import customize +from ._spec import hookimpl + +__all__ = ("InlineSnapshotPluginSpec", "customize", "hookimpl") diff --git a/src/inline_snapshot/plugin/_context_value.py b/src/inline_snapshot/plugin/_context_value.py deleted file mode 100644 index 2d2a05b0..00000000 --- a/src/inline_snapshot/plugin/_context_value.py +++ /dev/null @@ -1,8 +0,0 @@ -from dataclasses import dataclass -from typing import Any - - -@dataclass -class ContextValue: - name: str - value: Any diff --git a/src/inline_snapshot/plugin/_context_variable.py b/src/inline_snapshot/plugin/_context_variable.py new file mode 100644 index 00000000..2a148224 --- /dev/null +++ b/src/inline_snapshot/plugin/_context_variable.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass +from typing import Any + + +@dataclass +class ContextVariable: + """ + representation of a value in the local or global context of a snapshot. + + This type can also be returned in an customize function and is then converted into an [Custom][inline_snapshot.Custom] object by an other Function + """ + + name: str + "the name of the variable" + + value: Any + "the value of the variable" diff --git a/src/inline_snapshot/plugin/_default_plugin.py b/src/inline_snapshot/plugin/_default_plugin.py index 72ffb8b1..f22a0339 100644 --- a/src/inline_snapshot/plugin/_default_plugin.py +++ b/src/inline_snapshot/plugin/_default_plugin.py @@ -19,7 +19,7 @@ from inline_snapshot._unmanaged import is_dirty_equal from inline_snapshot._unmanaged import is_unmanaged from inline_snapshot._utils import triple_quote -from inline_snapshot.plugin._context_value import ContextValue +from inline_snapshot.plugin._context_variable import ContextVariable from ._spec import customize @@ -233,7 +233,7 @@ def dirty_equals_handler(self, value, builder: Builder): @customize(tryfirst=True) def context_value_handler(self, value, builder: Builder): - if isinstance(value, ContextValue): + if isinstance(value, ContextVariable): return builder.create_code(value.value, value.name) @customize diff --git a/src/inline_snapshot/plugin/_spec.py b/src/inline_snapshot/plugin/_spec.py index 553394a3..c2c9fbce 100644 --- a/src/inline_snapshot/plugin/_spec.py +++ b/src/inline_snapshot/plugin/_spec.py @@ -5,84 +5,107 @@ import pluggy from inline_snapshot._customize._builder import Builder -from inline_snapshot.plugin._context_value import ContextValue +from inline_snapshot.plugin._context_variable import ContextVariable -hookspec = pluggy.HookspecMarker("inline_snapshot") -hookimpl = pluggy.HookimplMarker("inline_snapshot") +inline_snapshot_plugin_name = "inline-snapshot" + +hookspec = pluggy.HookspecMarker(inline_snapshot_plugin_name) +""" +The pluggy hookspec marker for inline_snapshot. +""" + +hookimpl = pluggy.HookimplMarker(inline_snapshot_plugin_name) +""" +The pluggy hookimpl marker for inline_snapshot. +""" customize = partial(hookimpl, specname="customize") """ - Registers a function as a customization hook inside inline-snapshot. +Decorator to mark a function as an implementation of the `customize` hook which can be used instead of `hookimpl(specname="customize")`. +""" - Customization hooks allow you to control how objects are represented in snapshot code. - When inline-snapshot generates code for a value, it calls each registered customization - function in reverse order of registration until one returns a Custom object. - **Important**: Customization handlers should be registered in your `conftest.py` file to ensure - they are loaded before your tests run. +class InlineSnapshotPluginSpec: + @hookspec(firstresult=True) + def customize( + self, + value: Any, + builder: Builder, + local_vars: List[ContextVariable], + global_vars: List[ContextVariable], + ) -> Any: + """ + The customize hook is called every time a snapshot value should be converted into code. + This hook allows you to control how inline-snapshot represents objects in generated code. + When multiple handlers are registered, they are called until one returns a non-None value. + `customize` is also called for each attribute of the converted hook which is not a Custom node, which means that a hook for `int` does not only apply for `snapshot(5)` but also for `snaspshot([1,2,3])` 3 times. - Example: - Basic usage with a custom class: - - ``` python title="myclass.py" - class MyClass: - def __init__(self, arg1, arg2, key=None): - self.arg1 = arg1 - self.arg2 = arg2 - self.key_attr = key - ``` + Arguments: + value: The Python object that needs to be converted into source code representation. + This is the actual runtime value from your test. + builder: A Builder instance providing methods to construct custom code representations. + Use methods like `create_call()`, `create_dict()`, `create_external()`, etc. + local_vars: List of local variables available in the current scope, each containing + `name` and `value` attributes. Useful for referencing existing variables + instead of creating new literals. + global_vars: List of global variables available in the current scope, each containing + `name` and `value` attributes. - - ``` python title="conftest.py" - from myclass import MyClass - from inline_snapshot import customize + Returns: + (Custom): created using [Builder][inline_snapshot.Builder] `create_*` methods. + (None): if this handler doesn't apply to the given value. + (Something else): when the next handler should process the value. + Example: + You can use @customize when you want to specify multiple handlers in the same class: - @customize - def my_custom_handler(value, builder): - if isinstance(value, MyClass): - # Generate code like: MyClass(arg1, arg2, key=value) - return builder.create_call( - MyClass, [value.arg1, value.arg2], {"key": value.key_attr} - ) - return None # Let other handlers process this value - ``` - - ``` python - from inline_snapshot import snapshot - from myclass import MyClass + === "with @customize" + + ``` python title="conftest.py" + from inline_snapshot.plugin import customize - def test_myclass(): - obj = MyClass(42, "hello", key="world") - assert obj == snapshot(MyClass(42, "hello", key="world")) - ``` + class InlineSnapshotPlugin: + @customize + def binary_numbers(self, value, builder, local_vars, global_vars): + if isinstance(value, int): + return builder.create_code(value, bin(value)) - Note: - - **Always register handlers in `conftest.py`** to ensure they're available for all tests - - If no handler returns a Custom object, a default representation is used - - Use builder methods (`create_call`, `create_external`) to construct representations - - Always return `None` if your handler doesn't apply to the given value type - - The builder automatically handles recursive conversion of nested values, therfor `create_list` and `create_dict` are unlikely needed because you can just use `[]` or `{}` + @customize + def repeated_strings(self, value, builder): + if isinstance(value, str) and value == value[0] * len(value): + return builder.create_code(value, f"'{value[0]}'*{len(value)}") + ``` + === "by method name" - See Also: - - [Builder][inline_snapshot._customize.Builder]: Available builder methods - - [Custom][inline_snapshot._customize.Custom]: Base class for custom representations - """ + + ``` python title="conftest.py" + class InlineSnapshotPlugin: + def customize(self, value, builder, local_vars, global_vars): + if isinstance(value, int): + return builder.create_code(value, bin(value)) + if isinstance(value, str) and value == value[0] * len(value): + return builder.create_code(value, f"'{value[0]}'*{len(value)}") + ``` -class InlineSnapshotPluginSpec: - @hookspec(firstresult=True) - def customize( - self, - value: Any, - builder: Builder, - local_vars: List[ContextValue], - global_vars: List[ContextValue], - ) -> Any: ... + + + ``` python title="test_customizations.py" + from inline_snapshot import snapshot + + + def test_customizations(): + assert ["aaaaaaaaaaaaaaa", "bbbbbb"] == snapshot(["a" * 15, "b" * 6]) + assert 18856 == snapshot(0b100100110101000) + ``` + + + + """ # @hookspec # def format_code(self, filename, str) -> str: ... diff --git a/tests/adapter/test_dataclass.py b/tests/adapter/test_dataclass.py index 6aa073c8..414f286f 100644 --- a/tests/adapter/test_dataclass.py +++ b/tests/adapter/test_dataclass.py @@ -502,7 +502,7 @@ def __eq__(self,other): from inline_snapshot import customize from helper import L -class InlineSnapshotExtension: +class InlineSnapshotPlugin: @customize def handle_L(self,value,builder): if isinstance(value,L): diff --git a/tests/conftest.py b/tests/conftest.py index 7aacfb46..22aea413 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -313,7 +313,7 @@ def setup(self, source: str, add_header=True): from freezegun.api import FakeDatetime,FakeDate from inline_snapshot import customize -class InlineSnapshotExtension: +class InlineSnapshotPlugin: @customize def fakedatetime_handler(self,value,builder): if isinstance(value,FakeDatetime): diff --git a/tests/test_customize.py b/tests/test_customize.py index 4d96de99..29846dc1 100644 --- a/tests/test_customize.py +++ b/tests/test_customize.py @@ -16,7 +16,7 @@ def test_custom_dirty_equal(original, flag): from inline_snapshot import Builder from dirty_equals import IsStr -class InlineSnapshotExtension: +class InlineSnapshotPlugin: @customize def re_handler(self,value, builder: Builder): if value == IsStr(regex="[a-z]"): diff --git a/tests/test_docs.py b/tests/test_docs.py index 1edcc124..a3da8e91 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -48,7 +48,7 @@ def map_code_blocks(file: Path, func): code = None indent = "" block_start_linenum: Optional[int] = None - block_options: Optional[str] = None + block_options: dict[str, str] = {} code_header = None header_line = "" block_found = False @@ -68,7 +68,6 @@ def map_code_blocks(file: Path, func): if block_end.fullmatch(line.strip()) and is_block: # ``` is_block = False - assert block_options is not None assert block_start_linenum is not None code = "\n".join(block_lines) + "\n" @@ -315,7 +314,7 @@ def file_test( from freezegun.api import FakeDatetime,FakeDate from inline_snapshot import customize -class InlineSnapshotExtension: +class InlineSnapshotPlugin: @customize def fakedatetime_handler(self,value,builder): if isinstance(value,FakeDatetime): From 299e9091421d073cc63b7dc4f53dc4b31595ffff Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sun, 11 Jan 2026 09:31:30 +0100 Subject: [PATCH 38/72] docs: moved some docs --- docs/plugin.md | 18 +++++++++--------- src/inline_snapshot/__init__.py | 7 ------- src/inline_snapshot/_snapshot/dict_value.py | 2 +- src/inline_snapshot/plugin/__init__.py | 12 +++++++++++- tests/adapter/test_dataclass.py | 2 +- tests/conftest.py | 2 +- tests/test_customize.py | 4 ++-- tests/test_docs.py | 2 +- 8 files changed, 26 insertions(+), 23 deletions(-) diff --git a/docs/plugin.md b/docs/plugin.md index 716c17b6..8a1c9722 100644 --- a/docs/plugin.md +++ b/docs/plugin.md @@ -75,8 +75,8 @@ You can define a hook in your `conftest.py` that checks if your value is a squar ``` python title="conftest.py" from rect import Rect -from inline_snapshot import customize -from inline_snapshot import Builder +from inline_snapshot.plugin import customize +from inline_snapshot.plugin import Builder class InlineSnapshotPlugin: @@ -109,8 +109,8 @@ It can also be used to instruct inline-snapshot to use specific dirty-equals exp ``` python title="conftest.py" -from inline_snapshot import customize -from inline_snapshot import Builder +from inline_snapshot.plugin import customize +from inline_snapshot.plugin import Builder from dirty_equals import IsNow @@ -149,8 +149,8 @@ def test_is_now(): ``` python title="conftest.py" -from inline_snapshot import customize -from inline_snapshot import Builder +from inline_snapshot.plugin import customize +from inline_snapshot.plugin import Builder from dirty_equals import IsNow @@ -189,8 +189,8 @@ Customize hooks accept `local_vars` and `global_vars` as arguments that can be u ``` python title="conftest.py" -from inline_snapshot import customize -from inline_snapshot import Builder +from inline_snapshot.plugin import customize +from inline_snapshot.plugin import Builder class InlineSnapshotPlugin: @@ -267,7 +267,7 @@ What you can do now, instead of replacing `"some_other_secret"` with `secrets[1] ``` python title="conftest.py" from my_secrets import secrets -from inline_snapshot import customize, Builder +from inline_snapshot.plugin import customize, Builder class InlineSnapshotPlugin: diff --git a/src/inline_snapshot/__init__.py b/src/inline_snapshot/__init__.py index b10afcdf..69a573b2 100644 --- a/src/inline_snapshot/__init__.py +++ b/src/inline_snapshot/__init__.py @@ -1,11 +1,8 @@ from inline_snapshot._external._diff import BinaryDiff from inline_snapshot._external._diff import TextDiff -from inline_snapshot.plugin._context_variable import ContextVariable from ._code_repr import HasRepr from ._code_repr import customize_repr -from ._customize._builder import Builder -from ._customize._custom import Custom from ._exceptions import UsageError from ._external._external import external from ._external._external_file import external_file @@ -19,7 +16,6 @@ from ._types import Category from ._types import Snapshot from ._unmanaged import declare_unmanaged -from .plugin import customize from .version import __version__ __all__ = [ @@ -41,7 +37,4 @@ "declare_unmanaged", "get_snapshot_value", "__version__", - "Custom", - "Builder", - "ContextVariable", ] diff --git a/src/inline_snapshot/_snapshot/dict_value.py b/src/inline_snapshot/_snapshot/dict_value.py index efec48f5..bb92dd8d 100644 --- a/src/inline_snapshot/_snapshot/dict_value.py +++ b/src/inline_snapshot/_snapshot/dict_value.py @@ -10,8 +10,8 @@ from .._change import Delete from .._change import DictInsert from .._global_state import state -from .._inline_snapshot import UndecidedValue from .generic_value import GenericValue +from .undecided_value import UndecidedValue class DictValue(GenericValue): diff --git a/src/inline_snapshot/plugin/__init__.py b/src/inline_snapshot/plugin/__init__.py index 4958be0e..a397cdef 100644 --- a/src/inline_snapshot/plugin/__init__.py +++ b/src/inline_snapshot/plugin/__init__.py @@ -1,5 +1,15 @@ +from .._customize._builder import Builder +from .._customize._custom import Custom +from ._context_variable import ContextVariable from ._spec import InlineSnapshotPluginSpec from ._spec import customize from ._spec import hookimpl -__all__ = ("InlineSnapshotPluginSpec", "customize", "hookimpl") +__all__ = ( + "InlineSnapshotPluginSpec", + "customize", + "hookimpl", + "Builder", + "ContextVariable", + "Custom", +) diff --git a/tests/adapter/test_dataclass.py b/tests/adapter/test_dataclass.py index 414f286f..cfb81682 100644 --- a/tests/adapter/test_dataclass.py +++ b/tests/adapter/test_dataclass.py @@ -499,7 +499,7 @@ def __eq__(self,other): return other.l==self.l """, "tests/conftest.py": """\ -from inline_snapshot import customize +from inline_snapshot.plugin import customize from helper import L class InlineSnapshotPlugin: diff --git a/tests/conftest.py b/tests/conftest.py index 22aea413..c3a70545 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -311,7 +311,7 @@ def setup(self, source: str, add_header=True): import datetime import pytest from freezegun.api import FakeDatetime,FakeDate -from inline_snapshot import customize +from inline_snapshot.plugin import customize class InlineSnapshotPlugin: @customize diff --git a/tests/test_customize.py b/tests/test_customize.py index 29846dc1..00e6b2cd 100644 --- a/tests/test_customize.py +++ b/tests/test_customize.py @@ -12,8 +12,8 @@ def test_custom_dirty_equal(original, flag): Example( { "tests/conftest.py": """\ -from inline_snapshot import customize -from inline_snapshot import Builder +from inline_snapshot.plugin import customize +from inline_snapshot.plugin import Builder from dirty_equals import IsStr class InlineSnapshotPlugin: diff --git a/tests/test_docs.py b/tests/test_docs.py index a3da8e91..27c453ca 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -312,7 +312,7 @@ def file_test( import datetime import pytest from freezegun.api import FakeDatetime,FakeDate -from inline_snapshot import customize +from inline_snapshot.plugin import customize class InlineSnapshotPlugin: @customize From 2ba99005b3257e533c89359916bc2c65c788b925 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sun, 11 Jan 2026 15:59:08 +0100 Subject: [PATCH 39/72] docs: improved docs --- docs/plugin.md | 80 ++++++++++++++----- src/inline_snapshot/plugin/__init__.py | 3 + .../plugin/_context_variable.py | 10 ++- 3 files changed, 73 insertions(+), 20 deletions(-) diff --git a/docs/plugin.md b/docs/plugin.md index 8a1c9722..bd624a1f 100644 --- a/docs/plugin.md +++ b/docs/plugin.md @@ -2,13 +2,45 @@ inline-snapshot provides a plugin architecture based on [pluggy](https://pluggy.readthedocs.io/en/latest/index.html) which can be used to extend and customize it. -The plugins are searched in your `conftest.py` and has to be called `InlineSnapshotPlugin`. +## Overview -You can also create packages which provide plugins using setuptools entry points. +Plugins allow you to customize how inline-snapshot generates code for your snapshots. The primary use case is implementing custom representation logic through the `@customize` hook, which controls how Python objects are converted into source code. + +### When to Use Plugins + +You should consider creating a plugin when: + +- You find yourself manually editing snapshots after they are generated +- You want to use custom constructors or factory methods in your snapshots +- You need to reference local/global variables instead of hardcoding values +- You want to store certain values in external files based on specific criteria +- You need special code representations for your custom types + +### Plugin Capabilities + +Plugins can: + +- **Customize code generation**: Control how objects appear in snapshot code (e.g., use `Color.RED` instead of `Color(255, 0, 0)`) +- **Reference variables**: Use existing local or global variables in snapshots instead of literals +- **External storage**: Automatically store large or sensitive values in external files +- **Import management**: Automatically add necessary import statements to test files + +## Plugin Discovery + +inline-snapshot loads the plugins at the beginning of the session. +It searches for plugins in: +* installed packages with an `inline-snapshot` entry point +* your pytest `conftest.py` files + +### Loading Plugins from conftest.py + +Loading plugins from the `conftest.py` files is the recommended way when you want to change the behavior of inline-snapshot in your own project. + +The plugins are searched in your `conftest.py` and the name has to start with `InlineSnapshot*`. Each plugin which is loaded from your `conftest.py` is active globally for all your tests. ### Creating a Plugin Package -To distribute inline-snapshot plugins as a package, register your plugin class using the `inline_snapshot` entry point in your `setup.py` or `pyproject.toml`: +To distribute inline-snapshot plugins as a package, register your plugin class using the `inline-snapshot` entry point in your `setup.py` or `pyproject.toml`: === "pyproject.toml (recommended)" ``` toml @@ -28,8 +60,23 @@ To distribute inline-snapshot plugins as a package, register your plugin class u ) ``` +Your plugin class should contain methods decorated with `@customize`, just like in conftest.py: + ``` python title="my_package/plugin.py" -class MyInlineSnapshotPlugin: ... +from inline_snapshot.plugin import customize, Builder + + +class MyInlineSnapshotPlugin: + """ + This class will be instantiated by inline-snapshot when the package is installed. + Typically used by library authors who want to provide inline-snapshot integration. + """ + + @customize + def my_custom_handler(self, value, builder: Builder): + # Your customization logic here + if isinstance(value, YourCustomType): + return builder.create_call(YourCustomType, [value.arg]) ``` Once installed, the plugin will be automatically loaded by inline-snapshot. @@ -48,6 +95,8 @@ Once installed, the plugin will be automatically loaded by inline-snapshot. ## Customize Examples +The following examples demonstrate common use cases for the `@customize` hook. Each example shows how to implement custom representation logic for different scenarios. + The [customize][inline_snapshot.plugin.InlineSnapshotPluginSpec.customize] hook controls how inline-snapshot generates your snapshots. You should use it when you find yourself manually editing snapshots after they were created by inline-snapshot. @@ -98,11 +147,15 @@ def test_square(): assert Rect.make_square(5) == snapshot(Rect.make_square(5)) # (1)! assert Rect(1, 1) == snapshot(Rect.make_square(1)) # (2)! assert Rect(1, 2) == snapshot(Rect(width=1, height=2)) # (3)! + assert [Rect(3, 3), Rect(4, 5)] == snapshot( + [Rect.make_square(3), Rect(width=4, height=5)] + ) # (4)! ``` 1. Your handler is used because you created a square 2. Your handler is used because you created a Rect that happens to have the same width and height 3. Your handler is not used because width and height are different +4. The handler is applied recursively to each Rect inside the list - the first is converted to `make_square()` while the second uses the regular constructor ### dirty-equal expressions It can also be used to instruct inline-snapshot to use specific dirty-equals expressions for specific values. @@ -121,8 +174,7 @@ class InlineSnapshotPlugin: return IsNow ``` -Inline-snapshot provides a handler that can convert dirty-equals expressions back into source code. This allows you to return `IsNow` here without the need to construct a custom object with the builder. -This works because the value is converted with the customize functions until one hook uses the builder to create a Custom object. +As explained in the [customize hook specification][inline_snapshot.plugin.InlineSnapshotPluginSpec.customize], you can return types other than Custom objects. Inline-snapshot includes a built-in handler in its default plugin that converts dirty-equals expressions back into source code, which is why you can return `IsNow` directly without using the builder. This approach is much simpler than using `builder.create_call()` for complex dirty-equals expressions. ``` python title="test_is_now.py" @@ -139,7 +191,7 @@ def test_is_now(): 1. Inline-snapshot also creates the imports when they are missing !!! important - Inline-snapshot will never change the dirty-equals expressions in your code because they are unmanaged. + Inline-snapshot will never change the dirty-equals expressions in your code because they are [unmanaged](eq_snapshot.md#unmanaged-snapshot-values). Using `@customize` with dirty-equals is a one-way ticket. Once the code is created, inline-snapshot does not know if it was created by inline-snapshot itself or by the user and will not change it when you change the `@customize` implementation, because it has to assume that it was created by the user. @@ -151,7 +203,6 @@ def test_is_now(): ``` python title="conftest.py" from inline_snapshot.plugin import customize from inline_snapshot.plugin import Builder -from dirty_equals import IsNow class InlineSnapshotPlugin: @@ -164,7 +215,6 @@ class InlineSnapshotPlugin: ``` python title="test_long_strings.py" from inline_snapshot import snapshot -from datetime import datetime from inline_snapshot import external @@ -280,6 +330,8 @@ class InlineSnapshotPlugin: ) ``` +The [`create_code()`][inline_snapshot.plugin.Builder.create_code] method takes the actual value and its desired code representation, then [`with_import()`][inline_snapshot.plugin.CustomCode.with_import] adds the necessary import statement. + Inline-snapshot will now create the correct code and import statement when you run your tests with `--inline-snapshot=update`. @@ -316,15 +368,7 @@ def test_my_class(): ::: inline_snapshot.plugin options: heading_level: 3 - members: [hookimpl,customize] - show_root_heading: false - show_bases: false - show_source: false - -::: inline_snapshot - options: - heading_level: 3 - members: [Builder,Custom,CustomCode,ContextVariable] + members: [hookimpl,customize,Builder,Custom,CustomCode,ContextVariable] show_root_heading: false show_bases: false show_source: false diff --git a/src/inline_snapshot/plugin/__init__.py b/src/inline_snapshot/plugin/__init__.py index a397cdef..8e6c2b06 100644 --- a/src/inline_snapshot/plugin/__init__.py +++ b/src/inline_snapshot/plugin/__init__.py @@ -1,3 +1,5 @@ +from inline_snapshot._customize._custom_value import CustomCode + from .._customize._builder import Builder from .._customize._custom import Custom from ._context_variable import ContextVariable @@ -12,4 +14,5 @@ "Builder", "ContextVariable", "Custom", + "CustomCode", ) diff --git a/src/inline_snapshot/plugin/_context_variable.py b/src/inline_snapshot/plugin/_context_variable.py index 2a148224..3f57ede1 100644 --- a/src/inline_snapshot/plugin/_context_variable.py +++ b/src/inline_snapshot/plugin/_context_variable.py @@ -5,9 +5,15 @@ @dataclass class ContextVariable: """ - representation of a value in the local or global context of a snapshot. + Representation of a value in the local or global context of a snapshot. - This type can also be returned in an customize function and is then converted into an [Custom][inline_snapshot.Custom] object by an other Function + This type can be returned from a customize function to reference an existing variable + instead of creating a new literal value. Inline-snapshot includes a built-in handler + that converts ContextVariable instances into [Custom][inline_snapshot.plugin.Custom] + objects, generating code that references the variable by name. + + ContextVariable instances are provided via the `local_vars` and `global_vars` parameters + of the [customize hook][inline_snapshot.plugin.InlineSnapshotPluginSpec.customize]. """ name: str From 3852bc8550a42f990a17ab5f49db4c70a4790c0a Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sun, 11 Jan 2026 16:56:08 +0100 Subject: [PATCH 40/72] docs: sort imports --- README.md | 4 ++-- docs/code_generation.md | 4 ++-- docs/customize_repr.md | 4 ++-- docs/eq_snapshot.md | 28 ++++++++++++++-------------- docs/external/external.md | 8 ++++---- docs/external/register_format.md | 2 +- docs/plugin.md | 13 +++++-------- docs/testing.md | 8 ++++---- pyproject.toml | 20 +++----------------- src/inline_snapshot/_types.py | 2 +- src/inline_snapshot/extra.py | 22 +++++++++++----------- tests/test_docs.py | 10 ++++++++++ 12 files changed, 59 insertions(+), 66 deletions(-) diff --git a/README.md b/README.md index 54c09b98..5fdabde2 100644 --- a/README.md +++ b/README.md @@ -87,7 +87,7 @@ The following examples show how you can use inline-snapshot in your tests. Take ``` python -from inline_snapshot import snapshot, outsource, external +from inline_snapshot import external, outsource, snapshot def test_something(): @@ -135,9 +135,9 @@ strings\ ``` python -from inline_snapshot import snapshot import subprocess as sp import sys +from inline_snapshot import snapshot def run_python(cmd, stdout=None, stderr=None): diff --git a/docs/code_generation.md b/docs/code_generation.md index 3556780b..9a472640 100644 --- a/docs/code_generation.md +++ b/docs/code_generation.md @@ -7,8 +7,8 @@ It might be necessary to import the right modules to match the `repr()` output. === "original code" ``` python - from inline_snapshot import snapshot import datetime + from inline_snapshot import snapshot def something(): @@ -29,8 +29,8 @@ It might be necessary to import the right modules to match the `repr()` output. === "--inline-snapshot=create" ``` python hl_lines="18 19 20 21 22 23 24 25 26 27 28" - from inline_snapshot import snapshot import datetime + from inline_snapshot import snapshot def something(): diff --git a/docs/customize_repr.md b/docs/customize_repr.md index 5ff5e621..99c799b0 100644 --- a/docs/customize_repr.md +++ b/docs/customize_repr.md @@ -73,8 +73,8 @@ This implementation is then used by inline-snapshot if `repr()` is called during ``` python -from inline_snapshot import snapshot from my_class import MyClass +from inline_snapshot import snapshot def test_my_class(): @@ -97,8 +97,8 @@ You can also use `repr()` inside `__repr__()`, if you want to make your own type ``` python -from inline_snapshot import snapshot from enum import Enum +from inline_snapshot import snapshot class Pair: diff --git a/docs/eq_snapshot.md b/docs/eq_snapshot.md index d5af2cb2..41340811 100644 --- a/docs/eq_snapshot.md +++ b/docs/eq_snapshot.md @@ -117,8 +117,8 @@ You could initialize a test like this: ``` python -from inline_snapshot import snapshot import datetime +from inline_snapshot import snapshot def get_data(): @@ -136,8 +136,8 @@ If you use `--inline-snapshot=create`, inline-snapshot will record the current ` ``` python hl_lines="13 14 15" -from inline_snapshot import snapshot import datetime +from inline_snapshot import snapshot def get_data(): @@ -157,9 +157,9 @@ To avoid the test failing in future runs, replace the `datetime` with [dirty-equ ``` python -from inline_snapshot import snapshot -from dirty_equals import IsDatetime import datetime +from dirty_equals import IsDatetime +from inline_snapshot import snapshot def get_data(): @@ -182,9 +182,9 @@ Say a different part of the return data changes, such as the `payload` value: ``` python hl_lines="9" -from inline_snapshot import snapshot -from dirty_equals import IsDatetime import datetime +from dirty_equals import IsDatetime +from inline_snapshot import snapshot def get_data(): @@ -207,9 +207,9 @@ Re-running the test with `--inline-snapshot=fix` will update the snapshot to mat ``` python hl_lines="17" -from inline_snapshot import snapshot -from dirty_equals import IsDatetime import datetime +from dirty_equals import IsDatetime +from inline_snapshot import snapshot def get_data(): @@ -270,7 +270,7 @@ It tells inline-snapshot that the developer wants control over some part of the ``` python -from inline_snapshot import snapshot, Is +from inline_snapshot import Is, snapshot current_version = "1.5" @@ -292,7 +292,7 @@ The snapshot does not need to be fixed when `current_version` changes in the fut === "original code" ``` python - from inline_snapshot import snapshot, Is + from inline_snapshot import Is, snapshot def test_function(): @@ -303,7 +303,7 @@ The snapshot does not need to be fixed when `current_version` changes in the fut === "--inline-snapshot=fix" ``` python hl_lines="6" - from inline_snapshot import snapshot, Is + from inline_snapshot import Is, snapshot def test_function(): @@ -348,8 +348,8 @@ The following example shows how this can be used to run a tests with two differe ``` python +from my_lib import get_schema, version from inline_snapshot import snapshot -from my_lib import version, get_schema def test_function(): @@ -368,8 +368,8 @@ The advantage of this approach is that the test uses always the correct values f You can also extract the version logic into its own function. ``` python -from inline_snapshot import snapshot, Snapshot -from my_lib import version, get_schema +from my_lib import get_schema, version +from inline_snapshot import Snapshot, snapshot def version_snapshot(v1: Snapshot, v2: Snapshot): diff --git a/docs/external/external.md b/docs/external/external.md index bfecbe45..b7ac463f 100644 --- a/docs/external/external.md +++ b/docs/external/external.md @@ -64,7 +64,7 @@ The `external()` function can also be used inside other data structures. ``` python -from inline_snapshot import snapshot, external +from inline_snapshot import external, snapshot def test_something(): @@ -75,7 +75,7 @@ def test_something(): ``` python hl_lines="6 7 8 9 10 11 12 13" -from inline_snapshot import snapshot, external +from inline_snapshot import external, snapshot def test_something(): @@ -176,7 +176,7 @@ You must specify the suffix in this case. ``` python -from inline_snapshot import register_format_alias, external +from inline_snapshot import external, register_format_alias register_format_alias(".html", ".txt") @@ -189,7 +189,7 @@ inline-snapshot uses the given suffix to create an external snapshot. ``` python hl_lines="7 8 9" -from inline_snapshot import register_format_alias, external +from inline_snapshot import external, register_format_alias register_format_alias(".html", ".txt") diff --git a/docs/external/register_format.md b/docs/external/register_format.md index 216ec030..c0e72cbb 100644 --- a/docs/external/register_format.md +++ b/docs/external/register_format.md @@ -129,8 +129,8 @@ The custom format is then used every time a `NumberSet` is compared with an empt === "example" ``` python - from inline_snapshot import external from number_set import NumberSet + from inline_snapshot import external def test(): diff --git a/docs/plugin.md b/docs/plugin.md index bd624a1f..8471e5d4 100644 --- a/docs/plugin.md +++ b/docs/plugin.md @@ -139,8 +139,8 @@ This allows you to influence the code that is created by inline-snapshot. ``` python title="test_square.py" -from inline_snapshot import snapshot from rect import Rect +from inline_snapshot import snapshot def test_square(): @@ -178,10 +178,9 @@ As explained in the [customize hook specification][inline_snapshot.plugin.Inline ``` python title="test_is_now.py" -from inline_snapshot import snapshot from datetime import datetime - from dirty_equals import IsNow # (1)! +from inline_snapshot import snapshot def test_is_now(): @@ -214,9 +213,7 @@ class InlineSnapshotPlugin: ``` python title="test_long_strings.py" -from inline_snapshot import snapshot - -from inline_snapshot import external +from inline_snapshot import external, snapshot def test_long_strings(): @@ -301,8 +298,8 @@ The problem is that `--inline-snapshot=create` puts your secret into your test. ``` python -from inline_snapshot import snapshot from get_data import get_data +from inline_snapshot import snapshot def test_my_class(): @@ -336,8 +333,8 @@ Inline-snapshot will now create the correct code and import statement when you r ``` python hl_lines="4 5 9" -from inline_snapshot import snapshot from get_data import get_data +from inline_snapshot import snapshot from my_secrets import secrets diff --git a/docs/testing.md b/docs/testing.md index 19977812..4d9e4c9f 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -5,8 +5,8 @@ The following example shows how you can use the `Example` class to test what inl ``` python -from inline_snapshot.testing import Example from inline_snapshot import snapshot +from inline_snapshot.testing import Example def test_something(): @@ -30,8 +30,8 @@ Inline-snapshot will then populate the empty snapshots. ``` python hl_lines="17 18 19 20 21 22 23 24 25 26" -from inline_snapshot.testing import Example from inline_snapshot import snapshot +from inline_snapshot.testing import Example def test_something(): @@ -68,8 +68,8 @@ This allows for more complex tests where you create one example and perform mult ``` python -from inline_snapshot.testing import Example from inline_snapshot import snapshot +from inline_snapshot.testing import Example def test_something(): @@ -116,8 +116,8 @@ You can also use the same example multiple times and call different methods on i ``` python -from inline_snapshot.testing import Example from inline_snapshot import snapshot +from inline_snapshot.testing import Example def test_something(): diff --git a/pyproject.toml b/pyproject.toml index b89e05d9..1cf625e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ dev = [ "attrs>=24.3.0", "pydantic>=1", "black==25.1.0" + "isort" ] [project.entry-points.pytest11] @@ -169,18 +170,7 @@ matrix.extra-deps.dependencies = [ ] [tool.hatch.envs.hatch-test] -extra-dependencies = [ - "inline-snapshot[black,dirty-equals]", - "dirty-equals>=0.9.0", - "hypothesis>=6.75.5", - "mypy>=1.2.0 ; implementation_name == 'cpython'", - "pyright>=1.1.359", - "pytest-freezer>=0.4.8", - "pytest-mock>=3.14.0", - "black==25.1.0", - "setuptools" - "attrs", -] +dependency-groups=["dev"] env-vars.TOP = "{root}" [tool.hatch.envs.hatch-test.scripts] @@ -190,13 +180,9 @@ cov-combine = "coverage combine" cov-report=["coverage report","coverage html"] [tool.hatch.envs.types] +dependency-groups=["dev"] extra-dependencies = [ - "inline-snapshot[black,dirty-equals]", "mypy>=1.0.0", - "hypothesis>=6.75.5", - "pydantic", - "attrs", - "typing-extensions" ] [[tool.hatch.envs.types.matrix]] diff --git a/src/inline_snapshot/_types.py b/src/inline_snapshot/_types.py index 1f221c27..d0fb6caa 100644 --- a/src/inline_snapshot/_types.py +++ b/src/inline_snapshot/_types.py @@ -35,7 +35,7 @@ class Snapshot(Protocol[T]): ``` python from typing import Optional - from inline_snapshot import snapshot, Snapshot + from inline_snapshot import Snapshot, snapshot # required snapshots diff --git a/src/inline_snapshot/extra.py b/src/inline_snapshot/extra.py index f840aaf5..1aad98e9 100644 --- a/src/inline_snapshot/extra.py +++ b/src/inline_snapshot/extra.py @@ -83,9 +83,9 @@ def prints(*, stdout: Snapshot[str] = "", stderr: Snapshot[str] = ""): ``` python + import sys from inline_snapshot import snapshot from inline_snapshot.extra import prints - import sys def test_prints(): @@ -98,9 +98,9 @@ def test_prints(): ``` python hl_lines="7 8 9" + import sys from inline_snapshot import snapshot from inline_snapshot.extra import prints - import sys def test_prints(): @@ -114,11 +114,11 @@ def test_prints(): === "ignore stdout" - ``` python hl_lines="3 9 10" + ``` python hl_lines="2 9 10" + import sys + from dirty_equals import IsStr from inline_snapshot import snapshot from inline_snapshot.extra import prints - from dirty_equals import IsStr - import sys def test_prints(): @@ -168,9 +168,9 @@ def warns( ``` python + from warnings import warn from inline_snapshot import snapshot from inline_snapshot.extra import warns - from warnings import warn def test_warns(): @@ -182,9 +182,9 @@ def test_warns(): ``` python hl_lines="7" + from warnings import warn from inline_snapshot import snapshot from inline_snapshot.extra import warns - from warnings import warn def test_warns(): @@ -272,9 +272,9 @@ def test_request(): ``` python - from inline_snapshot.extra import Transformed - from inline_snapshot import snapshot import re + from inline_snapshot import snapshot + from inline_snapshot.extra import Transformed class Thing: @@ -337,9 +337,9 @@ def transformation(func): ``` python - from inline_snapshot.extra import transformation - from inline_snapshot import snapshot import re + from inline_snapshot import snapshot + from inline_snapshot.extra import transformation class Thing: diff --git a/tests/test_docs.py b/tests/test_docs.py index 27c453ca..503d5305 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -13,6 +13,7 @@ from typing import Optional from typing import TypeVar +import isort.api import pytest from executing import is_pytest_compatible @@ -390,6 +391,15 @@ def test_block(block: Block): assert last_code is not None test_files = {"tests/test_example.py": last_code} else: + code = isort.api.sort_code_string( + code, + config=isort.Config( + profile="black", + combine_as_imports=True, + lines_between_sections=0, + ), + ) + block.code = code test_files = {"tests/test_example.py": code} example = Example({**std_files, **test_files}) From e4908c6e9c9842f2bd803cf0e92d97b56047f157 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sun, 11 Jan 2026 17:02:13 +0100 Subject: [PATCH 41/72] docs: builder documentation --- src/inline_snapshot/_customize/_builder.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/inline_snapshot/_customize/_builder.py b/src/inline_snapshot/_customize/_builder.py index d94bd751..ec1a60ce 100644 --- a/src/inline_snapshot/_customize/_builder.py +++ b/src/inline_snapshot/_customize/_builder.py @@ -65,6 +65,9 @@ def _get_handler(self, v) -> Custom: def create_external( self, value: Any, format: str | None = None, storage: str | None = None ): + """ + Creates a new `external()` with the given format and storage. + """ return CustomExternal(value, format=format, storage=storage) @@ -73,7 +76,7 @@ def create_list(self, value) -> Custom: Creates an intermediate node for a list-expression which can be used as a result for your customization function. `create_list([1,2,3])` becomes `[1,2,3]` in the code. - List elements are recursively converted into CustomNodes. + List elements don't have to be Custom nodes and are converted by inline-snapshot if needed. """ custom = [self._get_handler(v) for v in value] return CustomList(value=custom) @@ -83,7 +86,7 @@ def create_tuple(self, value) -> Custom: Creates an intermediate node for a tuple-expression which can be used as a result for your customization function. `create_tuple((1, 2, 3))` becomes `(1, 2, 3)` in the code. - Tuple elements are recursively converted into CustomNodes. + Tuple elements don't have to be Custom nodes and are converted by inline-snapshot if needed. """ custom = [self._get_handler(v) for v in value] return CustomTuple(value=custom) @@ -95,7 +98,7 @@ def create_call( Creates an intermediate node for a function call expression which can be used as a result for your customization function. `create_call(MyClass, [arg1, arg2], {'key': value})` becomes `MyClass(arg1, arg2, key=value)` in the code. - Function, arguments, and keyword arguments are recursively converted into CustomNodes. + Function, arguments, and keyword arguments don't have to be Custom nodes and are converted by inline-snapshot if needed. """ function = self._get_handler(function) posonly_args = [self._get_handler(arg) for arg in posonly_args] @@ -114,7 +117,7 @@ def create_default(self, value) -> Custom: Creates an intermediate node for a default value which can be used as a result for your customization function. Default values are not included in the generated code when they match the actual default. - The value is recursively converted into a CustomNode. + The value doesn't have to be a Custom node and is converted by inline-snapshot if needed. """ return CustomDefault(value=self._get_handler(value)) @@ -123,7 +126,7 @@ def create_dict(self, value) -> Custom: Creates an intermediate node for a dict-expression which can be used as a result for your customization function. `create_dict({'key': 'value'})` becomes `{'key': 'value'}` in the code. - Dict keys and values are recursively converted into CustomNodes. + Dict keys and values don't have to be Custom nodes and are converted by inline-snapshot if needed. """ custom = {self._get_handler(k): self._get_handler(v) for k, v in value.items()} return CustomDict(value=custom) @@ -135,7 +138,6 @@ def create_code(self, value, repr: str | None = None) -> CustomCode: `create_code(value, '{value-1!r}+1')` becomes `4+1` in the code for a given `value=5`. Use this when you need to control the exact string representation of a value. - You can use `.with_import(module,name)` to create an import in the code. - `create_code(Color.red,"Color.red").with_import("my_colors","Color")` will create a `from my_colors import Color` if needed and `Color.red` in the code. + You can use [`.with_import(module,name)`][inline_snapshot.plugin.CustomCode.with_import] to create an import in the code. """ return CustomCode(value, repr) From 832580b7bb636af7aa320aca05947dde89c858bf Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sun, 11 Jan 2026 19:25:44 +0100 Subject: [PATCH 42/72] refactor: renamed map and made it private --- src/inline_snapshot/_customize/_custom.py | 4 ++-- src/inline_snapshot/_customize/_custom_call.py | 12 ++++++------ src/inline_snapshot/_customize/_custom_dict.py | 4 ++-- src/inline_snapshot/_customize/_custom_external.py | 2 +- src/inline_snapshot/_customize/_custom_sequence.py | 4 ++-- src/inline_snapshot/_customize/_custom_undefined.py | 2 +- src/inline_snapshot/_customize/_custom_unmanaged.py | 2 +- src/inline_snapshot/_customize/_custom_value.py | 2 +- src/inline_snapshot/_get_snapshot_value.py | 2 +- 9 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/inline_snapshot/_customize/_custom.py b/src/inline_snapshot/_customize/_custom.py index 08243dd7..3e9a84d9 100644 --- a/src/inline_snapshot/_customize/_custom.py +++ b/src/inline_snapshot/_customize/_custom.py @@ -31,7 +31,7 @@ def __eq__(self, other): return self.eval() == other.eval() @abstractmethod - def map(self, f): + def _map(self, f): raise NotImplementedError() @abstractmethod @@ -39,7 +39,7 @@ def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: raise NotImplementedError() def eval(self): - return self.map(lambda a: a) + return self._map(lambda a: a) def _needed_imports(self): yield from () diff --git a/src/inline_snapshot/_customize/_custom_call.py b/src/inline_snapshot/_customize/_custom_call.py index ddd69372..388689f8 100644 --- a/src/inline_snapshot/_customize/_custom_call.py +++ b/src/inline_snapshot/_customize/_custom_call.py @@ -20,8 +20,8 @@ def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: # this should never be called because default values are never converted into code assert False - def map(self, f): - return self.value.map(f) + def _map(self, f): + return self.value._map(f) def _needed_imports(self): yield from self.value._needed_imports() @@ -72,10 +72,10 @@ def argument(self, pos_or_str): else: return unwrap_default(self.kwargs[pos_or_str]) - def map(self, f): - return self._function.map(f)( - *[f(x.map(f)) for x in self._args], - **{k: f(v.map(f)) for k, v in self.kwargs.items()}, + def _map(self, f): + return self._function._map(f)( + *[f(x._map(f)) for x in self._args], + **{k: f(v._map(f)) for k, v in self.kwargs.items()}, ) def _needed_imports(self): diff --git a/src/inline_snapshot/_customize/_custom_dict.py b/src/inline_snapshot/_customize/_custom_dict.py index fafd49bd..1f2225d0 100644 --- a/src/inline_snapshot/_customize/_custom_dict.py +++ b/src/inline_snapshot/_customize/_custom_dict.py @@ -16,8 +16,8 @@ class CustomDict(Custom): node_type = ast.Dict value: dict[Custom, Custom] = field(compare=False) - def map(self, f): - return f({k.map(f): v.map(f) for k, v in self.value.items()}) + def _map(self, f): + return f({k._map(f): v._map(f) for k, v in self.value.items()}) def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: values = [] diff --git a/src/inline_snapshot/_customize/_custom_external.py b/src/inline_snapshot/_customize/_custom_external.py index 91a619e9..62cca398 100644 --- a/src/inline_snapshot/_customize/_custom_external.py +++ b/src/inline_snapshot/_customize/_custom_external.py @@ -20,7 +20,7 @@ class CustomExternal(Custom): format: str | None = None storage: str | None = None - def map(self, f): + def _map(self, f): return f(self.value) def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: diff --git a/src/inline_snapshot/_customize/_custom_sequence.py b/src/inline_snapshot/_customize/_custom_sequence.py index ad14878e..e1d5f2fc 100644 --- a/src/inline_snapshot/_customize/_custom_sequence.py +++ b/src/inline_snapshot/_customize/_custom_sequence.py @@ -21,8 +21,8 @@ class CustomSequenceTypes: class CustomSequence(Custom, CustomSequenceTypes): value: list[Custom] = field(compare=False) - def map(self, f): - return f(self.value_type([x.map(f) for x in self.value])) + def _map(self, f): + return f(self.value_type([x._map(f) for x in self.value])) def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: values = [] diff --git a/src/inline_snapshot/_customize/_custom_undefined.py b/src/inline_snapshot/_customize/_custom_undefined.py index f7b06344..23f93301 100644 --- a/src/inline_snapshot/_customize/_custom_undefined.py +++ b/src/inline_snapshot/_customize/_custom_undefined.py @@ -17,5 +17,5 @@ def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: yield from () return "..." - def map(self, f): + def _map(self, f): return f(undefined) diff --git a/src/inline_snapshot/_customize/_custom_unmanaged.py b/src/inline_snapshot/_customize/_custom_unmanaged.py index e909ea6d..2b18bfaa 100644 --- a/src/inline_snapshot/_customize/_custom_unmanaged.py +++ b/src/inline_snapshot/_customize/_custom_unmanaged.py @@ -18,5 +18,5 @@ def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: yield from () # pragma: no cover return "'unmanaged'" - def map(self, f): + def _map(self, f): return f(self.value) diff --git a/src/inline_snapshot/_customize/_custom_value.py b/src/inline_snapshot/_customize/_custom_value.py index bc946374..dcb27d66 100644 --- a/src/inline_snapshot/_customize/_custom_value.py +++ b/src/inline_snapshot/_customize/_custom_value.py @@ -46,7 +46,7 @@ def __init__(self, value, repr_str=None): super().__init__() - def map(self, f): + def _map(self, f): return f(self.value) def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: diff --git a/src/inline_snapshot/_get_snapshot_value.py b/src/inline_snapshot/_get_snapshot_value.py index b34e6a1f..a6bfffe6 100644 --- a/src/inline_snapshot/_get_snapshot_value.py +++ b/src/inline_snapshot/_get_snapshot_value.py @@ -13,7 +13,7 @@ def unwrap(value): if isinstance(value, GenericValue): - return value._visible_value().map(lambda v: unwrap(v)[0]), True + return value._visible_value()._map(lambda v: unwrap(v)[0]), True if isinstance(value, Outsourced): return (value.data, True) From d8d0d1c419def77faca8ba9e1aa84893bbc310ba Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sun, 11 Jan 2026 19:40:12 +0100 Subject: [PATCH 43/72] refactor: renamed eval to _eval --- src/inline_snapshot/_customize/_custom.py | 6 +++--- src/inline_snapshot/_new_adapter.py | 14 +++++++------- src/inline_snapshot/_snapshot/collection_value.py | 8 ++++---- src/inline_snapshot/_snapshot/eq_value.py | 4 ++-- src/inline_snapshot/_snapshot/generic_value.py | 2 +- src/inline_snapshot/_snapshot/min_max_value.py | 14 +++++++------- src/inline_snapshot/_snapshot/undecided_value.py | 2 +- 7 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/inline_snapshot/_customize/_custom.py b/src/inline_snapshot/_customize/_custom.py index 3e9a84d9..9c953519 100644 --- a/src/inline_snapshot/_customize/_custom.py +++ b/src/inline_snapshot/_customize/_custom.py @@ -24,11 +24,11 @@ class Custom(ABC): original_value: Any def __hash__(self): - return hash(self.eval()) + return hash(self._eval()) def __eq__(self, other): assert isinstance(other, Custom) - return self.eval() == other.eval() + return self._eval() == other._eval() @abstractmethod def _map(self, f): @@ -38,7 +38,7 @@ def _map(self, f): def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: raise NotImplementedError() - def eval(self): + def _eval(self): return self._map(lambda a: a) def _needed_imports(self): diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index 6102cbd8..1db5bf59 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -108,7 +108,7 @@ def reeval_CustomUndefined(old_value, value): def reeval_CustomCode(old_value: CustomCode, value: CustomCode): - if not old_value.eval() == value.eval(): + if not old_value._eval() == value._eval(): raise UsageError( "snapshot value should not change. Use Is(...) for dynamic snapshot parts." ) @@ -189,7 +189,7 @@ def compare_CustomCode( and isinstance(new_value, CustomCode) and isinstance(new_value.value, str) ): - if not old_value.eval() == new_value.eval(): + if not old_value._eval() == new_value._eval(): value = only_value(new_value.repr(self.context)) warnings.warn_explicit( @@ -200,7 +200,7 @@ def compare_CustomCode( ) return old_value - if not old_value.eval() == new_value.original_value: + if not old_value._eval() == new_value.original_value: if isinstance(old_value, CustomUndefined): flag = "create" else: @@ -218,7 +218,7 @@ def compare_CustomCode( file=self.context.file, new_code=new_code, flag=flag, - old_value=old_value.eval(), + old_value=old_value._eval(), new_value=new_value, ) @@ -239,7 +239,7 @@ def compare_CustomSequence( if old_node is not None: assert isinstance( - old_node, ast.List if isinstance(old_value.eval(), list) else ast.Tuple + old_node, ast.List if isinstance(old_value._eval(), list) else ast.Tuple ) assert isinstance(old_node, (ast.List, ast.Tuple)) @@ -311,7 +311,7 @@ def compare_CustomDict( node_value = ast.literal_eval(node) except Exception: continue - assert node_value == value2.eval() + assert node_value == value2._eval() else: pass # pragma: no cover @@ -400,7 +400,7 @@ def compare_CustomCall( result_args = [] - flag = "update" if old_value.eval() == new_value.original_value else "fix" + flag = "update" if old_value._eval() == new_value.original_value else "fix" if flag == "update": diff --git a/src/inline_snapshot/_snapshot/collection_value.py b/src/inline_snapshot/_snapshot/collection_value.py index 7afa3643..2f833033 100644 --- a/src/inline_snapshot/_snapshot/collection_value.py +++ b/src/inline_snapshot/_snapshot/collection_value.py @@ -26,13 +26,13 @@ def __contains__(self, item): if isinstance(self._new_value, CustomUndefined): self._new_value = CustomList([self.to_custom(item)]) else: - if item not in self._new_value.eval(): + if item not in self._new_value._eval(): self._new_value.value.append(self.to_custom(item)) if ignore_old_value() or isinstance(self._old_value, CustomUndefined): return True else: - return self._return(item in self._old_value.eval()) + return self._return(item in self._old_value._eval()) def _new_code(self) -> Generator[ChangeBase, None, str]: code = yield from self._new_value.repr(self._context) @@ -59,7 +59,7 @@ def _get_changes(self) -> Iterator[ChangeBase]: continue # check for update - new_code = yield from self.to_custom(old_value.eval()).repr(self._context) + new_code = yield from self.to_custom(old_value._eval()).repr(self._context) if self._file.code_changed(old_node, new_code): @@ -78,7 +78,7 @@ def _get_changes(self) -> Iterator[ChangeBase]: if v not in self._old_value.value: new_code = yield from v.repr(self._context) new_codes.append(new_code) - new_values.append(v.eval()) + new_values.append(v._eval()) if new_codes: yield ListInsert( diff --git a/src/inline_snapshot/_snapshot/eq_value.py b/src/inline_snapshot/_snapshot/eq_value.py index d1502928..d8afd335 100644 --- a/src/inline_snapshot/_snapshot/eq_value.py +++ b/src/inline_snapshot/_snapshot/eq_value.py @@ -35,8 +35,8 @@ def __eq__(self, other): self._new_value = result.value return self._return( - self._old_value.eval() == other, - self._new_value.eval() == other, + self._old_value._eval() == other, + self._new_value._eval() == other, ) def _new_code(self) -> Generator[ChangeBase, None, str]: diff --git a/src/inline_snapshot/_snapshot/generic_value.py b/src/inline_snapshot/_snapshot/generic_value.py index 2ce33fe2..efa3de5a 100644 --- a/src/inline_snapshot/_snapshot/generic_value.py +++ b/src/inline_snapshot/_snapshot/generic_value.py @@ -93,7 +93,7 @@ def _new_code(self) -> Generator[ChangeBase, None, str]: raise NotImplementedError() def __repr__(self): - return repr(self._visible_value().eval()) + return repr(self._visible_value()._eval()) def _type_error(self, op): __tracebackhide__ = True diff --git a/src/inline_snapshot/_snapshot/min_max_value.py b/src/inline_snapshot/_snapshot/min_max_value.py index 1d70dba6..0d3ffd76 100644 --- a/src/inline_snapshot/_snapshot/min_max_value.py +++ b/src/inline_snapshot/_snapshot/min_max_value.py @@ -25,12 +25,12 @@ def _generic_cmp(self, other): self._new_value = self.to_custom(other) if isinstance(self._old_value, CustomUndefined) or ignore_old_value(): return True - return self._return(self.cmp(self._old_value.eval(), other)) + return self._return(self.cmp(self._old_value._eval(), other)) else: - if not self.cmp(self._new_value.eval(), other): + if not self.cmp(self._new_value._eval(), other): self._new_value = self.to_custom(other) - return self._return(self.cmp(self._visible_value().eval(), other)) + return self._return(self.cmp(self._visible_value()._eval(), other)) def _new_code(self) -> Generator[ChangeBase, None, str]: code = yield from self._new_value.repr(self._context) @@ -39,9 +39,9 @@ def _new_code(self) -> Generator[ChangeBase, None, str]: def _get_changes(self) -> Iterator[ChangeBase]: new_code = yield from self._new_code() - if not self.cmp(self._old_value.eval(), self._new_value.eval()): + if not self.cmp(self._old_value._eval(), self._new_value._eval()): flag = "fix" - elif not self.cmp(self._new_value.eval(), self._old_value.eval()): + elif not self.cmp(self._new_value._eval(), self._old_value._eval()): flag = "trim" elif self._file.code_changed(self._ast_node, new_code): flag = "update" @@ -53,8 +53,8 @@ def _get_changes(self) -> Iterator[ChangeBase]: file=self._file, new_code=new_code, flag=flag, - old_value=self._old_value.eval(), - new_value=self._new_value.eval(), + old_value=self._old_value._eval(), + new_value=self._new_value._eval(), ) diff --git a/src/inline_snapshot/_snapshot/undecided_value.py b/src/inline_snapshot/_snapshot/undecided_value.py index a4cdb9ed..388588e0 100644 --- a/src/inline_snapshot/_snapshot/undecided_value.py +++ b/src/inline_snapshot/_snapshot/undecided_value.py @@ -89,7 +89,7 @@ def _new_code(self): def _get_changes(self) -> Iterator[ChangeBase]: assert isinstance(self._new_value, CustomUndefined) - new_value = self.to_custom(self._old_value.eval()) + new_value = self.to_custom(self._old_value._eval()) adapter = NewAdapter(self._context) From 85cb3d4ebaebc30f516868ecafaaf0aaf489d4e2 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sun, 11 Jan 2026 20:10:03 +0100 Subject: [PATCH 44/72] refactor: renamed repr to _code_repr --- src/inline_snapshot/_code_repr.py | 2 +- src/inline_snapshot/_customize/_custom.py | 2 +- .../_customize/_custom_call.py | 10 +++++----- .../_customize/_custom_dict.py | 6 +++--- .../_customize/_custom_external.py | 2 +- .../_customize/_custom_sequence.py | 4 ++-- .../_customize/_custom_undefined.py | 2 +- .../_customize/_custom_unmanaged.py | 2 +- .../_customize/_custom_value.py | 2 +- src/inline_snapshot/_new_adapter.py | 20 +++++++++---------- .../_snapshot/collection_value.py | 8 +++++--- src/inline_snapshot/_snapshot/dict_value.py | 4 ++-- src/inline_snapshot/_snapshot/eq_value.py | 2 +- .../_snapshot/min_max_value.py | 2 +- 14 files changed, 35 insertions(+), 33 deletions(-) diff --git a/src/inline_snapshot/_code_repr.py b/src/inline_snapshot/_code_repr.py index 8add9201..02c52104 100644 --- a/src/inline_snapshot/_code_repr.py +++ b/src/inline_snapshot/_code_repr.py @@ -88,7 +88,7 @@ def new_repr(obj): from inline_snapshot._customize._builder import Builder return only_value( - Builder(_snapshot_context=context)._get_handler(obj).repr(context) + Builder(_snapshot_context=context)._get_handler(obj)._code_repr(context) ) with mock.patch("builtins.repr", new_repr): diff --git a/src/inline_snapshot/_customize/_custom.py b/src/inline_snapshot/_customize/_custom.py index 9c953519..bb6d8a2f 100644 --- a/src/inline_snapshot/_customize/_custom.py +++ b/src/inline_snapshot/_customize/_custom.py @@ -35,7 +35,7 @@ def _map(self, f): raise NotImplementedError() @abstractmethod - def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: raise NotImplementedError() def _eval(self): diff --git a/src/inline_snapshot/_customize/_custom_call.py b/src/inline_snapshot/_customize/_custom_call.py index 388689f8..91a8fdf9 100644 --- a/src/inline_snapshot/_customize/_custom_call.py +++ b/src/inline_snapshot/_customize/_custom_call.py @@ -15,7 +15,7 @@ class CustomDefault(Custom): value: Custom = field(compare=False) - def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: yield from () # pragma: no cover # this should never be called because default values are never converted into code assert False @@ -41,18 +41,18 @@ class CustomCall(Custom): _kwargs: dict[str, Custom] = field(compare=False) _kwonly: dict[str, Custom] = field(default_factory=dict, compare=False) - def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: args = [] for a in self.args: - v = yield from a.repr(context) + v = yield from a._code_repr(context) args.append(v) for k, v in self.kwargs.items(): if not isinstance(v, CustomDefault): - value = yield from v.repr(context) + value = yield from v._code_repr(context) args.append(f"{k}={value}") - return f"{yield from self._function.repr(context)}({', '.join(args)})" + return f"{yield from self._function._code_repr(context)}({', '.join(args)})" @property def args(self): diff --git a/src/inline_snapshot/_customize/_custom_dict.py b/src/inline_snapshot/_customize/_custom_dict.py index 1f2225d0..f0f28874 100644 --- a/src/inline_snapshot/_customize/_custom_dict.py +++ b/src/inline_snapshot/_customize/_custom_dict.py @@ -19,11 +19,11 @@ class CustomDict(Custom): def _map(self, f): return f({k._map(f): v._map(f) for k, v in self.value.items()}) - def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: values = [] for k, v in self.value.items(): - key = yield from k.repr(context) - value = yield from v.repr(context) + key = yield from k._code_repr(context) + value = yield from v._code_repr(context) values.append(f"{key}: {value}") return f"{{{ ', '.join(values)}}}" diff --git a/src/inline_snapshot/_customize/_custom_external.py b/src/inline_snapshot/_customize/_custom_external.py index 62cca398..ceed5906 100644 --- a/src/inline_snapshot/_customize/_custom_external.py +++ b/src/inline_snapshot/_customize/_custom_external.py @@ -23,7 +23,7 @@ class CustomExternal(Custom): def _map(self, f): return f(self.value) - def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: from inline_snapshot._global_state import state storage_name = self.storage or state().config.default_storage diff --git a/src/inline_snapshot/_customize/_custom_sequence.py b/src/inline_snapshot/_customize/_custom_sequence.py index e1d5f2fc..06389648 100644 --- a/src/inline_snapshot/_customize/_custom_sequence.py +++ b/src/inline_snapshot/_customize/_custom_sequence.py @@ -24,10 +24,10 @@ class CustomSequence(Custom, CustomSequenceTypes): def _map(self, f): return f(self.value_type([x._map(f) for x in self.value])) - def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: values = [] for v in self.value: - value = yield from v.repr(context) + value = yield from v._code_repr(context) values.append(value) trailing_comma = self.trailing_comma and len(self.value) == 1 diff --git a/src/inline_snapshot/_customize/_custom_undefined.py b/src/inline_snapshot/_customize/_custom_undefined.py index 23f93301..90da8e69 100644 --- a/src/inline_snapshot/_customize/_custom_undefined.py +++ b/src/inline_snapshot/_customize/_custom_undefined.py @@ -13,7 +13,7 @@ class CustomUndefined(Custom): def __init__(self): self.value = undefined - def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: yield from () return "..." diff --git a/src/inline_snapshot/_customize/_custom_unmanaged.py b/src/inline_snapshot/_customize/_custom_unmanaged.py index 2b18bfaa..f85acb8b 100644 --- a/src/inline_snapshot/_customize/_custom_unmanaged.py +++ b/src/inline_snapshot/_customize/_custom_unmanaged.py @@ -14,7 +14,7 @@ class CustomUnmanaged(Custom): value: Any - def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: yield from () # pragma: no cover return "'unmanaged'" diff --git a/src/inline_snapshot/_customize/_custom_value.py b/src/inline_snapshot/_customize/_custom_value.py index dcb27d66..bde9f296 100644 --- a/src/inline_snapshot/_customize/_custom_value.py +++ b/src/inline_snapshot/_customize/_custom_value.py @@ -49,7 +49,7 @@ def __init__(self, value, repr_str=None): def _map(self, f): return f(self.value) - def repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: yield from () return self.repr_str diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index 1db5bf59..2ed1478b 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -182,7 +182,7 @@ def compare_CustomCode( if old_node is None: new_code = "" else: - new_code = yield from new_value.repr(self.context) + new_code = yield from new_value._code_repr(self.context) if ( isinstance(old_node, ast.JoinedStr) @@ -191,7 +191,7 @@ def compare_CustomCode( ): if not old_value._eval() == new_value._eval(): - value = only_value(new_value.repr(self.context)) + value = only_value(new_value._code_repr(self.context)) warnings.warn_explicit( f"inline-snapshot will be able to fix f-strings in the future.\nThe current string value is:\n {value}", filename=self.context.file._source.filename, @@ -267,7 +267,7 @@ def compare_CustomSequence( old_position += 1 elif c == "i": new_value_element = next(new) - new_code = yield from new_value_element.repr(self.context) + new_code = yield from new_value_element._code_repr(self.context) result.append(new_value_element) to_insert[old_position].append((new_code, new_value_element)) elif c == "d": @@ -349,8 +349,8 @@ def compare_CustomDict( if to_insert: new_code = [] for k, v in to_insert: - new_code_key = yield from k.repr(self.context) - new_code_value = yield from v.repr(self.context) + new_code_key = yield from k._code_repr(self.context) + new_code_value = yield from v._code_repr(self.context) new_code.append((new_code_key, new_code_value)) yield DictInsert( @@ -368,8 +368,8 @@ def compare_CustomDict( if to_insert: new_code = [] for k, v in to_insert: - new_code_key = yield from k.repr(self.context) - new_code_value = yield from v.repr(self.context) + new_code_key = yield from k._code_repr(self.context) + new_code_value = yield from v._code_repr(self.context) new_code.append( ( new_code_key, @@ -444,7 +444,7 @@ def intercept(stream): if old_args_len < len(new_args): for insert_pos, value in list(enumerate(new_args))[old_args_len:]: - new_code = yield from value.repr(self.context) + new_code = yield from value._code_repr(self.context) yield CallArg( flag=flag, file=self.context.file, @@ -498,7 +498,7 @@ def intercept(stream): if to_insert: for key, value in to_insert: - new_code = yield from value.repr(self.context) + new_code = yield from value._code_repr(self.context) yield CallArg( flag=flag, file=self.context.file, @@ -515,7 +515,7 @@ def intercept(stream): if to_insert: for key, value in to_insert: - new_code = yield from value.repr(self.context) + new_code = yield from value._code_repr(self.context) yield CallArg( flag=flag, diff --git a/src/inline_snapshot/_snapshot/collection_value.py b/src/inline_snapshot/_snapshot/collection_value.py index 2f833033..681fe7dd 100644 --- a/src/inline_snapshot/_snapshot/collection_value.py +++ b/src/inline_snapshot/_snapshot/collection_value.py @@ -35,7 +35,7 @@ def __contains__(self, item): return self._return(item in self._old_value._eval()) def _new_code(self) -> Generator[ChangeBase, None, str]: - code = yield from self._new_value.repr(self._context) + code = yield from self._new_value._code_repr(self._context) return code def _get_changes(self) -> Iterator[ChangeBase]: @@ -59,7 +59,9 @@ def _get_changes(self) -> Iterator[ChangeBase]: continue # check for update - new_code = yield from self.to_custom(old_value._eval()).repr(self._context) + new_code = yield from self.to_custom(old_value._eval())._code_repr( + self._context + ) if self._file.code_changed(old_node, new_code): @@ -76,7 +78,7 @@ def _get_changes(self) -> Iterator[ChangeBase]: new_values = [] for v in self._new_value.value: if v not in self._old_value.value: - new_code = yield from v.repr(self._context) + new_code = yield from v._code_repr(self._context) new_codes.append(new_code) new_values.append(v._eval()) diff --git a/src/inline_snapshot/_snapshot/dict_value.py b/src/inline_snapshot/_snapshot/dict_value.py index bb92dd8d..3071f144 100644 --- a/src/inline_snapshot/_snapshot/dict_value.py +++ b/src/inline_snapshot/_snapshot/dict_value.py @@ -63,7 +63,7 @@ def _new_code(self) -> Generator[ChangeBase, None, str]: for k, v in self._new_value.value.items(): if not isinstance(v, UndecidedValue): new_code = yield from v._new_code() # type:ignore - new_key = yield from k.repr(self._context) + new_key = yield from k._code_repr(self._context) values.append(f"{new_key}: {new_code}") return "{" + ", ".join(values) + "}" @@ -94,7 +94,7 @@ def _get_changes(self) -> Iterator[ChangeBase]: ): # add new values new_value = yield from new_value_element._new_code() # type:ignore - new_key = yield from key.repr(self._context) + new_key = yield from key._code_repr(self._context) to_insert.append((new_key, new_value)) to_insert_values.append((key, new_value_element)) diff --git a/src/inline_snapshot/_snapshot/eq_value.py b/src/inline_snapshot/_snapshot/eq_value.py index d8afd335..1cb39e54 100644 --- a/src/inline_snapshot/_snapshot/eq_value.py +++ b/src/inline_snapshot/_snapshot/eq_value.py @@ -40,7 +40,7 @@ def __eq__(self, other): ) def _new_code(self) -> Generator[ChangeBase, None, str]: - code = yield from self._new_value.repr(self._context) + code = yield from self._new_value._code_repr(self._context) return code def _get_changes(self) -> Iterator[Change]: diff --git a/src/inline_snapshot/_snapshot/min_max_value.py b/src/inline_snapshot/_snapshot/min_max_value.py index 0d3ffd76..4af440df 100644 --- a/src/inline_snapshot/_snapshot/min_max_value.py +++ b/src/inline_snapshot/_snapshot/min_max_value.py @@ -33,7 +33,7 @@ def _generic_cmp(self, other): return self._return(self.cmp(self._visible_value()._eval(), other)) def _new_code(self) -> Generator[ChangeBase, None, str]: - code = yield from self._new_value.repr(self._context) + code = yield from self._new_value._code_repr(self._context) return code def _get_changes(self) -> Iterator[ChangeBase]: From 7a20636ea74f04528606bab6c49caa3c6760be4f Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sun, 11 Jan 2026 20:46:40 +0100 Subject: [PATCH 45/72] feat: support for defaults in named tuples --- src/inline_snapshot/plugin/_default_plugin.py | 11 +- tests/adapter/test_namedtuple.py | 227 ++++++++++++++++++ 2 files changed, 233 insertions(+), 5 deletions(-) create mode 100644 tests/adapter/test_namedtuple.py diff --git a/src/inline_snapshot/plugin/_default_plugin.py b/src/inline_snapshot/plugin/_default_plugin.py index f22a0339..492d8ad9 100644 --- a/src/inline_snapshot/plugin/_default_plugin.py +++ b/src/inline_snapshot/plugin/_default_plugin.py @@ -181,16 +181,17 @@ def namedtuple_handler(self, value, builder: Builder): if not all(type(n) == str for n in f): return - # TODO handle with builder.Default - return builder.create_call( type(value), [], { - field: getattr(value, field) + field: ( + getattr(value, field) + if field not in value._field_defaults + or getattr(value, field) != value._field_defaults[field] + else builder.create_default(value._field_defaults[field]) + ) for field in value._fields - if field not in value._field_defaults - or getattr(value, field) != value._field_defaults[field] }, {}, ) diff --git a/tests/adapter/test_namedtuple.py b/tests/adapter/test_namedtuple.py new file mode 100644 index 00000000..74fdd417 --- /dev/null +++ b/tests/adapter/test_namedtuple.py @@ -0,0 +1,227 @@ +from inline_snapshot import snapshot +from inline_snapshot.testing._example import Example + + +def test_namedtuple_default_value(): + # Note: namedtuples with defaults are created using a different approach + Example( + """\ +from inline_snapshot import snapshot, Is +from collections import namedtuple + +A = namedtuple('A', ['a', 'b', 'c'], defaults=[2, []]) + +def test_something(): + for _ in [1, 2]: + assert A(a=1) == snapshot(A(a=1, b=2, c=[])) +""" + ).run_inline( + ["--inline-snapshot=update"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from inline_snapshot import snapshot, Is +from collections import namedtuple + +A = namedtuple('A', ['a', 'b', 'c'], defaults=[2, []]) + +def test_something(): + for _ in [1, 2]: + assert A(a=1) == snapshot(A(a=1)) +""" + } + ), + ) + + +def test_namedtuple_add_arguments(): + Example( + """\ +from inline_snapshot import snapshot, Is +from collections import namedtuple + +A = namedtuple('A', ['a', 'b'], defaults=[2]) + +def test_something(): + for _ in [1, 2]: + assert A(a=1, b=5) == snapshot(A(a=1)) +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from inline_snapshot import snapshot, Is +from collections import namedtuple + +A = namedtuple('A', ['a', 'b'], defaults=[2]) + +def test_something(): + for _ in [1, 2]: + assert A(a=1, b=5) == snapshot(A(a=1, b=5)) +""" + } + ), + ) + + +def test_namedtuple_positional_arguments(): + Example( + """\ +from inline_snapshot import snapshot, Is +from collections import namedtuple + +A = namedtuple('A', ['a', 'b', 'c'], defaults=[2, []]) + +def test_something(): + for _ in [1, 2]: + assert A(a=1) == snapshot(A(1, 2, c=[])) +""" + ).run_inline( + ["--inline-snapshot=update"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from inline_snapshot import snapshot, Is +from collections import namedtuple + +A = namedtuple('A', ['a', 'b', 'c'], defaults=[2, []]) + +def test_something(): + for _ in [1, 2]: + assert A(a=1) == snapshot(A(a=1)) +""" + } + ), + ) + + +def test_namedtuple_typing(): + Example( + """\ +from inline_snapshot import snapshot +from typing import NamedTuple + +class A(NamedTuple): + a: int + b: int + +def test_something(): + assert A(a=1, b=2) == snapshot() +""" + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from inline_snapshot import snapshot +from typing import NamedTuple + +class A(NamedTuple): + a: int + b: int + +def test_something(): + assert A(a=1, b=2) == snapshot(A(a=1, b=2)) +""" + } + ), + ) + + +def test_namedtuple_typing_defaults(): + Example( + """\ +from inline_snapshot import snapshot +from typing import NamedTuple + +class A(NamedTuple): + a: int + b: int = 2 + c: list = [] + +def test_something(): + for _ in [1, 2]: + assert A(a=1) == snapshot(A(a=1, b=2, c=[])) +""" + ).run_inline( + ["--inline-snapshot=update"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from inline_snapshot import snapshot +from typing import NamedTuple + +class A(NamedTuple): + a: int + b: int = 2 + c: list = [] + +def test_something(): + for _ in [1, 2]: + assert A(a=1) == snapshot(A(a=1)) +""" + } + ), + ) + + +def test_namedtuple_nested(): + Example( + """\ +from inline_snapshot import snapshot +from collections import namedtuple + +Inner = namedtuple('Inner', ['x', 'y']) +Outer = namedtuple('Outer', ['a', 'inner']) + +def test_something(): + assert Outer(a=1, inner=Inner(x=2, y=3)) == snapshot() +""" + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from inline_snapshot import snapshot +from collections import namedtuple + +Inner = namedtuple('Inner', ['x', 'y']) +Outer = namedtuple('Outer', ['a', 'inner']) + +def test_something(): + assert Outer(a=1, inner=Inner(x=2, y=3)) == snapshot(Outer(a=1, inner=Inner(x=2, y=3))) +""" + } + ), + ) + + +def test_namedtuple_mixed_args(): + # Test mixing positional and keyword arguments + Example( + """\ +from inline_snapshot import snapshot +from collections import namedtuple + +A = namedtuple('A', ['a', 'b', 'c']) + +def test_something(): + assert A(1, b=2, c=3) == snapshot() +""" + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from inline_snapshot import snapshot +from collections import namedtuple + +A = namedtuple('A', ['a', 'b', 'c']) + +def test_something(): + assert A(1, b=2, c=3) == snapshot(A(a=1, b=2, c=3)) +""" + } + ), + ) From 9b2593a305a398ce81583491e26003a34b62e576 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sun, 11 Jan 2026 20:52:08 +0100 Subject: [PATCH 46/72] docs: fixed some documentation errors --- src/inline_snapshot/_customize/_custom_value.py | 4 +++- src/inline_snapshot/plugin/_spec.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/inline_snapshot/_customize/_custom_value.py b/src/inline_snapshot/_customize/_custom_value.py index bde9f296..9c7f4dc0 100644 --- a/src/inline_snapshot/_customize/_custom_value.py +++ b/src/inline_snapshot/_customize/_custom_value.py @@ -4,6 +4,8 @@ import importlib from typing import Generator +from typing_extensions import Self + from inline_snapshot._adapter_context import AdapterContext from inline_snapshot._change import ChangeBase from inline_snapshot._code_repr import HasRepr @@ -59,7 +61,7 @@ def __repr__(self): def _needed_imports(self): yield from self._imports - def with_import(self, module, name, simplify=True): + def with_import(self, module: str, name: str, simplify: bool = True) -> Self: """ Adds a `from module import name` statement to the generated code. diff --git a/src/inline_snapshot/plugin/_spec.py b/src/inline_snapshot/plugin/_spec.py index c2c9fbce..ddf3967d 100644 --- a/src/inline_snapshot/plugin/_spec.py +++ b/src/inline_snapshot/plugin/_spec.py @@ -53,7 +53,7 @@ def customize( `name` and `value` attributes. Returns: - (Custom): created using [Builder][inline_snapshot.Builder] `create_*` methods. + (Custom): created using [Builder][inline_snapshot.plugin.Builder] `create_*` methods. (None): if this handler doesn't apply to the given value. (Something else): when the next handler should process the value. From ceee21dc9ad1b414536a0ce09dc770fdc2e1e1c2 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Wed, 14 Jan 2026 18:17:32 +0100 Subject: [PATCH 47/72] fix: several bugfixes which I found while I tested pydantic-ai --- src/inline_snapshot/_customize/_builder.py | 5 ++ .../_customize/_custom_value.py | 1 + .../_external/_find_external.py | 9 +- src/inline_snapshot/_new_adapter.py | 8 +- src/inline_snapshot/_utils.py | 12 ++- src/inline_snapshot/plugin/_default_plugin.py | 2 +- tests/adapter/test_dataclass.py | 87 +++++++++++++++++++ tests/external/test_external.py | 22 +++++ 8 files changed, 138 insertions(+), 8 deletions(-) diff --git a/src/inline_snapshot/_customize/_builder.py b/src/inline_snapshot/_customize/_builder.py index ec1a60ce..cc33c9e2 100644 --- a/src/inline_snapshot/_customize/_builder.py +++ b/src/inline_snapshot/_customize/_builder.py @@ -62,6 +62,11 @@ def _get_handler(self, v) -> Custom: result.__dict__["original_value"] = v return result + def with_default(self, value: Any, default: Any): + if value == default: + return CustomDefault(value=self._get_handler(value)) + return value + def create_external( self, value: Any, format: str | None = None, storage: str | None = None ): diff --git a/src/inline_snapshot/_customize/_custom_value.py b/src/inline_snapshot/_customize/_custom_value.py index 9c7f4dc0..544f536d 100644 --- a/src/inline_snapshot/_customize/_custom_value.py +++ b/src/inline_snapshot/_customize/_custom_value.py @@ -80,6 +80,7 @@ def with_import(self, module: str, name: str, simplify: bool = True) -> Self: builder.create_value(my_obj, "secrets[0]").with_import("my_secrets", "secrets") ``` """ + name = name.split("[")[0] if simplify: module = _simplify_module_path(module, name) self._imports.append([module, name]) diff --git a/src/inline_snapshot/_external/_find_external.py b/src/inline_snapshot/_external/_find_external.py index e00c2eb1..af989f2a 100644 --- a/src/inline_snapshot/_external/_find_external.py +++ b/src/inline_snapshot/_external/_find_external.py @@ -123,7 +123,14 @@ def ensure_import(filename, imports, recorder: ChangeRecorder): last_import = None for node in tree.body: - if not isinstance(node, (ast.ImportFrom, ast.Import)): + if not ( + isinstance(node, (ast.ImportFrom, ast.Import)) + or ( + isinstance(node, ast.Expr) + and isinstance(node.value, ast.Constant) + and isinstance(node.value.value, str) + ) + ): break last_import = node diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index 2ed1478b..ae6951a3 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -463,14 +463,14 @@ def intercept(stream): old_keywords = {kw.arg: kw.value for kw in old_node.keywords} for kw_arg, kw_value in old_keywords.items(): - if kw_arg not in new_kwargs or isinstance( - new_kwargs[kw_arg], CustomDefault - ): + missing = kw_arg not in new_kwargs + if missing or isinstance(new_kwargs[kw_arg], CustomDefault): # delete entries yield Delete( ( "update" - if old_value.argument(kw_arg) == new_value.argument(kw_arg) + if not missing + and old_value.argument(kw_arg) == new_value.argument(kw_arg) else flag ), self.context.file, diff --git a/src/inline_snapshot/_utils.py b/src/inline_snapshot/_utils.py index 492d8312..48d4f1c2 100644 --- a/src/inline_snapshot/_utils.py +++ b/src/inline_snapshot/_utils.py @@ -6,7 +6,7 @@ from inline_snapshot._exceptions import UsageError -from ._code_repr import value_code_repr +from ._code_repr import real_repr def link(text, link=None): @@ -156,7 +156,7 @@ def clone(obj): inline-snapshot uses `copy.deepcopy` to copy objects, but the copied object is not equal to the original one: -value = {value_code_repr(obj)} +value = {real_repr(obj)} copied_value = copy.deepcopy(value) assert value == copied_value @@ -164,3 +164,11 @@ def clone(obj): """ ) return new + + +def clone_if_equal(obj): + new = copy.deepcopy(obj) + if obj == new: + return new + else: + return obj diff --git a/src/inline_snapshot/plugin/_default_plugin.py b/src/inline_snapshot/plugin/_default_plugin.py index 492d8ad9..9fcb62df 100644 --- a/src/inline_snapshot/plugin/_default_plugin.py +++ b/src/inline_snapshot/plugin/_default_plugin.py @@ -70,7 +70,7 @@ def builtin_function_handler(self, value, builder: Builder): @customize def type_handler(self, value, builder: Builder): if isinstance(value, type): - qualname = value.__qualname__ + qualname = value.__qualname__.split("[")[0] name = qualname.split(".")[0] return builder.create_code(value, qualname).with_import( value.__module__, name diff --git a/tests/adapter/test_dataclass.py b/tests/adapter/test_dataclass.py index cfb81682..4094d24d 100644 --- a/tests/adapter/test_dataclass.py +++ b/tests/adapter/test_dataclass.py @@ -685,3 +685,90 @@ def test_list(): } ), ) + + +def test_dataclass_custom_init(): + + Example( + """\ +from dataclasses import dataclass +from inline_snapshot import snapshot + +@dataclass(init=False) +class A: + _a:int + + def __init__(self,a:int=None,_a:int=None): + self._a=a or _a + + @property + def a(self): + return self._a + + +def test_A(): + assert A(a=5) == snapshot(A(_a=5)) + assert A(a=5) == snapshot(A(a=5)) + + assert A(a=5) == snapshot(A(_a=4)) + assert A(a=5) == snapshot(A(a=4)) +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from dataclasses import dataclass +from inline_snapshot import snapshot + +@dataclass(init=False) +class A: + _a:int + + def __init__(self,a:int=None,_a:int=None): + self._a=a or _a + + @property + def a(self): + return self._a + + +def test_A(): + assert A(a=5) == snapshot(A(_a=5)) + assert A(a=5) == snapshot(A(a=5)) + + assert A(a=5) == snapshot(A(_a=5)) + assert A(a=5) == snapshot(A(_a=5)) +""" + } + ), + ).run_inline( + ["--inline-snapshot=update"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from dataclasses import dataclass +from inline_snapshot import snapshot + +@dataclass(init=False) +class A: + _a:int + + def __init__(self,a:int=None,_a:int=None): + self._a=a or _a + + @property + def a(self): + return self._a + + +def test_A(): + assert A(a=5) == snapshot(A(_a=5)) + assert A(a=5) == snapshot(A(_a=5)) + + assert A(a=5) == snapshot(A(_a=5)) + assert A(a=5) == snapshot(A(_a=5)) +""" + } + ), + ) diff --git a/tests/external/test_external.py b/tests/external/test_external.py index 4148ea6c..1158d860 100644 --- a/tests/external/test_external.py +++ b/tests/external/test_external.py @@ -527,6 +527,28 @@ def test_ensure_imports_with_comment(tmp_path): ) +def test_ensure_imports_with_docstring(tmp_path): + file = tmp_path / "file.py" + file.write_bytes( + b"""\ +''' docstring ''' +from __future__ import annotations +""" + ) + + with apply_changes() as recorder: + ensure_import(file, {"os": ["chdir"]}, recorder) + + assert file.read_text("utf-8") == snapshot( + """\ +''' docstring ''' +from __future__ import annotations + +from os import chdir +""" + ) + + def test_new_externals(project): project.pyproject( """\ From e3fc57c4d1fb9ec39a02674a0efa75ff2098c1fb Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Wed, 14 Jan 2026 19:10:47 +0100 Subject: [PATCH 48/72] refactor: used to with_default and removed create_default --- src/inline_snapshot/_customize/_builder.py | 29 +++++------ src/inline_snapshot/plugin/_default_plugin.py | 50 ++++++++----------- 2 files changed, 35 insertions(+), 44 deletions(-) diff --git a/src/inline_snapshot/_customize/_builder.py b/src/inline_snapshot/_customize/_builder.py index cc33c9e2..156e666e 100644 --- a/src/inline_snapshot/_customize/_builder.py +++ b/src/inline_snapshot/_customize/_builder.py @@ -5,6 +5,7 @@ from inline_snapshot._adapter_context import AdapterContext from inline_snapshot._compare_context import compare_context +from inline_snapshot._exceptions import UsageError from inline_snapshot.plugin._context_variable import ContextVariable from ._custom import Custom @@ -62,11 +63,6 @@ def _get_handler(self, v) -> Custom: result.__dict__["original_value"] = v return result - def with_default(self, value: Any, default: Any): - if value == default: - return CustomDefault(value=self._get_handler(value)) - return value - def create_external( self, value: Any, format: str | None = None, storage: str | None = None ): @@ -96,6 +92,20 @@ def create_tuple(self, value) -> Custom: custom = [self._get_handler(v) for v in value] return CustomTuple(value=custom) + def with_default(self, value: Any, default: Any): + """ + Creates an intermediate node for a default value which can be used as an argument for create_call. + + Arguments are not included in the generated code when they match the actual default. + The value doesn't have to be a Custom node and is converted by inline-snapshot if needed. + """ + if isinstance(default, Custom): + raise UsageError("default value can not be an Custom value") + + if value == default: + return CustomDefault(value=self._get_handler(value)) + return value + def create_call( self, function, posonly_args=[], kwargs={}, kwonly_args={} ) -> Custom: @@ -117,15 +127,6 @@ def create_call( _kwonly=kwonly_args, ) - def create_default(self, value) -> Custom: - """ - Creates an intermediate node for a default value which can be used as a result for your customization function. - - Default values are not included in the generated code when they match the actual default. - The value doesn't have to be a Custom node and is converted by inline-snapshot if needed. - """ - return CustomDefault(value=self._get_handler(value)) - def create_dict(self, value) -> Custom: """ Creates an intermediate node for a dict-expression which can be used as a result for your customization function. diff --git a/src/inline_snapshot/plugin/_default_plugin.py b/src/inline_snapshot/plugin/_default_plugin.py index 9fcb62df..6380e66c 100644 --- a/src/inline_snapshot/plugin/_default_plugin.py +++ b/src/inline_snapshot/plugin/_default_plugin.py @@ -152,19 +152,15 @@ def dataclass_handler(self, value, builder: Builder): for field in fields(value): # type: ignore if field.repr: field_value = getattr(value, field.name) - is_default = False - if field.default != MISSING and field.default == field_value: - is_default = True + if field.default != MISSING: + field_value = builder.with_default(field_value, field.default) - if ( - field.default_factory != MISSING - and field.default_factory() == field_value - ): - is_default = True + if field.default_factory != MISSING: + field_value = builder.with_default( + field_value, field.default_factory() + ) - if is_default: - field_value = builder.create_default(field_value) kwargs[field.name] = field_value return builder.create_call(type(value), [], kwargs, {}) @@ -188,8 +184,9 @@ def namedtuple_handler(self, value, builder: Builder): field: ( getattr(value, field) if field not in value._field_defaults - or getattr(value, field) != value._field_defaults[field] - else builder.create_default(value._field_defaults[field]) + else builder.with_default( + getattr(value, field), value._field_defaults[field] + ) ) for field in value._fields }, @@ -264,7 +261,6 @@ def attrs_handler(self, value, builder: Builder): for field in attrs.fields(type(value)): if field.repr: field_value = getattr(value, field.name) - is_default = False if field.default is not attrs.NOTHING: @@ -277,12 +273,9 @@ def attrs_handler(self, value, builder: Builder): else field.default.factory(value) ) ) - - if default_value == field_value: - is_default = True - - if is_default: - field_value = builder.create_default(field_value) + field_value = builder.with_default( + field_value, default_value + ) kwargs[field.name] = field_value @@ -314,7 +307,7 @@ def get_fields(value): class InlineSnapshotPydanticPlugin: @customize - def attrs_handler(self, value, builder: Builder): + def pydantic_model_handler(self, value, builder: Builder): if isinstance(value, BaseModel): @@ -323,22 +316,19 @@ def attrs_handler(self, value, builder: Builder): for name, field in get_fields(value).items(): # type: ignore if getattr(field, "repr", True): field_value = getattr(value, name) - is_default = False if ( field.default is not PydanticUndefined and field.default == field_value ): - is_default = True - - if ( - field.default_factory is not None - and field.default_factory() == field_value - ): - is_default = True + field_value = builder.with_default( + field_value, field.default + ) - if is_default: - field_value = builder.create_default(field_value) + elif field.default_factory is not None: + field_value = builder.with_default( + field_value, field.default_factory() + ) kwargs[name] = field_value From 9da49c09341e52ded98e4d2f4262a12171a0e18c Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Wed, 14 Jan 2026 19:22:13 +0100 Subject: [PATCH 49/72] refactor: local/global_vars are now a dict --- docs/plugin.md | 10 ++++---- src/inline_snapshot/_customize/_builder.py | 17 +++++++------- src/inline_snapshot/plugin/__init__.py | 2 -- .../plugin/_context_variable.py | 23 ------------------- src/inline_snapshot/plugin/_default_plugin.py | 6 ----- src/inline_snapshot/plugin/_spec.py | 15 +++++------- 6 files changed, 19 insertions(+), 54 deletions(-) delete mode 100644 src/inline_snapshot/plugin/_context_variable.py diff --git a/docs/plugin.md b/docs/plugin.md index 8471e5d4..7db2cbc0 100644 --- a/docs/plugin.md +++ b/docs/plugin.md @@ -242,10 +242,10 @@ from inline_snapshot.plugin import Builder class InlineSnapshotPlugin: @customize - def local_var_handler(self, value, local_vars): - for local in local_vars: - if local.name.startswith("v_") and local.value == value: - return local + def local_var_handler(self, value, builder, local_vars): + for var_name, var_value in local_vars.items(): + if var_name.startswith("v_") and var_value == value: + return builder.create_code(value, var_name) ``` We check all local variables to see if they match our naming convention and are equal to the value that is part of our snapshot, and return the local variable if we find one that fits the criteria. @@ -365,7 +365,7 @@ def test_my_class(): ::: inline_snapshot.plugin options: heading_level: 3 - members: [hookimpl,customize,Builder,Custom,CustomCode,ContextVariable] + members: [hookimpl,customize,Builder,Custom,CustomCode] show_root_heading: false show_bases: false show_source: false diff --git a/src/inline_snapshot/_customize/_builder.py b/src/inline_snapshot/_customize/_builder.py index 156e666e..48f7724a 100644 --- a/src/inline_snapshot/_customize/_builder.py +++ b/src/inline_snapshot/_customize/_builder.py @@ -6,7 +6,6 @@ from inline_snapshot._adapter_context import AdapterContext from inline_snapshot._compare_context import compare_context from inline_snapshot._exceptions import UsageError -from inline_snapshot.plugin._context_variable import ContextVariable from ._custom import Custom from ._custom_call import CustomCall @@ -31,19 +30,19 @@ def _get_handler(self, v) -> Custom: self._snapshot_context is not None and (frame := self._snapshot_context.frame) is not None ): - local_vars = [ - ContextVariable(var_name, var_value) + local_vars = { + var_name: var_value for var_name, var_value in frame.locals.items() if "@" not in var_name - ] - global_vars = [ - ContextVariable(var_name, var_value) + } + global_vars = { + var_name: var_value for var_name, var_value in frame.globals.items() if "@" not in var_name - ] + } else: - local_vars = [] - global_vars = [] + local_vars = {} + global_vars = {} result = v diff --git a/src/inline_snapshot/plugin/__init__.py b/src/inline_snapshot/plugin/__init__.py index 8e6c2b06..4dae7995 100644 --- a/src/inline_snapshot/plugin/__init__.py +++ b/src/inline_snapshot/plugin/__init__.py @@ -2,7 +2,6 @@ from .._customize._builder import Builder from .._customize._custom import Custom -from ._context_variable import ContextVariable from ._spec import InlineSnapshotPluginSpec from ._spec import customize from ._spec import hookimpl @@ -12,7 +11,6 @@ "customize", "hookimpl", "Builder", - "ContextVariable", "Custom", "CustomCode", ) diff --git a/src/inline_snapshot/plugin/_context_variable.py b/src/inline_snapshot/plugin/_context_variable.py deleted file mode 100644 index 3f57ede1..00000000 --- a/src/inline_snapshot/plugin/_context_variable.py +++ /dev/null @@ -1,23 +0,0 @@ -from dataclasses import dataclass -from typing import Any - - -@dataclass -class ContextVariable: - """ - Representation of a value in the local or global context of a snapshot. - - This type can be returned from a customize function to reference an existing variable - instead of creating a new literal value. Inline-snapshot includes a built-in handler - that converts ContextVariable instances into [Custom][inline_snapshot.plugin.Custom] - objects, generating code that references the variable by name. - - ContextVariable instances are provided via the `local_vars` and `global_vars` parameters - of the [customize hook][inline_snapshot.plugin.InlineSnapshotPluginSpec.customize]. - """ - - name: str - "the name of the variable" - - value: Any - "the value of the variable" diff --git a/src/inline_snapshot/plugin/_default_plugin.py b/src/inline_snapshot/plugin/_default_plugin.py index 6380e66c..52153591 100644 --- a/src/inline_snapshot/plugin/_default_plugin.py +++ b/src/inline_snapshot/plugin/_default_plugin.py @@ -19,7 +19,6 @@ from inline_snapshot._unmanaged import is_dirty_equal from inline_snapshot._unmanaged import is_unmanaged from inline_snapshot._utils import triple_quote -from inline_snapshot.plugin._context_variable import ContextVariable from ._spec import customize @@ -229,11 +228,6 @@ def dirty_equals_handler(self, value, builder: Builder): kwargs.pop("approx") return builder.create_call(type(value), args, kwargs) - @customize(tryfirst=True) - def context_value_handler(self, value, builder: Builder): - if isinstance(value, ContextVariable): - return builder.create_code(value.value, value.name) - @customize def outsource_handler(self, value, builder: Builder): if isinstance(value, Outsourced): diff --git a/src/inline_snapshot/plugin/_spec.py b/src/inline_snapshot/plugin/_spec.py index ddf3967d..dad9ad89 100644 --- a/src/inline_snapshot/plugin/_spec.py +++ b/src/inline_snapshot/plugin/_spec.py @@ -1,11 +1,10 @@ from functools import partial from typing import Any -from typing import List +from typing import Dict import pluggy from inline_snapshot._customize._builder import Builder -from inline_snapshot.plugin._context_variable import ContextVariable inline_snapshot_plugin_name = "inline-snapshot" @@ -31,8 +30,8 @@ def customize( self, value: Any, builder: Builder, - local_vars: List[ContextVariable], - global_vars: List[ContextVariable], + local_vars: Dict[str, Any], + global_vars: Dict[str, Any], ) -> Any: """ The customize hook is called every time a snapshot value should be converted into code. @@ -46,11 +45,9 @@ def customize( This is the actual runtime value from your test. builder: A Builder instance providing methods to construct custom code representations. Use methods like `create_call()`, `create_dict()`, `create_external()`, etc. - local_vars: List of local variables available in the current scope, each containing - `name` and `value` attributes. Useful for referencing existing variables - instead of creating new literals. - global_vars: List of global variables available in the current scope, each containing - `name` and `value` attributes. + local_vars: Dictionary mapping variable names to their values in the local scope. + Useful for referencing existing variables instead of creating new literals. + global_vars: Dictionary mapping variable names to their values in the global scope. Returns: (Custom): created using [Builder][inline_snapshot.plugin.Builder] `create_*` methods. From 2fd6c20e1bcd87687e351bbb889245e130234adc Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Wed, 14 Jan 2026 22:16:50 +0100 Subject: [PATCH 50/72] feat: typing, custom __file__ --- src/inline_snapshot/_customize/_builder.py | 11 ++++---- src/inline_snapshot/plugin/_default_plugin.py | 28 +++++++++++-------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/inline_snapshot/_customize/_builder.py b/src/inline_snapshot/_customize/_builder.py index 48f7724a..bfa1a004 100644 --- a/src/inline_snapshot/_customize/_builder.py +++ b/src/inline_snapshot/_customize/_builder.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from typing import Any +from typing import Callable from inline_snapshot._adapter_context import AdapterContext from inline_snapshot._compare_context import compare_context @@ -71,7 +72,7 @@ def create_external( return CustomExternal(value, format=format, storage=storage) - def create_list(self, value) -> Custom: + def create_list(self, value: list) -> Custom: """ Creates an intermediate node for a list-expression which can be used as a result for your customization function. @@ -81,7 +82,7 @@ def create_list(self, value) -> Custom: custom = [self._get_handler(v) for v in value] return CustomList(value=custom) - def create_tuple(self, value) -> Custom: + def create_tuple(self, value: tuple) -> Custom: """ Creates an intermediate node for a tuple-expression which can be used as a result for your customization function. @@ -106,7 +107,7 @@ def with_default(self, value: Any, default: Any): return value def create_call( - self, function, posonly_args=[], kwargs={}, kwonly_args={} + self, function: Custom | Callable, posonly_args=[], kwargs={}, kwonly_args={} ) -> Custom: """ Creates an intermediate node for a function call expression which can be used as a result for your customization function. @@ -126,7 +127,7 @@ def create_call( _kwonly=kwonly_args, ) - def create_dict(self, value) -> Custom: + def create_dict(self, value: dict) -> Custom: """ Creates an intermediate node for a dict-expression which can be used as a result for your customization function. @@ -136,7 +137,7 @@ def create_dict(self, value) -> Custom: custom = {self._get_handler(k): self._get_handler(v) for k, v in value.items()} return CustomDict(value=custom) - def create_code(self, value, repr: str | None = None) -> CustomCode: + def create_code(self, value: Any, repr: str | None = None) -> CustomCode: """ Creates an intermediate node for a value with a custom representation which can be used as a result for your customization function. diff --git a/src/inline_snapshot/plugin/_default_plugin.py b/src/inline_snapshot/plugin/_default_plugin.py index 52153591..10848b55 100644 --- a/src/inline_snapshot/plugin/_default_plugin.py +++ b/src/inline_snapshot/plugin/_default_plugin.py @@ -10,6 +10,8 @@ from pathlib import PurePath from types import BuiltinFunctionType from types import FunctionType +from typing import Any +from typing import Dict from inline_snapshot._customize._builder import Builder from inline_snapshot._customize._custom_undefined import CustomUndefined @@ -53,9 +55,15 @@ def counter_handler(self, value, builder: Builder): return builder.create_call(Counter, [dict(value)]) @customize - def function_handler(self, value, builder: Builder): - if isinstance(value, FunctionType): - qualname = value.__qualname__ + def function_and_type_handler( + self, value, builder: Builder, local_vars: Dict[str, Any] + ): + if isinstance(value, (FunctionType, type)): + for name, local_value in local_vars.items(): + if local_value is value: + return builder.create_code(value, name) + + qualname = value.__qualname__.split("[")[0] name = qualname.split(".")[0] return builder.create_code(value, qualname).with_import( value.__module__, name @@ -66,15 +74,6 @@ def builtin_function_handler(self, value, builder: Builder): if isinstance(value, BuiltinFunctionType): return builder.create_code(value, value.__name__) - @customize - def type_handler(self, value, builder: Builder): - if isinstance(value, type): - qualname = value.__qualname__.split("[")[0] - name = qualname.split(".")[0] - return builder.create_code(value, qualname).with_import( - value.__module__, name - ) - @customize def path_handler(self, value, builder: Builder): if isinstance(value, Path): @@ -141,6 +140,11 @@ def flag_handler(self, value, builder: Builder): ), ).with_import(type(value).__module__, name) + @customize + def source_file_name_handler(self, value, builder: Builder, global_vars): + if value == global_vars["__file__"]: + return builder.create_code(value, "__file__") + @customize def dataclass_handler(self, value, builder: Builder): From 6f7b7e6ae1b364466970a49c767efb69d9b41dd0 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Wed, 14 Jan 2026 22:29:32 +0100 Subject: [PATCH 51/72] refactor: moved _custom_value.py --- src/inline_snapshot/_customize/_builder.py | 2 +- .../_customize/{_custom_value.py => _custom_code.py} | 0 src/inline_snapshot/_new_adapter.py | 2 +- src/inline_snapshot/_snapshot/undecided_value.py | 2 +- src/inline_snapshot/plugin/__init__.py | 2 +- src/inline_snapshot/plugin/_default_plugin.py | 2 +- 6 files changed, 5 insertions(+), 5 deletions(-) rename src/inline_snapshot/_customize/{_custom_value.py => _custom_code.py} (100%) diff --git a/src/inline_snapshot/_customize/_builder.py b/src/inline_snapshot/_customize/_builder.py index bfa1a004..04a1bf01 100644 --- a/src/inline_snapshot/_customize/_builder.py +++ b/src/inline_snapshot/_customize/_builder.py @@ -11,11 +11,11 @@ from ._custom import Custom from ._custom_call import CustomCall from ._custom_call import CustomDefault +from ._custom_code import CustomCode from ._custom_dict import CustomDict from ._custom_external import CustomExternal from ._custom_sequence import CustomList from ._custom_sequence import CustomTuple -from ._custom_value import CustomCode @dataclass diff --git a/src/inline_snapshot/_customize/_custom_value.py b/src/inline_snapshot/_customize/_custom_code.py similarity index 100% rename from src/inline_snapshot/_customize/_custom_value.py rename to src/inline_snapshot/_customize/_custom_code.py diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index ae6951a3..9b603338 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -20,12 +20,12 @@ from inline_snapshot._customize._custom import Custom from inline_snapshot._customize._custom_call import CustomCall from inline_snapshot._customize._custom_call import CustomDefault +from inline_snapshot._customize._custom_code import CustomCode from inline_snapshot._customize._custom_dict import CustomDict from inline_snapshot._customize._custom_sequence import CustomList from inline_snapshot._customize._custom_sequence import CustomSequence from inline_snapshot._customize._custom_undefined import CustomUndefined from inline_snapshot._customize._custom_unmanaged import CustomUnmanaged -from inline_snapshot._customize._custom_value import CustomCode from inline_snapshot._exceptions import UsageError from inline_snapshot._generator_utils import only_value from inline_snapshot.syntax_warnings import InlineSnapshotInfo diff --git a/src/inline_snapshot/_snapshot/undecided_value.py b/src/inline_snapshot/_snapshot/undecided_value.py index 388588e0..b9bd1055 100644 --- a/src/inline_snapshot/_snapshot/undecided_value.py +++ b/src/inline_snapshot/_snapshot/undecided_value.py @@ -5,12 +5,12 @@ from inline_snapshot._compare_context import compare_only from inline_snapshot._customize._custom import Custom from inline_snapshot._customize._custom_call import CustomCall +from inline_snapshot._customize._custom_code import CustomCode from inline_snapshot._customize._custom_dict import CustomDict from inline_snapshot._customize._custom_sequence import CustomList from inline_snapshot._customize._custom_sequence import CustomTuple from inline_snapshot._customize._custom_undefined import CustomUndefined from inline_snapshot._customize._custom_unmanaged import CustomUnmanaged -from inline_snapshot._customize._custom_value import CustomCode from inline_snapshot._new_adapter import NewAdapter from inline_snapshot._new_adapter import warn_star_expression from inline_snapshot._unmanaged import is_unmanaged diff --git a/src/inline_snapshot/plugin/__init__.py b/src/inline_snapshot/plugin/__init__.py index 4dae7995..091e33a5 100644 --- a/src/inline_snapshot/plugin/__init__.py +++ b/src/inline_snapshot/plugin/__init__.py @@ -1,4 +1,4 @@ -from inline_snapshot._customize._custom_value import CustomCode +from inline_snapshot._customize._custom_code import CustomCode from .._customize._builder import Builder from .._customize._custom import Custom diff --git a/src/inline_snapshot/plugin/_default_plugin.py b/src/inline_snapshot/plugin/_default_plugin.py index 10848b55..bf008a33 100644 --- a/src/inline_snapshot/plugin/_default_plugin.py +++ b/src/inline_snapshot/plugin/_default_plugin.py @@ -142,7 +142,7 @@ def flag_handler(self, value, builder: Builder): @customize def source_file_name_handler(self, value, builder: Builder, global_vars): - if value == global_vars["__file__"]: + if "__file__" in global_vars and value == global_vars["__file__"]: return builder.create_code(value, "__file__") @customize From 7a460271dbde69075b18629f5c68a7a371eb29a9 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sat, 17 Jan 2026 08:17:53 +0100 Subject: [PATCH 52/72] feat: with_import('module') --- docs/plugin.md | 6 +- src/inline_snapshot/_change.py | 13 +- src/inline_snapshot/_customize/_builder.py | 2 +- .../_customize/_custom_code.py | 27 +++- .../_external/_find_external.py | 26 +++- src/inline_snapshot/_inline_snapshot.py | 14 +- src/inline_snapshot/_new_adapter.py | 18 ++- src/inline_snapshot/plugin/_default_plugin.py | 8 +- tests/external/test_external.py | 6 +- tests/test_customize.py | 127 ++++++++++++++++++ 10 files changed, 219 insertions(+), 28 deletions(-) diff --git a/docs/plugin.md b/docs/plugin.md index 7db2cbc0..da9b89f1 100644 --- a/docs/plugin.md +++ b/docs/plugin.md @@ -322,9 +322,9 @@ class InlineSnapshotPlugin: def secret_handler(self, value, builder: Builder): for i, secret in enumerate(secrets): if value == secret: - return builder.create_code(secret, f"secrets[{i}]").with_import( - "my_secrets", "secrets" - ) + return builder.create_code( + secret, f"secrets[{i}]" + ).with_import_from("my_secrets", "secrets") ``` The [`create_code()`][inline_snapshot.plugin.Builder.create_code] method takes the actual value and its desired code representation, then [`with_import()`][inline_snapshot.plugin.CustomCode.with_import] adds the necessary import statement. diff --git a/src/inline_snapshot/_change.py b/src/inline_snapshot/_change.py index 1edcaa90..fcd5317b 100644 --- a/src/inline_snapshot/_change.py +++ b/src/inline_snapshot/_change.py @@ -1,6 +1,7 @@ from __future__ import annotations import ast +import dataclasses from collections import defaultdict from dataclasses import dataclass from pathlib import Path @@ -119,6 +120,7 @@ def apply_external_changes(self): @dataclass() class RequiredImports(Change): imports: dict[str, set[str]] + module_imports: set[str] = dataclasses.field(default_factory=set) @dataclass() @@ -277,6 +279,7 @@ def apply_all(all_changes: list[ChangeBase], recorder: ChangeRecorder): # file -> module -> names imports_by_file: dict[str, dict[str, set]] = defaultdict(lambda: defaultdict(set)) + module_imports_by_file: dict[str, set] = defaultdict(set) for change in all_changes: if isinstance(change, Delete): @@ -292,12 +295,16 @@ def apply_all(all_changes: list[ChangeBase], recorder: ChangeRecorder): sources[node] = change.file elif isinstance(change, RequiredImports): for module, names in change.imports.items(): - imports_by_file[change.filename][module] |= set(names) + imports_by_file[change.file.filename][module] |= set(names) + for module in change.module_imports: + module_imports_by_file[change.file.filename].add(module) else: change.apply(recorder) - for filename, imports in imports_by_file.items(): - ensure_import(filename, imports, recorder) + for filename in set(imports_by_file) | set(module_imports_by_file): + imports = imports_by_file.get(filename, defaultdict(set)) + module_imports = module_imports_by_file.get(filename, set()) + ensure_import(filename, imports, module_imports, recorder) for parent, changes in by_parent.items(): source = sources[parent] diff --git a/src/inline_snapshot/_customize/_builder.py b/src/inline_snapshot/_customize/_builder.py index 04a1bf01..f9e022cc 100644 --- a/src/inline_snapshot/_customize/_builder.py +++ b/src/inline_snapshot/_customize/_builder.py @@ -144,6 +144,6 @@ def create_code(self, value: Any, repr: str | None = None) -> CustomCode: `create_code(value, '{value-1!r}+1')` becomes `4+1` in the code for a given `value=5`. Use this when you need to control the exact string representation of a value. - You can use [`.with_import(module,name)`][inline_snapshot.plugin.CustomCode.with_import] to create an import in the code. + You can use [`.with_import_from(module,name)`][inline_snapshot.plugin.CustomCode.with_import_from] to create an import in the code. """ return CustomCode(value, repr) diff --git a/src/inline_snapshot/_customize/_custom_code.py b/src/inline_snapshot/_customize/_custom_code.py index 544f536d..75b1cc78 100644 --- a/src/inline_snapshot/_customize/_custom_code.py +++ b/src/inline_snapshot/_customize/_custom_code.py @@ -40,7 +40,7 @@ def __init__(self, value, repr_str=None): ast.parse(self.repr_str) except SyntaxError: self.repr_str = HasRepr(type(value), self.repr_str).__repr__() - self.with_import("inline_snapshot", "HasRepr") + self.with_import_from("inline_snapshot", "HasRepr") else: self.repr_str = repr_str @@ -61,7 +61,7 @@ def __repr__(self): def _needed_imports(self): yield from self._imports - def with_import(self, module: str, name: str, simplify: bool = True) -> Self: + def with_import_from(self, module: str, name: str, simplify: bool = True) -> Self: """ Adds a `from module import name` statement to the generated code. @@ -77,7 +77,9 @@ def with_import(self, module: str, name: str, simplify: bool = True) -> Self: Example: ``` python - builder.create_value(my_obj, "secrets[0]").with_import("my_secrets", "secrets") + builder.create_code(my_obj, "secrets[0]").with_import_from( + "my_secrets", "secrets" + ) ``` """ name = name.split("[")[0] @@ -86,3 +88,22 @@ def with_import(self, module: str, name: str, simplify: bool = True) -> Self: self._imports.append([module, name]) return self + + def with_import(self, module: str) -> Self: + """ + Adds an `import module` statement to the generated code. + + Arguments: + module: The module path to import (e.g., "os.path" or "collections.abc"). + + Returns: + The CustomCode instance itself, allowing for method chaining. + + Example: + ``` python + builder.create_code(my_obj, "os.path.join('a', 'b')").with_import("os.path") + ``` + """ + self._imports.append((module,)) + + return self diff --git a/src/inline_snapshot/_external/_find_external.py b/src/inline_snapshot/_external/_find_external.py index af989f2a..280883e8 100644 --- a/src/inline_snapshot/_external/_find_external.py +++ b/src/inline_snapshot/_external/_find_external.py @@ -26,6 +26,14 @@ def contains_import(tree, module, name): return False +def contains_module_import(tree, module): + for node in tree.body: + if isinstance(node, ast.Import): + if any(alias.name == module for alias in node.names): + return True + return False + + def used_externals_in( filename: Path, source: Union[str, ast.Module], check_import=True ) -> List[ExternalLocation]: @@ -98,7 +106,7 @@ def module_name_of(filename: str | os.PathLike) -> Optional[str]: return ".".join(parts) -def ensure_import(filename, imports, recorder: ChangeRecorder): +def ensure_import(filename, imports, module_imports, recorder: ChangeRecorder): source = Source.for_filename(filename) change = recorder.new_change() @@ -107,6 +115,7 @@ def ensure_import(filename, imports, recorder: ChangeRecorder): token = source.asttokens() to_add = [] + modules_to_add = [] my_module = module_name_of(filename) @@ -119,6 +128,14 @@ def ensure_import(filename, imports, recorder: ChangeRecorder): if not contains_import(tree, module, name): to_add.append((module, name)) + for module in sorted(module_imports): + if module == my_module: + continue + if module == "builtins": + continue + if not contains_module_import(tree, module): + modules_to_add.append(module) + assert isinstance(tree, ast.Module) last_import = None @@ -147,7 +164,12 @@ def ensure_import(filename, imports, recorder: ChangeRecorder): position = end_of(last_token) code = "" + for module in modules_to_add: + code += f"import {module}\n" for module, name in to_add: - code += f"\nfrom {module} import {name}\n" + code += f"from {module} import {name}\n" + + if code: + code = "\n" + code change.insert(position, code, filename=filename) diff --git a/src/inline_snapshot/_inline_snapshot.py b/src/inline_snapshot/_inline_snapshot.py index 3e9a61a2..7c6fc7aa 100644 --- a/src/inline_snapshot/_inline_snapshot.py +++ b/src/inline_snapshot/_inline_snapshot.py @@ -152,11 +152,19 @@ def _changes(self) -> Iterator[ChangeBase]: ) imports: dict[str, set[str]] = defaultdict(set) - for module, name in self._value._needed_imports(): - imports[module].add(name) + module_imports: set[str] = set() + for import_info in self._value._needed_imports(): + if len(import_info) == 2: + module, name = import_info + imports[module].add(name) + elif len(import_info) == 1: + module_imports.add(import_info[0]) yield RequiredImports( - flag="create", file=self._value._file, imports=imports + flag="create", + file=self._value._file, + imports=imports, + module_imports=module_imports, ) else: diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index 9b603338..d28d1b56 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -224,12 +224,18 @@ def compare_CustomCode( def needed_imports(value: Custom): imports: dict[str, set] = defaultdict(set) - for module, name in value._needed_imports(): - imports[module].add(name) - return imports - - if imports := needed_imports(new_value): - yield RequiredImports(flag, self.context.file, imports) + module_imports: set[str] = set() + for import_info in value._needed_imports(): + if len(import_info) == 2: + module, name = import_info + imports[module].add(name) + elif len(import_info) == 1: + module_imports.add(import_info[0]) + return imports, module_imports + + imports, module_imports = needed_imports(new_value) + if imports or module_imports: + yield RequiredImports(flag, self.context.file, imports, module_imports) return new_value diff --git a/src/inline_snapshot/plugin/_default_plugin.py b/src/inline_snapshot/plugin/_default_plugin.py index bf008a33..960db91d 100644 --- a/src/inline_snapshot/plugin/_default_plugin.py +++ b/src/inline_snapshot/plugin/_default_plugin.py @@ -65,7 +65,7 @@ def function_and_type_handler( qualname = value.__qualname__.split("[")[0] name = qualname.split(".")[0] - return builder.create_code(value, qualname).with_import( + return builder.create_code(value, qualname).with_import_from( value.__module__, name ) @@ -123,7 +123,7 @@ def enum_handler(self, value, builder: Builder): return builder.create_code( value, f"{type(value).__qualname__}.{value.name}" - ).with_import(type(value).__module__, name) + ).with_import_from(type(value).__module__, name) # -8<- [end:Enum] @@ -138,7 +138,7 @@ def flag_handler(self, value, builder: Builder): " | ".join( f"{qualname}.{flag.name}" for flag in type(value) if flag in value ), - ).with_import(type(value).__module__, name) + ).with_import_from(type(value).__module__, name) @customize def source_file_name_handler(self, value, builder: Builder, global_vars): @@ -219,7 +219,7 @@ def dirty_equals_handler(self, value, builder: Builder): if is_dirty_equal(value) and builder._build_new_value: if isinstance(value, type): - return builder.create_code(value, value.__name__).with_import( + return builder.create_code(value, value.__name__).with_import_from( "dirty_equals", value.__name__ ) else: diff --git a/tests/external/test_external.py b/tests/external/test_external.py index 1158d860..1a1302b4 100644 --- a/tests/external/test_external.py +++ b/tests/external/test_external.py @@ -495,7 +495,7 @@ def test_ensure_imports(tmp_path): ) with apply_changes() as recorder: - ensure_import(file, {"os": ["chdir", "environ"]}, recorder) + ensure_import(file, {"os": ["chdir", "environ"]}, set(), recorder) assert file.read_text("utf-8") == snapshot( """\ @@ -516,7 +516,7 @@ def test_ensure_imports_with_comment(tmp_path): ) with apply_changes() as recorder: - ensure_import(file, {"os": ["chdir"]}, recorder) + ensure_import(file, {"os": ["chdir"]}, set(), recorder) assert file.read_text("utf-8") == snapshot( """\ @@ -537,7 +537,7 @@ def test_ensure_imports_with_docstring(tmp_path): ) with apply_changes() as recorder: - ensure_import(file, {"os": ["chdir"]}, recorder) + ensure_import(file, {"os": ["chdir"]}, set(), recorder) assert file.read_text("utf-8") == snapshot( """\ diff --git a/tests/test_customize.py b/tests/test_customize.py index 00e6b2cd..12c89a22 100644 --- a/tests/test_customize.py +++ b/tests/test_customize.py @@ -84,3 +84,130 @@ def test(): } ), ) + + +@pytest.mark.parametrize( + "original,flag", + [("ComplexObj(1, 2)", "update"), ("'wrong'", "fix"), ("", "create")], +) +def test_with_import(original, flag): + """Test that with_import adds both simple and nested module import statements correctly.""" + + Example( + { + "conftest.py": """\ +from inline_snapshot.plugin import customize +from inline_snapshot.plugin import Builder +from pkg.subpkg import ComplexObj + +class InlineSnapshotPlugin: + @customize + def complex_handler(self, value, builder: Builder): + if isinstance(value, ComplexObj): + return builder.create_code( + value, + f"mod1.helper(pkg.subpkg.create({value.a!r}, {value.b!r}))" + ).with_import("mod1").with_import("pkg.subpkg") +""", + "mod1.py": """\ +def helper(obj): + return obj +""", + "pkg/__init__.py": "", + "pkg/subpkg.py": """\ +class ComplexObj: + def __init__(self, a, b): + self.a = a + self.b = b + + def __eq__(self, other): + return isinstance(other, ComplexObj) and self.a == other.a and self.b == other.b + +def create(a, b): + return ComplexObj(a, b) +""", + "test_something.py": f"""\ +from inline_snapshot import snapshot +from pkg.subpkg import ComplexObj + +def test_a(): + assert snapshot({original}) == ComplexObj(1, 2) +""", + } + ).run_inline( + [f"--inline-snapshot={flag}"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot +from pkg.subpkg import ComplexObj + +import mod1 +import pkg.subpkg + +def test_a(): + assert snapshot(mod1.helper(pkg.subpkg.create(1, 2))) == ComplexObj(1, 2) +""" + } + ), + ).run_inline() + + +@pytest.mark.parametrize( + "original,flag", [("MyClass('value')", "update"), ("'wrong'", "fix")] +) +def test_with_import_preserves_existing(original, flag): + """Test that with_import preserves existing import statements.""" + + Example( + { + "conftest.py": """\ +from inline_snapshot.plugin import customize +from inline_snapshot.plugin import Builder +from mymodule import MyClass + +class InlineSnapshotPlugin: + @customize + def myclass_handler(self, value, builder: Builder): + if isinstance(value, MyClass): + return builder.create_code( + value, + f"mymodule.MyClass({value.value!r})" + ).with_import("mymodule") +""", + "mymodule.py": """\ +class MyClass: + def __init__(self, value): + self.value = value + + def __eq__(self, other): + return isinstance(other, MyClass) and self.value == other.value +""", + "test_something.py": f"""\ +from inline_snapshot import snapshot +from mymodule import MyClass + +import mymodule +import os + +def test_a(): + assert snapshot({original}) == MyClass("value") +""", + } + ).run_inline( + [f"--inline-snapshot={flag}"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot +from mymodule import MyClass + +import mymodule +import os + +def test_a(): + assert snapshot(mymodule.MyClass("value")) == MyClass("value") +""" + } + ), + ).run_inline() From 5754bd3452b4924e3b5ea09da2f4f709067b40f3 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sat, 17 Jan 2026 21:56:06 +0100 Subject: [PATCH 53/72] refactor: removed _needed_imports --- docs/external/outsource.md | 2 +- src/inline_snapshot/_change.py | 2 + src/inline_snapshot/_customize/_custom.py | 3 -- .../_customize/_custom_call.py | 14 ------ .../_customize/_custom_code.py | 22 +++++++-- .../_customize/_custom_dict.py | 5 -- .../_customize/_custom_external.py | 7 +-- .../_customize/_custom_sequence.py | 4 -- src/inline_snapshot/_generator_utils.py | 25 ++++++++++ src/inline_snapshot/_inline_snapshot.py | 21 +-------- src/inline_snapshot/_new_adapter.py | 47 +++++-------------- .../_snapshot/generic_value.py | 3 -- tests/external/test_external.py | 12 ++--- 13 files changed, 70 insertions(+), 97 deletions(-) diff --git a/docs/external/outsource.md b/docs/external/outsource.md index 605528f6..4c50323f 100644 --- a/docs/external/outsource.md +++ b/docs/external/outsource.md @@ -51,7 +51,7 @@ def test_captcha(): "size": "200x100", "difficulty": 8, "picture": external( - "uuid:e3e70682-c209-4cac-a29f-6fbed82c07cd.png" + "uuid:f728b4fa-4248-4e3a-8a5d-2f346baa9455.png" ), } ) diff --git a/src/inline_snapshot/_change.py b/src/inline_snapshot/_change.py index fcd5317b..e63421fc 100644 --- a/src/inline_snapshot/_change.py +++ b/src/inline_snapshot/_change.py @@ -119,6 +119,8 @@ def apply_external_changes(self): @dataclass() class RequiredImports(Change): + # module:str + # name:Optional[str] imports: dict[str, set[str]] module_imports: set[str] = dataclasses.field(default_factory=set) diff --git a/src/inline_snapshot/_customize/_custom.py b/src/inline_snapshot/_customize/_custom.py index bb6d8a2f..fe6a3575 100644 --- a/src/inline_snapshot/_customize/_custom.py +++ b/src/inline_snapshot/_customize/_custom.py @@ -40,6 +40,3 @@ def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str def _eval(self): return self._map(lambda a: a) - - def _needed_imports(self): - yield from () diff --git a/src/inline_snapshot/_customize/_custom_call.py b/src/inline_snapshot/_customize/_custom_call.py index 91a8fdf9..528b43d6 100644 --- a/src/inline_snapshot/_customize/_custom_call.py +++ b/src/inline_snapshot/_customize/_custom_call.py @@ -23,9 +23,6 @@ def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str def _map(self, f): return self.value._map(f) - def _needed_imports(self): - yield from self.value._needed_imports() - def unwrap_default(value): if isinstance(value, CustomDefault): @@ -77,14 +74,3 @@ def _map(self, f): *[f(x._map(f)) for x in self._args], **{k: f(v._map(f)) for k, v in self.kwargs.items()}, ) - - def _needed_imports(self): - yield from self._function._needed_imports() - for v in self._args: - yield from v._needed_imports() - - for v in self._kwargs.values(): - yield from v._needed_imports() - - for v in self._kwonly.values(): - yield from v._needed_imports() diff --git a/src/inline_snapshot/_customize/_custom_code.py b/src/inline_snapshot/_customize/_custom_code.py index 75b1cc78..360cd5d3 100644 --- a/src/inline_snapshot/_customize/_custom_code.py +++ b/src/inline_snapshot/_customize/_custom_code.py @@ -8,6 +8,7 @@ from inline_snapshot._adapter_context import AdapterContext from inline_snapshot._change import ChangeBase +from inline_snapshot._change import RequiredImports from inline_snapshot._code_repr import HasRepr from inline_snapshot._code_repr import value_code_repr from inline_snapshot._utils import clone @@ -28,10 +29,14 @@ def _simplify_module_path(module: str, name: str) -> str: class CustomCode(Custom): + _imports: list[tuple[str, str]] + _module_imports: list[str] + def __init__(self, value, repr_str=None): assert not isinstance(value, Custom) value = clone(value) self._imports = [] + self._module_imports = [] if repr_str is None: self.repr_str = value_code_repr(value) @@ -52,15 +57,22 @@ def _map(self, f): return f(self.value) def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: - yield from () + file = context.file if context else None + + for module in self._module_imports: + yield RequiredImports( + flag="fix", file=file, imports={}, module_imports=[module] + ) + for module, name in self._imports: + yield RequiredImports( + flag="fix", file=file, imports={module: {name}}, module_imports=[] + ) + return self.repr_str def __repr__(self): return f"CustomValue({self.repr_str})" - def _needed_imports(self): - yield from self._imports - def with_import_from(self, module: str, name: str, simplify: bool = True) -> Self: """ Adds a `from module import name` statement to the generated code. @@ -104,6 +116,6 @@ def with_import(self, module: str) -> Self: builder.create_code(my_obj, "os.path.join('a', 'b')").with_import("os.path") ``` """ - self._imports.append((module,)) + self._module_imports.append(module) return self diff --git a/src/inline_snapshot/_customize/_custom_dict.py b/src/inline_snapshot/_customize/_custom_dict.py index f0f28874..43274b10 100644 --- a/src/inline_snapshot/_customize/_custom_dict.py +++ b/src/inline_snapshot/_customize/_custom_dict.py @@ -27,8 +27,3 @@ def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str values.append(f"{key}: {value}") return f"{{{ ', '.join(values)}}}" - - def _needed_imports(self): - for k, v in self.value.items(): - yield from k._needed_imports() - yield from v._needed_imports() diff --git a/src/inline_snapshot/_customize/_custom_external.py b/src/inline_snapshot/_customize/_custom_external.py index ceed5906..88c35013 100644 --- a/src/inline_snapshot/_customize/_custom_external.py +++ b/src/inline_snapshot/_customize/_custom_external.py @@ -8,6 +8,7 @@ from inline_snapshot._adapter_context import AdapterContext from inline_snapshot._change import ChangeBase from inline_snapshot._change import ExternalChange +from inline_snapshot._change import RequiredImports from inline_snapshot._external._external_location import ExternalLocation from inline_snapshot._external._format._protocol import get_format_handler @@ -52,8 +53,8 @@ def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str location, format, ) + yield RequiredImports( + "create", context.file, {"inline_snapshot": ["external"]}, [] + ) return f"external({location.to_str()!r})" - - def _needed_imports(self): - return [("inline_snapshot", "external")] diff --git a/src/inline_snapshot/_customize/_custom_sequence.py b/src/inline_snapshot/_customize/_custom_sequence.py index 06389648..98252ae2 100644 --- a/src/inline_snapshot/_customize/_custom_sequence.py +++ b/src/inline_snapshot/_customize/_custom_sequence.py @@ -33,10 +33,6 @@ def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str trailing_comma = self.trailing_comma and len(self.value) == 1 return f"{self.braces[0]}{', '.join(values)}{', ' if trailing_comma else ''}{self.braces[1]}" - def _needed_imports(self): - for v in self.value: - yield from v._needed_imports() - class CustomList(CustomSequence): node_type = ast.List diff --git a/src/inline_snapshot/_generator_utils.py b/src/inline_snapshot/_generator_utils.py index bc094d88..5119fc82 100644 --- a/src/inline_snapshot/_generator_utils.py +++ b/src/inline_snapshot/_generator_utils.py @@ -15,3 +15,28 @@ def split_gen(gen): def only_value(gen): return split_gen(gen).value + + +def gen_map(stream, f): + stream = iter(stream) + while True: + try: + yield f(next(stream)) + except StopIteration as stop: + return stop.value + + +def with_flag(stream, flag): + def map(change): + change.flag = flag + return change + + return gen_map(stream, map) + + +def make_gen_map(f): + def m(stream): + + return gen_map(stream, f) + + return m diff --git a/src/inline_snapshot/_inline_snapshot.py b/src/inline_snapshot/_inline_snapshot.py index 7c6fc7aa..441903fb 100644 --- a/src/inline_snapshot/_inline_snapshot.py +++ b/src/inline_snapshot/_inline_snapshot.py @@ -1,6 +1,5 @@ import ast import inspect -from collections import defaultdict from typing import Any from typing import Iterator from typing import TypeVar @@ -11,12 +10,12 @@ from inline_snapshot._adapter_context import AdapterContext from inline_snapshot._adapter_context import FrameContext from inline_snapshot._customize._custom_undefined import CustomUndefined +from inline_snapshot._generator_utils import with_flag from inline_snapshot._source_file import SourceFile from inline_snapshot._types import SnapshotRefBase from ._change import CallArg from ._change import ChangeBase -from ._change import RequiredImports from ._global_state import state from ._sentinels import undefined from ._snapshot.undecided_value import UndecidedValue @@ -139,7 +138,7 @@ def _changes(self) -> Iterator[ChangeBase]: if isinstance(self._value._new_value, CustomUndefined): return - new_code = yield from self._value._new_code() + new_code = yield from with_flag(self._value._new_code(), "create") yield CallArg( flag="create", @@ -151,22 +150,6 @@ def _changes(self) -> Iterator[ChangeBase]: new_value=self._value._new_value, ) - imports: dict[str, set[str]] = defaultdict(set) - module_imports: set[str] = set() - for import_info in self._value._needed_imports(): - if len(import_info) == 2: - module, name = import_info - imports[module].add(name) - elif len(import_info) == 1: - module_imports.add(import_info[0]) - - yield RequiredImports( - flag="create", - file=self._value._file, - imports=imports, - module_imports=module_imports, - ) - else: yield from self._value._get_changes() diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index d28d1b56..db5a2c6d 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -15,7 +15,6 @@ from inline_snapshot._change import DictInsert from inline_snapshot._change import ListInsert from inline_snapshot._change import Replace -from inline_snapshot._change import RequiredImports from inline_snapshot._compare_context import compare_context from inline_snapshot._customize._custom import Custom from inline_snapshot._customize._custom_call import CustomCall @@ -27,7 +26,9 @@ from inline_snapshot._customize._custom_undefined import CustomUndefined from inline_snapshot._customize._custom_unmanaged import CustomUnmanaged from inline_snapshot._exceptions import UsageError +from inline_snapshot._generator_utils import make_gen_map from inline_snapshot._generator_utils import only_value +from inline_snapshot._generator_utils import split_gen from inline_snapshot.syntax_warnings import InlineSnapshotInfo from inline_snapshot.syntax_warnings import InlineSnapshotSyntaxWarning @@ -179,10 +180,7 @@ def compare_CustomCode( assert isinstance(new_value, Custom) assert isinstance(old_node, (ast.expr, type(None))), old_node - if old_node is None: - new_code = "" - else: - new_code = yield from new_value._code_repr(self.context) + new_code, new_changes = split_gen(new_value._code_repr(self.context)) if ( isinstance(old_node, ast.JoinedStr) @@ -213,6 +211,10 @@ def compare_CustomCode( # equal and equal repr return old_value + for change in new_changes: + change.flag = flag + yield change + yield Replace( node=old_node, file=self.context.file, @@ -222,21 +224,6 @@ def compare_CustomCode( new_value=new_value, ) - def needed_imports(value: Custom): - imports: dict[str, set] = defaultdict(set) - module_imports: set[str] = set() - for import_info in value._needed_imports(): - if len(import_info) == 2: - module, name = import_info - imports[module].add(name) - elif len(import_info) == 1: - module_imports.add(import_info[0]) - return imports, module_imports - - imports, module_imports = needed_imports(new_value) - if imports or module_imports: - yield RequiredImports(flag, self.context.file, imports, module_imports) - return new_value def compare_CustomSequence( @@ -408,20 +395,12 @@ def compare_CustomCall( flag = "update" if old_value._eval() == new_value.original_value else "fix" - if flag == "update": - - def intercept(stream): - while True: - try: - change = next(stream) - if change.flag == "fix": - change.flag = "update" - yield change - except StopIteration as stop: - return stop.value - - else: - intercept = lambda a: a + @make_gen_map + def intercept(change): + if flag == "update": + if change.flag == "fix": + change.flag = "update" + return change old_node_args: Sequence[ast.expr | None] if old_node: diff --git a/src/inline_snapshot/_snapshot/generic_value.py b/src/inline_snapshot/_snapshot/generic_value.py index efa3de5a..01f17ba5 100644 --- a/src/inline_snapshot/_snapshot/generic_value.py +++ b/src/inline_snapshot/_snapshot/generic_value.py @@ -77,9 +77,6 @@ def _ignore_old(self): or isinstance(self._old_value, CustomUndefined) ) - def _needed_imports(self): - yield from self._new_value._needed_imports() - def _visible_value(self): if self._ignore_old(): return self._new_value diff --git a/tests/external/test_external.py b/tests/external/test_external.py index 1a1302b4..f197d986 100644 --- a/tests/external/test_external.py +++ b/tests/external/test_external.py @@ -19,7 +19,7 @@ def test_basic(check_update): assert check_update( "assert outsource('text') == snapshot()", flags="create" ) == snapshot( - "assert outsource('text') == snapshot(external(\"uuid:e3e70682-c209-4cac-a29f-6fbed82c07cd.txt\"))" + "assert outsource('text') == snapshot(external(\"uuid:f728b4fa-4248-4e3a-8a5d-2f346baa9455.txt\"))" ) @@ -272,7 +272,7 @@ def test_a(): from inline_snapshot import external def test_a(): - assert outsource("test") == snapshot(external("uuid:e3e70682-c209-4cac-a29f-6fbed82c07cd.txt")) + assert outsource("test") == snapshot(external("uuid:f728b4fa-4248-4e3a-8a5d-2f346baa9455.txt")) """ ) @@ -478,7 +478,7 @@ def test_something(): from inline_snapshot import external def test_something(): from inline_snapshot import outsource,snapshot - assert outsource("test") == snapshot(external("uuid:e3e70682-c209-4cac-a29f-6fbed82c07cd.txt")) + assert outsource("test") == snapshot(external("uuid:f728b4fa-4248-4e3a-8a5d-2f346baa9455.txt")) test_something() \ """ @@ -626,15 +626,15 @@ def test_something(): ["--inline-snapshot=create"], changed_files=snapshot( { + "tests/__inline_snapshot__/test_something/test_something/eb1167b3-67a9-4378-bc65-c1e582e2e662.txt": "foo", "tests/__inline_snapshot__/test_something/test_something/f728b4fa-4248-4e3a-8a5d-2f346baa9455.txt": "foo", - "tests/__inline_snapshot__/test_something/test_something/e3e70682-c209-4cac-a29f-6fbed82c07cd.txt": "foo", "tests/test_something.py": """\ from inline_snapshot import external, snapshot,outsource def test_something(): - assert outsource("foo") == snapshot(external("uuid:f728b4fa-4248-4e3a-8a5d-2f346baa9455.txt")) - assert "foo" == external("uuid:e3e70682-c209-4cac-a29f-6fbed82c07cd.txt") + assert outsource("foo") == snapshot(external("uuid:eb1167b3-67a9-4378-bc65-c1e582e2e662.txt")) + assert "foo" == external("uuid:f728b4fa-4248-4e3a-8a5d-2f346baa9455.txt") """, } ), From 545af71ab5a8ec42b44d1ee8a024260cc4e49c92 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sat, 17 Jan 2026 22:12:52 +0100 Subject: [PATCH 54/72] refactor: simplified RequiredImport --- src/inline_snapshot/_change.py | 19 ++++++------- .../_customize/_custom_code.py | 10 +++---- .../_customize/_custom_external.py | 6 ++--- .../_external/_find_external.py | 27 +++++++++---------- 4 files changed, 26 insertions(+), 36 deletions(-) diff --git a/src/inline_snapshot/_change.py b/src/inline_snapshot/_change.py index e63421fc..bf9d4c52 100644 --- a/src/inline_snapshot/_change.py +++ b/src/inline_snapshot/_change.py @@ -1,7 +1,6 @@ from __future__ import annotations import ast -import dataclasses from collections import defaultdict from dataclasses import dataclass from pathlib import Path @@ -118,11 +117,9 @@ def apply_external_changes(self): @dataclass() -class RequiredImports(Change): - # module:str - # name:Optional[str] - imports: dict[str, set[str]] - module_imports: set[str] = dataclasses.field(default_factory=set) +class RequiredImport(Change): + module: str + name: str | None = None @dataclass() @@ -295,11 +292,11 @@ def apply_all(all_changes: list[ChangeBase], recorder: ChangeRecorder): node = cast(EnhancedAST, change.node) by_parent[node].append(change) sources[node] = change.file - elif isinstance(change, RequiredImports): - for module, names in change.imports.items(): - imports_by_file[change.file.filename][module] |= set(names) - for module in change.module_imports: - module_imports_by_file[change.file.filename].add(module) + elif isinstance(change, RequiredImport): + if change.name: + imports_by_file[change.file.filename][change.module].add(change.name) + else: + module_imports_by_file[change.file.filename].add(change.module) else: change.apply(recorder) diff --git a/src/inline_snapshot/_customize/_custom_code.py b/src/inline_snapshot/_customize/_custom_code.py index 360cd5d3..80874e49 100644 --- a/src/inline_snapshot/_customize/_custom_code.py +++ b/src/inline_snapshot/_customize/_custom_code.py @@ -8,7 +8,7 @@ from inline_snapshot._adapter_context import AdapterContext from inline_snapshot._change import ChangeBase -from inline_snapshot._change import RequiredImports +from inline_snapshot._change import RequiredImport from inline_snapshot._code_repr import HasRepr from inline_snapshot._code_repr import value_code_repr from inline_snapshot._utils import clone @@ -60,13 +60,9 @@ def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str file = context.file if context else None for module in self._module_imports: - yield RequiredImports( - flag="fix", file=file, imports={}, module_imports=[module] - ) + yield RequiredImport(flag="fix", file=file, module=module) for module, name in self._imports: - yield RequiredImports( - flag="fix", file=file, imports={module: {name}}, module_imports=[] - ) + yield RequiredImport(flag="fix", file=file, module=module, name=name) return self.repr_str diff --git a/src/inline_snapshot/_customize/_custom_external.py b/src/inline_snapshot/_customize/_custom_external.py index 88c35013..069fdcea 100644 --- a/src/inline_snapshot/_customize/_custom_external.py +++ b/src/inline_snapshot/_customize/_custom_external.py @@ -8,7 +8,7 @@ from inline_snapshot._adapter_context import AdapterContext from inline_snapshot._change import ChangeBase from inline_snapshot._change import ExternalChange -from inline_snapshot._change import RequiredImports +from inline_snapshot._change import RequiredImport from inline_snapshot._external._external_location import ExternalLocation from inline_snapshot._external._format._protocol import get_format_handler @@ -53,8 +53,6 @@ def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str location, format, ) - yield RequiredImports( - "create", context.file, {"inline_snapshot": ["external"]}, [] - ) + yield RequiredImport("create", context.file, "inline_snapshot", "external") return f"external({location.to_str()!r})" diff --git a/src/inline_snapshot/_external/_find_external.py b/src/inline_snapshot/_external/_find_external.py index 280883e8..cadae7ac 100644 --- a/src/inline_snapshot/_external/_find_external.py +++ b/src/inline_snapshot/_external/_find_external.py @@ -2,8 +2,10 @@ import os from dataclasses import replace from pathlib import Path +from typing import Dict from typing import List from typing import Optional +from typing import Set from typing import Union from executing import Source @@ -106,7 +108,12 @@ def module_name_of(filename: str | os.PathLike) -> Optional[str]: return ".".join(parts) -def ensure_import(filename, imports, module_imports, recorder: ChangeRecorder): +def ensure_import( + filename, + imports: Dict[str, Set[str]], + module_imports: Set[str], + recorder: ChangeRecorder, +): source = Source.for_filename(filename) change = recorder.new_change() @@ -114,11 +121,9 @@ def ensure_import(filename, imports, module_imports, recorder: ChangeRecorder): tree = source.tree token = source.asttokens() - to_add = [] - modules_to_add = [] - my_module = module_name_of(filename) + code = "" for module, names in imports.items(): if module == my_module: continue @@ -126,7 +131,7 @@ def ensure_import(filename, imports, module_imports, recorder: ChangeRecorder): continue for name in sorted(names): if not contains_import(tree, module, name): - to_add.append((module, name)) + code += f"from {module} import {name}\n" for module in sorted(module_imports): if module == my_module: @@ -134,10 +139,11 @@ def ensure_import(filename, imports, module_imports, recorder: ChangeRecorder): if module == "builtins": continue if not contains_module_import(tree, module): - modules_to_add.append(module) + code += f"import {module}\n" assert isinstance(tree, ast.Module) + # find source position last_import = None for node in tree.body: if not ( @@ -163,13 +169,6 @@ def ensure_import(filename, imports, module_imports, recorder: ChangeRecorder): break position = end_of(last_token) - code = "" - for module in modules_to_add: - code += f"import {module}\n" - for module, name in to_add: - code += f"from {module} import {name}\n" - if code: code = "\n" + code - - change.insert(position, code, filename=filename) + change.insert(position, code, filename=filename) From ba7726ac87e536cbce1d3032bc7dd4fc0571a28c Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sun, 18 Jan 2026 17:24:23 +0100 Subject: [PATCH 55/72] fix: typeing fixes --- src/inline_snapshot/_adapter_context.py | 3 ++- src/inline_snapshot/_code_repr.py | 5 ++++- src/inline_snapshot/_customize/_custom_code.py | 10 +++++----- src/inline_snapshot/_external/_find_external.py | 2 +- src/inline_snapshot/_snapshot/collection_value.py | 3 ++- 5 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/inline_snapshot/_adapter_context.py b/src/inline_snapshot/_adapter_context.py index 9de8cc55..e95cf3c3 100644 --- a/src/inline_snapshot/_adapter_context.py +++ b/src/inline_snapshot/_adapter_context.py @@ -1,5 +1,6 @@ import ast from dataclasses import dataclass +from typing import Optional from inline_snapshot._source_file import SourceFile @@ -13,7 +14,7 @@ class FrameContext: @dataclass class AdapterContext: file: SourceFile - frame: FrameContext | None + frame: Optional[FrameContext] qualname: str def eval(self, node): diff --git a/src/inline_snapshot/_code_repr.py b/src/inline_snapshot/_code_repr.py index 02c52104..09df8105 100644 --- a/src/inline_snapshot/_code_repr.py +++ b/src/inline_snapshot/_code_repr.py @@ -78,7 +78,10 @@ def _(obj: MyCustomClass): def code_repr(obj): - with mock_repr(None): + from inline_snapshot._adapter_context import AdapterContext + + context = AdapterContext(None, None, "") + with mock_repr(context): return repr(obj) diff --git a/src/inline_snapshot/_customize/_custom_code.py b/src/inline_snapshot/_customize/_custom_code.py index 80874e49..dda78fc8 100644 --- a/src/inline_snapshot/_customize/_custom_code.py +++ b/src/inline_snapshot/_customize/_custom_code.py @@ -57,12 +57,12 @@ def _map(self, f): return f(self.value) def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: - file = context.file if context else None - for module in self._module_imports: - yield RequiredImport(flag="fix", file=file, module=module) + yield RequiredImport(flag="fix", file=context.file, module=module) for module, name in self._imports: - yield RequiredImport(flag="fix", file=file, module=module, name=name) + yield RequiredImport( + flag="fix", file=context.file, module=module, name=name + ) return self.repr_str @@ -93,7 +93,7 @@ def with_import_from(self, module: str, name: str, simplify: bool = True) -> Sel name = name.split("[")[0] if simplify: module = _simplify_module_path(module, name) - self._imports.append([module, name]) + self._imports.append((module, name)) return self diff --git a/src/inline_snapshot/_external/_find_external.py b/src/inline_snapshot/_external/_find_external.py index cadae7ac..e3f928c5 100644 --- a/src/inline_snapshot/_external/_find_external.py +++ b/src/inline_snapshot/_external/_find_external.py @@ -76,7 +76,7 @@ def used_externals_in( return usages -def module_name_of(filename: str | os.PathLike) -> Optional[str]: +def module_name_of(filename: Union[str, os.PathLike]) -> Optional[str]: path = Path(filename).resolve() if path.suffix != ".py": diff --git a/src/inline_snapshot/_snapshot/collection_value.py b/src/inline_snapshot/_snapshot/collection_value.py index 681fe7dd..360c452f 100644 --- a/src/inline_snapshot/_snapshot/collection_value.py +++ b/src/inline_snapshot/_snapshot/collection_value.py @@ -1,6 +1,7 @@ import ast from typing import Generator from typing import Iterator +from typing import Union from inline_snapshot._customize._custom_sequence import CustomList from inline_snapshot._customize._custom_undefined import CustomUndefined @@ -16,7 +17,7 @@ class CollectionValue(GenericValue): _current_op = "x in snapshot" - _ast_node: ast.List | ast.Tuple + _ast_node: Union[ast.List, ast.Tuple] _new_value: CustomList def __contains__(self, item): From ae3ebd075cbf47057a5f2eb35a742b7239b1e6e6 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sun, 18 Jan 2026 22:08:48 +0100 Subject: [PATCH 56/72] test: fixed docs --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1cf625e2..35739be6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ dev = [ "coverage-enable-subprocess>=1.0", "attrs>=24.3.0", "pydantic>=1", - "black==25.1.0" + "black==25.1.0", "isort" ] From d1092a0d664824aa4fccbfab093613904f202e9e Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Mon, 19 Jan 2026 21:32:55 +0100 Subject: [PATCH 57/72] test: coverage --- src/inline_snapshot/_compare_context.py | 6 ++-- .../_customize/_custom_unmanaged.py | 6 ++-- src/inline_snapshot/_global_state.py | 4 +-- src/inline_snapshot/_utils.py | 8 ----- tests/external/test_external.py | 5 +-- tests/test_builder.py | 34 +++++++++++++++++++ 6 files changed, 47 insertions(+), 16 deletions(-) create mode 100644 tests/test_builder.py diff --git a/src/inline_snapshot/_compare_context.py b/src/inline_snapshot/_compare_context.py index 104a235c..c70bcdcd 100644 --- a/src/inline_snapshot/_compare_context.py +++ b/src/inline_snapshot/_compare_context.py @@ -13,5 +13,7 @@ def compare_context(): global _eq_check_only old_eq_only = _eq_check_only _eq_check_only = True - yield - _eq_check_only = old_eq_only + try: + yield + finally: + _eq_check_only = old_eq_only diff --git a/src/inline_snapshot/_customize/_custom_unmanaged.py b/src/inline_snapshot/_customize/_custom_unmanaged.py index f85acb8b..62b63e54 100644 --- a/src/inline_snapshot/_customize/_custom_unmanaged.py +++ b/src/inline_snapshot/_customize/_custom_unmanaged.py @@ -14,8 +14,10 @@ class CustomUnmanaged(Custom): value: Any - def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: - yield from () # pragma: no cover + def _code_repr( + self, context: AdapterContext + ) -> Generator[ChangeBase, None, str]: # pragma: no cover + yield from () return "'unmanaged'" def _map(self, f): diff --git a/src/inline_snapshot/_global_state.py b/src/inline_snapshot/_global_state.py index fe2c67ef..89053f3d 100644 --- a/src/inline_snapshot/_global_state.py +++ b/src/inline_snapshot/_global_state.py @@ -92,14 +92,14 @@ def enter_snapshot_context(): try: from .plugin._default_plugin import InlineSnapshotAttrsPlugin - except ImportError: + except ImportError: # pragma: no cover pass else: _current.pm.register(InlineSnapshotAttrsPlugin()) try: from .plugin._default_plugin import InlineSnapshotPydanticPlugin - except ImportError: + except ImportError: # pragma: no cover pass else: _current.pm.register(InlineSnapshotPydanticPlugin()) diff --git a/src/inline_snapshot/_utils.py b/src/inline_snapshot/_utils.py index 48d4f1c2..0ba40271 100644 --- a/src/inline_snapshot/_utils.py +++ b/src/inline_snapshot/_utils.py @@ -164,11 +164,3 @@ def clone(obj): """ ) return new - - -def clone_if_equal(obj): - new = copy.deepcopy(obj) - if obj == new: - return new - else: - return obj diff --git a/tests/external/test_external.py b/tests/external/test_external.py index f197d986..6a78578a 100644 --- a/tests/external/test_external.py +++ b/tests/external/test_external.py @@ -72,8 +72,9 @@ def test_a(): def test_compare_outsource(): - assert outsource("one") == outsource("one") - assert outsource("one") != outsource("two") + with snapshot_env(): + assert outsource("one") == outsource("one") + assert outsource("one") != outsource("two") def test_hash_collision(): diff --git a/tests/test_builder.py b/tests/test_builder.py new file mode 100644 index 00000000..b2a5f7ba --- /dev/null +++ b/tests/test_builder.py @@ -0,0 +1,34 @@ +from inline_snapshot import snapshot +from inline_snapshot.testing._example import Example + + +def test_default_arg(): + + e = Example( + { + "conftest.py": """\ +from inline_snapshot.plugin import customize + +class InlineSnapshotPlugin: + @customize + def handler(self,value,builder): + if value==5: + return builder.with_default(5,builder.create_code(8)) +""", + "test_a.py": """\ +from inline_snapshot import snapshot +def test_a(): + assert 5==snapshot() +""", + } + ) + + e.run_inline( + ["--inline-snapshot=create"], + raises=snapshot( + """\ +UsageError: +default value can not be an Custom value\ +""" + ), + ) From 38389097e4cbb00e8c2521225ff16972196a0779 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Tue, 20 Jan 2026 09:21:09 +0100 Subject: [PATCH 58/72] test: coverage --- pyproject.toml | 1 + .../_external/_find_external.py | 17 +++--- tests/external/test_module_name.py | 59 +++++++++++++++++++ 3 files changed, 69 insertions(+), 8 deletions(-) create mode 100644 tests/external/test_module_name.py diff --git a/pyproject.toml b/pyproject.toml index 35739be6..6cbb6392 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,6 +96,7 @@ exclude_lines = [ "if is_insider", "\\.\\.\\." ] +ignore_errors=true [tool.coverage.run] diff --git a/src/inline_snapshot/_external/_find_external.py b/src/inline_snapshot/_external/_find_external.py index e3f928c5..95d43d5d 100644 --- a/src/inline_snapshot/_external/_find_external.py +++ b/src/inline_snapshot/_external/_find_external.py @@ -79,8 +79,7 @@ def used_externals_in( def module_name_of(filename: Union[str, os.PathLike]) -> Optional[str]: path = Path(filename).resolve() - if path.suffix != ".py": - return None + assert path.suffix == ".py" parts = [] @@ -97,13 +96,14 @@ def module_name_of(filename: Union[str, os.PathLike]) -> Optional[str]: next_parent = current.parent if next_parent == current: - break + break # pragma: no cover current = next_parent + else: + pass # pragma: no cover parts.reverse() - if not parts: - return None + assert parts return ".".join(parts) @@ -114,6 +114,7 @@ def ensure_import( module_imports: Set[str], recorder: ChangeRecorder, ): + print("file", filename) source = Source.for_filename(filename) change = recorder.new_change() @@ -121,11 +122,11 @@ def ensure_import( tree = source.tree token = source.asttokens() - my_module = module_name_of(filename) + my_module_name = module_name_of(filename) code = "" for module, names in imports.items(): - if module == my_module: + if module == my_module_name: continue if module == "builtins": continue @@ -134,7 +135,7 @@ def ensure_import( code += f"from {module} import {name}\n" for module in sorted(module_imports): - if module == my_module: + if module == my_module_name: continue if module == "builtins": continue diff --git a/tests/external/test_module_name.py b/tests/external/test_module_name.py new file mode 100644 index 00000000..78d84aee --- /dev/null +++ b/tests/external/test_module_name.py @@ -0,0 +1,59 @@ +"""Tests for module_name_of function coverage.""" + +from inline_snapshot import snapshot +from inline_snapshot.testing import Example + + +def test_module_name_init_py_with_snapshot(): + """Test module_name_of when __init__.py itself contains the snapshot call.""" + Example( + { + "tests/__init__.py": "", + "tests/mypackage/b.py": """\ +from dataclasses import dataclass + +@dataclass +class B: + b:int +""", + "tests/mypackage/__init__.py": """\ +from inline_snapshot import snapshot +from dataclasses import dataclass + +@dataclass +class A: + a:int + +s = snapshot() +""", + "tests/test_something.py": """\ +from inline_snapshot import snapshot + +def test_a(): + from .mypackage import s,A + from .mypackage.b import B + assert s == [A(5),B(5)] +""", + } + ).run_pytest( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "tests/mypackage/__init__.py": """\ +from inline_snapshot import snapshot +from dataclasses import dataclass + +from tests.mypackage.b import B + +@dataclass +class A: + a:int + +s = snapshot([A(a=5), B(b=5)]) +""" + } + ), + returncode=1, + ).run_pytest( + ["--inline-snapshot=disable"] + ) From 29351b2262301eebb010bfd8c96ca00fbebe4b98 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Tue, 20 Jan 2026 13:08:23 +0100 Subject: [PATCH 59/72] test: coverage --- src/inline_snapshot/_external/_find_external.py | 8 +++----- tests/test_customize.py | 10 ++++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/inline_snapshot/_external/_find_external.py b/src/inline_snapshot/_external/_find_external.py index 95d43d5d..93011a27 100644 --- a/src/inline_snapshot/_external/_find_external.py +++ b/src/inline_snapshot/_external/_find_external.py @@ -31,7 +31,9 @@ def contains_import(tree, module, name): def contains_module_import(tree, module): for node in tree.body: if isinstance(node, ast.Import): - if any(alias.name == module for alias in node.names): + if any( + alias.name == module and alias.asname is None for alias in node.names + ): return True return False @@ -135,10 +137,6 @@ def ensure_import( code += f"from {module} import {name}\n" for module in sorted(module_imports): - if module == my_module_name: - continue - if module == "builtins": - continue if not contains_module_import(tree, module): code += f"import {module}\n" diff --git a/tests/test_customize.py b/tests/test_customize.py index 12c89a22..551142d7 100644 --- a/tests/test_customize.py +++ b/tests/test_customize.py @@ -156,7 +156,8 @@ def test_a(): @pytest.mark.parametrize( "original,flag", [("MyClass('value')", "update"), ("'wrong'", "fix")] ) -def test_with_import_preserves_existing(original, flag): +@pytest.mark.parametrize("existing_import", ["\nimport mymodule\n", ""]) +def test_with_import_preserves_existing(original, flag, existing_import): """Test that with_import preserves existing import statements.""" Example( @@ -187,8 +188,8 @@ def __eq__(self, other): from inline_snapshot import snapshot from mymodule import MyClass -import mymodule -import os +import os # just another import +{existing_import}\ def test_a(): assert snapshot({original}) == MyClass("value") @@ -202,8 +203,9 @@ def test_a(): from inline_snapshot import snapshot from mymodule import MyClass +import os # just another import + import mymodule -import os def test_a(): assert snapshot(mymodule.MyClass("value")) == MyClass("value") From b590ce54577c84e9b99090a9722c5a39fa74ea44 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Wed, 21 Jan 2026 07:58:13 +0100 Subject: [PATCH 60/72] refactor: removed value argument of builder.create_code(...) --- docs/customize_repr.md | 2 +- docs/plugin.md | 13 +- src/inline_snapshot/_customize/_builder.py | 125 ++++++++++++++---- .../_customize/_custom_code.py | 91 +++++-------- src/inline_snapshot/plugin/__init__.py | 6 +- src/inline_snapshot/plugin/_default_plugin.py | 33 ++--- src/inline_snapshot/plugin/_spec.py | 8 +- tests/conftest.py | 4 +- tests/test_builder.py | 2 +- tests/test_customize.py | 50 +++++-- tests/test_docs.py | 4 +- 11 files changed, 212 insertions(+), 126 deletions(-) diff --git a/docs/customize_repr.md b/docs/customize_repr.md index 99c799b0..17dea008 100644 --- a/docs/customize_repr.md +++ b/docs/customize_repr.md @@ -8,7 +8,7 @@ @customize def my_class_handler(value, builder): if isinstance(value, MyClass): - return builder.create_code(value, "my_class_repr") + return builder.create_code("my_class_repr") ``` instead of diff --git a/docs/plugin.md b/docs/plugin.md index da9b89f1..e5f142e1 100644 --- a/docs/plugin.md +++ b/docs/plugin.md @@ -245,7 +245,7 @@ class InlineSnapshotPlugin: def local_var_handler(self, value, builder, local_vars): for var_name, var_value in local_vars.items(): if var_name.startswith("v_") and var_value == value: - return builder.create_code(value, var_name) + return builder.create_code(var_name) ``` We check all local variables to see if they match our naming convention and are equal to the value that is part of our snapshot, and return the local variable if we find one that fits the criteria. @@ -314,7 +314,7 @@ What you can do now, instead of replacing `"some_other_secret"` with `secrets[1] ``` python title="conftest.py" from my_secrets import secrets -from inline_snapshot.plugin import customize, Builder +from inline_snapshot.plugin import customize, Builder, ImportFrom class InlineSnapshotPlugin: @@ -323,11 +323,12 @@ class InlineSnapshotPlugin: for i, secret in enumerate(secrets): if value == secret: return builder.create_code( - secret, f"secrets[{i}]" - ).with_import_from("my_secrets", "secrets") + f"secrets[{i}]", + imports=[ImportFrom("my_secrets", "secrets")], + ) ``` -The [`create_code()`][inline_snapshot.plugin.Builder.create_code] method takes the actual value and its desired code representation, then [`with_import()`][inline_snapshot.plugin.CustomCode.with_import] adds the necessary import statement. +The [`create_code()`][inline_snapshot.plugin.Builder.create_code] method takes the desired code representation. The `imports` parameter adds the necessary import statements. Inline-snapshot will now create the correct code and import statement when you run your tests with `--inline-snapshot=update`. @@ -365,7 +366,7 @@ def test_my_class(): ::: inline_snapshot.plugin options: heading_level: 3 - members: [hookimpl,customize,Builder,Custom,CustomCode] + members: [hookimpl,customize,Builder,Custom,Import,ImportFrom] show_root_heading: false show_bases: false show_source: false diff --git a/src/inline_snapshot/_customize/_builder.py b/src/inline_snapshot/_customize/_builder.py index f9e022cc..21a49cab 100644 --- a/src/inline_snapshot/_customize/_builder.py +++ b/src/inline_snapshot/_customize/_builder.py @@ -1,10 +1,14 @@ from __future__ import annotations from dataclasses import dataclass +from functools import cached_property from typing import Any from typing import Callable from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._sentinels import undefined + +missing = undefined from inline_snapshot._compare_context import compare_context from inline_snapshot._exceptions import UsageError @@ -12,6 +16,8 @@ from ._custom_call import CustomCall from ._custom_call import CustomDefault from ._custom_code import CustomCode +from ._custom_code import Import +from ._custom_code import ImportFrom from ._custom_dict import CustomDict from ._custom_external import CustomExternal from ._custom_sequence import CustomList @@ -27,24 +33,6 @@ def _get_handler(self, v) -> Custom: from inline_snapshot._global_state import state - if ( - self._snapshot_context is not None - and (frame := self._snapshot_context.frame) is not None - ): - local_vars = { - var_name: var_value - for var_name, var_value in frame.locals.items() - if "@" not in var_name - } - global_vars = { - var_name: var_value - for var_name, var_value in frame.globals.items() - if "@" not in var_name - } - else: - local_vars = {} - global_vars = {} - result = v while not isinstance(result, Custom): @@ -52,8 +40,8 @@ def _get_handler(self, v) -> Custom: r = state().pm.hook.customize( value=result, builder=self, - local_vars=local_vars, - global_vars=global_vars, + local_vars=self._get_local_vars, + global_vars=self._get_global_vars, ) if r is None: result = CustomCode(result) @@ -61,6 +49,13 @@ def _get_handler(self, v) -> Custom: result = r result.__dict__["original_value"] = v + + if not isinstance(v, Custom) and self._build_new_value: + if result._eval() != v: + raise UsageError( + f"Customized value does not match original value: {result._eval()!r} != {v!r}" + ) + return result def create_external( @@ -137,13 +132,95 @@ def create_dict(self, value: dict) -> Custom: custom = {self._get_handler(k): self._get_handler(v) for k, v in value.items()} return CustomDict(value=custom) - def create_code(self, value: Any, repr: str | None = None) -> CustomCode: + @cached_property + def _get_local_vars(self): + """Get local vars from snapshot context.""" + if ( + self._snapshot_context is not None + and (frame := self._snapshot_context.frame) is not None + ): + return { + var_name: var_value + for var_name, var_value in frame.locals.items() + if "@" not in var_name + } + return {} + + @cached_property + def _get_global_vars(self): + """Get global vars from snapshot context.""" + if ( + self._snapshot_context is not None + and (frame := self._snapshot_context.frame) is not None + ): + return { + var_name: var_value + for var_name, var_value in frame.globals.items() + if "@" not in var_name + } + return {} + + def _build_import_vars(self, imports): + """Build import vars from imports parameter.""" + import_vars = {} + if imports: + import importlib + + for imp in imports: + if isinstance(imp, Import): + # import module - makes top-level package available + importlib.import_module(imp.module) + top_level = imp.module.split(".")[0] + import_vars[top_level] = importlib.import_module(top_level) + elif isinstance(imp, ImportFrom): + # from module import name + module = importlib.import_module(imp.module) + import_vars[imp.name] = getattr(module, imp.name) + return import_vars + + def create_code( + self, code: str, *, imports: list[Import | ImportFrom] = [] + ) -> Custom: """ Creates an intermediate node for a value with a custom representation which can be used as a result for your customization function. - `create_code(value, '{value-1!r}+1')` becomes `4+1` in the code for a given `value=5`. + `create_code('{value-1!r}+1')` becomes `4+1` in the code. Use this when you need to control the exact string representation of a value. - You can use [`.with_import_from(module,name)`][inline_snapshot.plugin.CustomCode.with_import_from] to create an import in the code. + Arguments: + code: Custom string representation to evaluate. This is required and will be evaluated using the snapshot context. + imports: Optional list of Import and ImportFrom objects to add required imports to the generated code. + Example: `imports=[Import("os"), ImportFrom("pathlib", "Path")]` """ - return CustomCode(value, repr) + import_vars = None + + # Try direct variable lookup for simple identifiers (fastest) + if code.isidentifier(): + # Direct lookup with proper precedence: local > import > global + if code in self._get_local_vars: + return CustomCode(self._get_local_vars[code], code, imports) + + # Build import vars only if needed + import_vars = self._build_import_vars(imports) + if code in import_vars: + return CustomCode(import_vars[code], code, imports) + + if code in self._get_global_vars: + return CustomCode(self._get_global_vars[code], code, imports) + + # Try ast.literal_eval for simple literals (fast and safe) + try: + import ast + + return CustomCode(ast.literal_eval(code), code, imports) + except (ValueError, SyntaxError): + # Fall back to eval with context for complex expressions + # Build evaluation context with proper precedence: global < import < local + if import_vars is None: + import_vars = self._build_import_vars(imports) + eval_context = { + **self._get_global_vars, + **import_vars, + **self._get_local_vars, + } + return CustomCode(eval(code, eval_context), code, imports) diff --git a/src/inline_snapshot/_customize/_custom_code.py b/src/inline_snapshot/_customize/_custom_code.py index dda78fc8..c7238051 100644 --- a/src/inline_snapshot/_customize/_custom_code.py +++ b/src/inline_snapshot/_customize/_custom_code.py @@ -2,10 +2,9 @@ import ast import importlib +from dataclasses import dataclass from typing import Generator -from typing_extensions import Self - from inline_snapshot._adapter_context import AdapterContext from inline_snapshot._change import ChangeBase from inline_snapshot._change import RequiredImport @@ -16,6 +15,21 @@ from ._custom import Custom +@dataclass(frozen=True) +class Import: + """Represents an import statement: `import module`""" + + module: str + + +@dataclass(frozen=True) +class ImportFrom: + """Represents a from-import statement: `from module import name`""" + + module: str + name: str + + def _simplify_module_path(module: str, name: str) -> str: """Simplify module path by finding the shortest import path for a given name.""" value = getattr(importlib.import_module(module), name) @@ -29,14 +43,12 @@ def _simplify_module_path(module: str, name: str) -> str: class CustomCode(Custom): - _imports: list[tuple[str, str]] - _module_imports: list[str] + _imports: list[Import | ImportFrom] - def __init__(self, value, repr_str=None): + def __init__(self, value, repr_str=None, imports: list[Import | ImportFrom] = []): assert not isinstance(value, Custom) value = clone(value) - self._imports = [] - self._module_imports = [] + self._imports = list(imports) if repr_str is None: self.repr_str = value_code_repr(value) @@ -45,7 +57,7 @@ def __init__(self, value, repr_str=None): ast.parse(self.repr_str) except SyntaxError: self.repr_str = HasRepr(type(value), self.repr_str).__repr__() - self.with_import_from("inline_snapshot", "HasRepr") + self._imports.append(ImportFrom("inline_snapshot", "HasRepr")) else: self.repr_str = repr_str @@ -57,61 +69,18 @@ def _map(self, f): return f(self.value) def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: - for module in self._module_imports: - yield RequiredImport(flag="fix", file=context.file, module=module) - for module, name in self._imports: - yield RequiredImport( - flag="fix", file=context.file, module=module, name=name - ) + for imp in self._imports: + if isinstance(imp, Import): + yield RequiredImport(flag="fix", file=context.file, module=imp.module) + elif isinstance(imp, ImportFrom): + yield RequiredImport( + flag="fix", + file=context.file, + module=_simplify_module_path(imp.module, imp.name), + name=imp.name, + ) return self.repr_str def __repr__(self): return f"CustomValue({self.repr_str})" - - def with_import_from(self, module: str, name: str, simplify: bool = True) -> Self: - """ - Adds a `from module import name` statement to the generated code. - - Arguments: - module: The module path to import from (e.g., "my_module" or "package.submodule"). - name: The name to import from the module (e.g., "MyClass" or "my_function"). - simplify: If True (default), attempts to find the shortest valid import path - by checking parent modules. For example, if "package.submodule.MyClass" - is accessible from "package", it will use the shorter path. - - Returns: - The CustomCode instance itself, allowing for method chaining. - - Example: - ``` python - builder.create_code(my_obj, "secrets[0]").with_import_from( - "my_secrets", "secrets" - ) - ``` - """ - name = name.split("[")[0] - if simplify: - module = _simplify_module_path(module, name) - self._imports.append((module, name)) - - return self - - def with_import(self, module: str) -> Self: - """ - Adds an `import module` statement to the generated code. - - Arguments: - module: The module path to import (e.g., "os.path" or "collections.abc"). - - Returns: - The CustomCode instance itself, allowing for method chaining. - - Example: - ``` python - builder.create_code(my_obj, "os.path.join('a', 'b')").with_import("os.path") - ``` - """ - self._module_imports.append(module) - - return self diff --git a/src/inline_snapshot/plugin/__init__.py b/src/inline_snapshot/plugin/__init__.py index 091e33a5..65b8fc68 100644 --- a/src/inline_snapshot/plugin/__init__.py +++ b/src/inline_snapshot/plugin/__init__.py @@ -1,4 +1,5 @@ -from inline_snapshot._customize._custom_code import CustomCode +from inline_snapshot._customize._custom_code import Import +from inline_snapshot._customize._custom_code import ImportFrom from .._customize._builder import Builder from .._customize._custom import Custom @@ -12,5 +13,6 @@ "hookimpl", "Builder", "Custom", - "CustomCode", + "Import", + "ImportFrom", ) diff --git a/src/inline_snapshot/plugin/_default_plugin.py b/src/inline_snapshot/plugin/_default_plugin.py index 960db91d..e5db9996 100644 --- a/src/inline_snapshot/plugin/_default_plugin.py +++ b/src/inline_snapshot/plugin/_default_plugin.py @@ -14,6 +14,7 @@ from typing import Dict from inline_snapshot._customize._builder import Builder +from inline_snapshot._customize._custom_code import ImportFrom from inline_snapshot._customize._custom_undefined import CustomUndefined from inline_snapshot._customize._custom_unmanaged import CustomUnmanaged from inline_snapshot._external._outsource import Outsourced @@ -47,7 +48,7 @@ def string_handler(self, value, builder: Builder): assert ast.literal_eval(triple_quoted_string) == value - return builder.create_code(value, triple_quoted_string) + return builder.create_code(triple_quoted_string) @customize(tryfirst=True) def counter_handler(self, value, builder: Builder): @@ -61,18 +62,18 @@ def function_and_type_handler( if isinstance(value, (FunctionType, type)): for name, local_value in local_vars.items(): if local_value is value: - return builder.create_code(value, name) + return builder.create_code(name) qualname = value.__qualname__.split("[")[0] name = qualname.split(".")[0] - return builder.create_code(value, qualname).with_import_from( - value.__module__, name + return builder.create_code( + qualname, imports=[ImportFrom(value.__module__, name)] ) @customize def builtin_function_handler(self, value, builder: Builder): if isinstance(value, BuiltinFunctionType): - return builder.create_code(value, value.__name__) + return builder.create_code(value.__name__) @customize def path_handler(self, value, builder: Builder): @@ -100,17 +101,17 @@ def sort_set_values(self, set_values): def set_handler(self, value, builder: Builder): if isinstance(value, set): if len(value) == 0: - return builder.create_code(value, "set()") + return builder.create_code("set()") else: return builder.create_code( - value, "{" + ", ".join(self.sort_set_values(value)) + "}" + "{" + ", ".join(self.sort_set_values(value)) + "}" ) @customize def frozenset_handler(self, value, builder: Builder): if isinstance(value, frozenset): if len(value) == 0: - return builder.create_code(value, "frozenset()") + return builder.create_code("frozenset()") else: return builder.create_call(frozenset, [set(value)]) @@ -122,8 +123,9 @@ def enum_handler(self, value, builder: Builder): name = qualname.split(".")[0] return builder.create_code( - value, f"{type(value).__qualname__}.{value.name}" - ).with_import_from(type(value).__module__, name) + f"{type(value).__qualname__}.{value.name}", + imports=[ImportFrom(type(value).__module__, name)], + ) # -8<- [end:Enum] @@ -134,16 +136,16 @@ def flag_handler(self, value, builder: Builder): name = qualname.split(".")[0] return builder.create_code( - value, " | ".join( f"{qualname}.{flag.name}" for flag in type(value) if flag in value ), - ).with_import_from(type(value).__module__, name) + imports=[ImportFrom(type(value).__module__, name)], + ) @customize def source_file_name_handler(self, value, builder: Builder, global_vars): if "__file__" in global_vars and value == global_vars["__file__"]: - return builder.create_code(value, "__file__") + return builder.create_code("__file__") @customize def dataclass_handler(self, value, builder: Builder): @@ -219,8 +221,9 @@ def dirty_equals_handler(self, value, builder: Builder): if is_dirty_equal(value) and builder._build_new_value: if isinstance(value, type): - return builder.create_code(value, value.__name__).with_import_from( - "dirty_equals", value.__name__ + return builder.create_code( + value.__name__, + imports=[ImportFrom("dirty_equals", value.__name__)], ) else: from dirty_equals import IsNow diff --git a/src/inline_snapshot/plugin/_spec.py b/src/inline_snapshot/plugin/_spec.py index dad9ad89..c0b91b5a 100644 --- a/src/inline_snapshot/plugin/_spec.py +++ b/src/inline_snapshot/plugin/_spec.py @@ -68,12 +68,12 @@ class InlineSnapshotPlugin: @customize def binary_numbers(self, value, builder, local_vars, global_vars): if isinstance(value, int): - return builder.create_code(value, bin(value)) + return builder.create_code(bin(value)) @customize def repeated_strings(self, value, builder): if isinstance(value, str) and value == value[0] * len(value): - return builder.create_code(value, f"'{value[0]}'*{len(value)}") + return builder.create_code(f"'{value[0]}'*{len(value)}") ``` === "by method name" @@ -83,10 +83,10 @@ def repeated_strings(self, value, builder): class InlineSnapshotPlugin: def customize(self, value, builder, local_vars, global_vars): if isinstance(value, int): - return builder.create_code(value, bin(value)) + return builder.create_code(bin(value)) if isinstance(value, str) and value == value[0] * len(value): - return builder.create_code(value, f"'{value[0]}'*{len(value)}") + return builder.create_code(f"'{value[0]}'*{len(value)}") ``` diff --git a/tests/conftest.py b/tests/conftest.py index c3a70545..dd85943a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -317,12 +317,12 @@ class InlineSnapshotPlugin: @customize def fakedatetime_handler(self,value,builder): if isinstance(value,FakeDatetime): - return builder.create_code(value,value.__repr__().replace("FakeDatetime","datetime.datetime")) + return builder.create_code(value.__repr__().replace("FakeDatetime","datetime.datetime")) @customize def fakedate_handler(self,value,builder): if isinstance(value,FakeDate): - return builder.create_code(value,value.__repr__().replace("FakeDate","datetime.date")) + return builder.create_code(value.__repr__().replace("FakeDate","datetime.date")) @pytest.fixture(autouse=True) diff --git a/tests/test_builder.py b/tests/test_builder.py index b2a5f7ba..3ef6505d 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -13,7 +13,7 @@ class InlineSnapshotPlugin: @customize def handler(self,value,builder): if value==5: - return builder.with_default(5,builder.create_code(8)) + return builder.with_default(5,builder.create_code("8")) """, "test_a.py": """\ from inline_snapshot import snapshot diff --git a/tests/test_customize.py b/tests/test_customize.py index 551142d7..227c995e 100644 --- a/tests/test_customize.py +++ b/tests/test_customize.py @@ -97,7 +97,7 @@ def test_with_import(original, flag): { "conftest.py": """\ from inline_snapshot.plugin import customize -from inline_snapshot.plugin import Builder +from inline_snapshot.plugin import Builder, Import from pkg.subpkg import ComplexObj class InlineSnapshotPlugin: @@ -105,9 +105,9 @@ class InlineSnapshotPlugin: def complex_handler(self, value, builder: Builder): if isinstance(value, ComplexObj): return builder.create_code( - value, - f"mod1.helper(pkg.subpkg.create({value.a!r}, {value.b!r}))" - ).with_import("mod1").with_import("pkg.subpkg") + f"mod1.helper(pkg.subpkg.create({value.a!r}, {value.b!r}))", + imports=[Import("mod1"), Import("pkg.subpkg")] + ) """, "mod1.py": """\ def helper(obj): @@ -164,7 +164,7 @@ def test_with_import_preserves_existing(original, flag, existing_import): { "conftest.py": """\ from inline_snapshot.plugin import customize -from inline_snapshot.plugin import Builder +from inline_snapshot.plugin import Builder, Import from mymodule import MyClass class InlineSnapshotPlugin: @@ -172,9 +172,9 @@ class InlineSnapshotPlugin: def myclass_handler(self, value, builder: Builder): if isinstance(value, MyClass): return builder.create_code( - value, - f"mymodule.MyClass({value.value!r})" - ).with_import("mymodule") + f"mymodule.MyClass({value.value!r})", + imports=[Import("mymodule")] + ) """, "mymodule.py": """\ class MyClass: @@ -213,3 +213,37 @@ def test_a(): } ), ).run_inline() + + +def test_customized_value_mismatch_error(): + """Test that UsageError is raised when customized value doesn't match original.""" + + Example( + { + "conftest.py": """\ +from inline_snapshot.plugin import customize +from inline_snapshot.plugin import Builder + +class InlineSnapshotPlugin: + @customize + def bad_handler(self, value, builder: Builder): + if value == 42: + # Return a CustomCode with wrong value - repr evaluates to 100 but original is 42 + return builder.create_code("100") +""", + "test_something.py": """\ +from inline_snapshot import snapshot + +def test_a(): + assert snapshot() == 42 +""", + } + ).run_inline( + ["--inline-snapshot=create"], + raises=snapshot( + """\ +UsageError: +Customized value does not match original value: 100 != 42\ +""" + ), + ) diff --git a/tests/test_docs.py b/tests/test_docs.py index 503d5305..1a4a9703 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -319,12 +319,12 @@ class InlineSnapshotPlugin: @customize def fakedatetime_handler(self,value,builder): if isinstance(value,FakeDatetime): - return builder.create_code(value,value.__repr__().replace("FakeDatetime","datetime.datetime")) + return builder.create_code(value.__repr__().replace("FakeDatetime","datetime.datetime")) @customize def fakedate_handler(self,value,builder): if isinstance(value,FakeDate): - return builder.create_code(value,value.__repr__().replace("FakeDate","datetime.date")) + return builder.create_code(value.__repr__().replace("FakeDate","datetime.date")) @pytest.fixture(autouse=True) From d6eac4241fbc3d24ab42cbb7862cdd7c37b8f3e7 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Wed, 21 Jan 2026 09:07:37 +0100 Subject: [PATCH 61/72] test: coverage --- .github/workflows/ci.yml | 3 +- src/inline_snapshot/_customize/_builder.py | 2 + .../_customize/_custom_code.py | 2 + tests/test_customize.py | 42 +++++++++++++++++++ 4 files changed, 48 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4c794ad2..f37e6616 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -122,7 +122,8 @@ jobs: uvx coverage html --skip-covered --skip-empty # Report and write to summary. - uvx coverage report --format=markdown >> $GITHUB_STEP_SUMMARY + uvx coverage report --format=markdown --skip-covered --skip-empty -m >> $GITHUB_STEP_SUMMARY + uvx coverage report --skip-covered --skip-empty > htmlcov/report.txt # Report again and fail if under 100%. uvx coverage report --fail-under=100 diff --git a/src/inline_snapshot/_customize/_builder.py b/src/inline_snapshot/_customize/_builder.py index 21a49cab..d29999a0 100644 --- a/src/inline_snapshot/_customize/_builder.py +++ b/src/inline_snapshot/_customize/_builder.py @@ -176,6 +176,8 @@ def _build_import_vars(self, imports): # from module import name module = importlib.import_module(imp.module) import_vars[imp.name] = getattr(module, imp.name) + else: + assert False return import_vars def create_code( diff --git a/src/inline_snapshot/_customize/_custom_code.py b/src/inline_snapshot/_customize/_custom_code.py index c7238051..e9f14925 100644 --- a/src/inline_snapshot/_customize/_custom_code.py +++ b/src/inline_snapshot/_customize/_custom_code.py @@ -79,6 +79,8 @@ def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str module=_simplify_module_path(imp.module, imp.name), name=imp.name, ) + else: + assert False return self.repr_str diff --git a/tests/test_customize.py b/tests/test_customize.py index 227c995e..654568ce 100644 --- a/tests/test_customize.py +++ b/tests/test_customize.py @@ -247,3 +247,45 @@ def test_a(): """ ), ) + + +@pytest.mark.parametrize("original,flag", [("'wrong'", "fix"), ("", "create")]) +def test_global_var_lookup(original, flag): + """Test that create_code can look up global variables.""" + + Example( + { + "conftest.py": """\ +from inline_snapshot.plugin import customize +from inline_snapshot.plugin import Builder + +class InlineSnapshotPlugin: + @customize + def use_global(self, value, builder: Builder): + if value == "test_value": + return builder.create_code("GLOBAL_VAR") +""", + "test_something.py": f"""\ +from inline_snapshot import snapshot + +GLOBAL_VAR = "test_value" + +def test_a(): + assert snapshot({original}) == "test_value" +""", + } + ).run_inline( + [f"--inline-snapshot={flag}"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot + +GLOBAL_VAR = "test_value" + +def test_a(): + assert snapshot(GLOBAL_VAR) == "test_value" +""" + } + ), + ) From 4215c8b79532dd808df8ec39ba803f5a73f7a936 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Thu, 22 Jan 2026 08:44:33 +0100 Subject: [PATCH 62/72] test: coverage --- src/inline_snapshot/_global_state.py | 20 ++---- src/inline_snapshot/plugin/_default_plugin.py | 72 +++++++++++++------ src/inline_snapshot/testing/_example.py | 8 +-- tests/conftest.py | 2 +- tests/test_customize.py | 28 ++++++++ tests/test_dirty_equals.py | 29 ++++++++ 6 files changed, 118 insertions(+), 41 deletions(-) diff --git a/src/inline_snapshot/_global_state.py b/src/inline_snapshot/_global_state.py index 89053f3d..4db55a15 100644 --- a/src/inline_snapshot/_global_state.py +++ b/src/inline_snapshot/_global_state.py @@ -86,23 +86,15 @@ def enter_snapshot_context(): _current.pm.add_hookspecs(InlineSnapshotPluginSpec) + from .plugin._default_plugin import InlineSnapshotAttrsPlugin + from .plugin._default_plugin import InlineSnapshotDirtyEqualsPlugin from .plugin._default_plugin import InlineSnapshotPlugin + from .plugin._default_plugin import InlineSnapshotPydanticPlugin _current.pm.register(InlineSnapshotPlugin()) - - try: - from .plugin._default_plugin import InlineSnapshotAttrsPlugin - except ImportError: # pragma: no cover - pass - else: - _current.pm.register(InlineSnapshotAttrsPlugin()) - - try: - from .plugin._default_plugin import InlineSnapshotPydanticPlugin - except ImportError: # pragma: no cover - pass - else: - _current.pm.register(InlineSnapshotPydanticPlugin()) + _current.pm.register(InlineSnapshotAttrsPlugin()) + _current.pm.register(InlineSnapshotPydanticPlugin()) + _current.pm.register(InlineSnapshotDirtyEqualsPlugin()) _current.pm.load_setuptools_entrypoints(inline_snapshot_plugin_name) diff --git a/src/inline_snapshot/plugin/_default_plugin.py b/src/inline_snapshot/plugin/_default_plugin.py index e5db9996..cccef739 100644 --- a/src/inline_snapshot/plugin/_default_plugin.py +++ b/src/inline_snapshot/plugin/_default_plugin.py @@ -215,26 +215,6 @@ def undefined_handler(self, value, builder: Builder): if value is undefined: return CustomUndefined() - @customize(tryfirst=True) - def dirty_equals_handler(self, value, builder: Builder): - - if is_dirty_equal(value) and builder._build_new_value: - - if isinstance(value, type): - return builder.create_code( - value.__name__, - imports=[ImportFrom("dirty_equals", value.__name__)], - ) - else: - from dirty_equals import IsNow - from dirty_equals._utils import Omit - - args = [a for a in value._repr_args if a is not Omit] - kwargs = {k: a for k, a in value._repr_kwargs.items() if a is not Omit} - if type(value) == IsNow: - kwargs.pop("approx") - return builder.create_call(type(value), args, kwargs) - @customize def outsource_handler(self, value, builder: Builder): if isinstance(value, Outsourced): @@ -243,11 +223,57 @@ def outsource_handler(self, value, builder: Builder): ) +try: + pass +except ImportError: # pragma: no cover + + class InlineSnapshotDirtyEqualsPlugin: + pass + +else: + import datetime + + from dirty_equals import IsNow + from dirty_equals._utils import Omit + + class InlineSnapshotDirtyEqualsPlugin: + @customize(tryfirst=True) + def dirty_equals_handler(self, value, builder: Builder): + + if is_dirty_equal(value) and builder._build_new_value: + + if isinstance(value, type): + return builder.create_code( + value.__name__, + imports=[ImportFrom("dirty_equals", value.__name__)], + ) + else: + + args = [a for a in value._repr_args if a is not Omit] + kwargs = { + k: a for k, a in value._repr_kwargs.items() if a is not Omit + } + if type(value) == IsNow: + kwargs.pop("approx") + if ( + isinstance(delta := kwargs["delta"], datetime.timedelta) + and delta.total_seconds() == 2 + ): + kwargs.pop("delta") + return builder.create_call(type(value), args, kwargs) + + @customize(tryfirst=True) + def is_now_handler(self, value, builder: Builder): + if value == IsNow(): + return IsNow() + + try: import attrs except ImportError: # pragma: no cover - pass + class InlineSnapshotAttrsPlugin: + pass else: @@ -286,7 +312,9 @@ def attrs_handler(self, value, builder: Builder): try: import pydantic except ImportError: # pragma: no cover - pass + + class InlineSnapshotPydanticPlugin: + pass else: # import pydantic diff --git a/src/inline_snapshot/testing/_example.py b/src/inline_snapshot/testing/_example.py index 7121e360..dc519d37 100644 --- a/src/inline_snapshot/testing/_example.py +++ b/src/inline_snapshot/testing/_example.py @@ -355,9 +355,9 @@ def report_error(message): conftest_module ) else: - raise UsageError( - f"Could not load conftest from {conftest_path}" - ) + assert ( + False + ), f"Could not load conftest from {conftest_path}" tests_found = False for filename in tmp_path.rglob("test_*.py"): @@ -372,7 +372,7 @@ def report_error(message): sys.modules[filename.stem] = module spec.loader.exec_module(module) else: - raise UsageError(f"Could not load module from {filename}") + assert False, f"Could not load module from {filename}" # run all test_* functions tests = [ diff --git a/tests/conftest.py b/tests/conftest.py index dd85943a..48ad4570 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -124,7 +124,7 @@ def run(self, *flags_arg: Category): sys.modules[filename.stem] = module spec.loader.exec_module(module) else: - raise RuntimeError(f"Could not load module from {filename}") + assert False, f"Could not load module from {filename}" except AssertionError: traceback.print_exc() error = True diff --git a/tests/test_customize.py b/tests/test_customize.py index 654568ce..15f5302c 100644 --- a/tests/test_customize.py +++ b/tests/test_customize.py @@ -289,3 +289,31 @@ def test_a(): } ), ) + + +@pytest.mark.parametrize("original,flag", [("'wrong'", "fix"), ("", "create")]) +def test_file_handler(original, flag): + """Test that __file__ handler creates correct code.""" + + Example( + { + "test_something.py": f"""\ +from inline_snapshot import snapshot + +def test_a(): + assert snapshot({original}) == __file__ +""", + } + ).run_inline( + [f"--inline-snapshot={flag}"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot + +def test_a(): + assert snapshot(__file__) == __file__ +""" + } + ), + ) diff --git a/tests/test_dirty_equals.py b/tests/test_dirty_equals.py index 2d51188f..1098ace5 100644 --- a/tests/test_dirty_equals.py +++ b/tests/test_dirty_equals.py @@ -159,3 +159,32 @@ def test_number(): } ), ) + + +def test_is_now_without_approx() -> None: + """Test that IsNow handler correctly removes 'approx' kwarg.""" + + Example( + """\ +from datetime import datetime +from dirty_equals import IsNow +from inline_snapshot import snapshot + +def test_time(): + assert datetime.now() == snapshot() +""" + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from datetime import datetime +from dirty_equals import IsNow +from inline_snapshot import snapshot + +def test_time(): + assert datetime.now() == snapshot(IsNow()) +""" + } + ), + ) From 11b26e5392f42adf8f1a28243d12e624d22b7fcc Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Thu, 22 Jan 2026 09:15:20 +0100 Subject: [PATCH 63/72] feat: removed kw_only_args from CustomCall --- .github/workflows/ci.yml | 2 +- docs/eq_snapshot.md | 42 ++++--------------- src/inline_snapshot/_customize/_builder.py | 10 ++--- .../_customize/_custom_call.py | 27 ++++-------- src/inline_snapshot/_new_adapter.py | 13 +++--- src/inline_snapshot/plugin/_default_plugin.py | 9 ++-- src/inline_snapshot/testing/_example.py | 34 +++++++-------- 7 files changed, 45 insertions(+), 92 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f37e6616..caa69a99 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -123,7 +123,7 @@ jobs: # Report and write to summary. uvx coverage report --format=markdown --skip-covered --skip-empty -m >> $GITHUB_STEP_SUMMARY - uvx coverage report --skip-covered --skip-empty > htmlcov/report.txt + uvx coverage report --skip-covered --skip-empty -m > htmlcov/report.txt # Report again and fail if under 100%. uvx coverage report --fail-under=100 diff --git a/docs/eq_snapshot.md b/docs/eq_snapshot.md index 41340811..fd3e1df7 100644 --- a/docs/eq_snapshot.md +++ b/docs/eq_snapshot.md @@ -135,31 +135,11 @@ def test_function(): If you use `--inline-snapshot=create`, inline-snapshot will record the current `datetime` in the snapshot: -``` python hl_lines="13 14 15" +``` python hl_lines="4 5 15" import datetime from inline_snapshot import snapshot - -def get_data(): - return { - "date": datetime.datetime.utcnow(), - "payload": "some data", - } - - -def test_function(): - assert get_data() == snapshot( - {"date": datetime.datetime(2024, 3, 14, 0, 0), "payload": "some data"} - ) -``` - -To avoid the test failing in future runs, replace the `datetime` with [dirty-equals' `IsDatetime()`](https://dirty-equals.helpmanual.io/latest/types/datetime/#dirty_equals.IsDatetime): - - -``` python -import datetime -from dirty_equals import IsDatetime -from inline_snapshot import snapshot +from dirty_equals import IsNow def get_data(): @@ -170,20 +150,16 @@ def get_data(): def test_function(): - assert get_data() == snapshot( - { - "date": IsDatetime(), - "payload": "some data", - } - ) + assert get_data() == snapshot({"date": IsNow(), "payload": "some data"}) ``` +Inline-snapshot uses [dirty-equals `IsNow()`](https://dirty-equals.helpmanual.io/latest/types/datetime/#dirty_equals.IsDatetime) by default when the value is equal to the current time to avoid the test failing in future runs. Say a different part of the return data changes, such as the `payload` value: -``` python hl_lines="9" +``` python hl_lines="3 9 14 15 16 17 18 19" import datetime -from dirty_equals import IsDatetime +from dirty_equals import IsNow from inline_snapshot import snapshot @@ -197,7 +173,7 @@ def get_data(): def test_function(): assert get_data() == snapshot( { - "date": IsDatetime(), + "date": IsNow(), "payload": "some data", } ) @@ -208,7 +184,7 @@ Re-running the test with `--inline-snapshot=fix` will update the snapshot to mat ``` python hl_lines="17" import datetime -from dirty_equals import IsDatetime +from dirty_equals import IsNow from inline_snapshot import snapshot @@ -222,7 +198,7 @@ def get_data(): def test_function(): assert get_data() == snapshot( { - "date": IsDatetime(), + "date": IsNow(), "payload": "data changed for some good reason", } ) diff --git a/src/inline_snapshot/_customize/_builder.py b/src/inline_snapshot/_customize/_builder.py index d29999a0..eb79e5e5 100644 --- a/src/inline_snapshot/_customize/_builder.py +++ b/src/inline_snapshot/_customize/_builder.py @@ -102,7 +102,7 @@ def with_default(self, value: Any, default: Any): return value def create_call( - self, function: Custom | Callable, posonly_args=[], kwargs={}, kwonly_args={} + self, function: Custom | Callable, posonly_args=[], kwargs={} ) -> Custom: """ Creates an intermediate node for a function call expression which can be used as a result for your customization function. @@ -113,13 +113,11 @@ def create_call( function = self._get_handler(function) posonly_args = [self._get_handler(arg) for arg in posonly_args] kwargs = {k: self._get_handler(arg) for k, arg in kwargs.items()} - kwonly_args = {k: self._get_handler(arg) for k, arg in kwonly_args.items()} return CustomCall( - _function=function, - _args=posonly_args, - _kwargs=kwargs, - _kwonly=kwonly_args, + function=function, + args=posonly_args, + kwargs=kwargs, ) def create_dict(self, value: dict) -> Custom: diff --git a/src/inline_snapshot/_customize/_custom_call.py b/src/inline_snapshot/_customize/_custom_call.py index 528b43d6..f6ea8eb9 100644 --- a/src/inline_snapshot/_customize/_custom_call.py +++ b/src/inline_snapshot/_customize/_custom_call.py @@ -33,10 +33,9 @@ def unwrap_default(value): @dataclass(frozen=True) class CustomCall(Custom): node_type = ast.Call - _function: Custom = field(compare=False) - _args: list[Custom] = field(compare=False) - _kwargs: dict[str, Custom] = field(compare=False) - _kwonly: dict[str, Custom] = field(default_factory=dict, compare=False) + function: Custom = field(compare=False) + args: list[Custom] = field(compare=False) + kwargs: dict[str, Custom] = field(compare=False) def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: args = [] @@ -49,28 +48,16 @@ def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str value = yield from v._code_repr(context) args.append(f"{k}={value}") - return f"{yield from self._function._code_repr(context)}({', '.join(args)})" - - @property - def args(self): - return self._args - - @property - def all_pos_args(self): - return [*self._args, *self._kwargs.values()] - - @property - def kwargs(self): - return {**self._kwargs, **self._kwonly} + return f"{yield from self.function._code_repr(context)}({', '.join(args)})" def argument(self, pos_or_str): if isinstance(pos_or_str, int): - return unwrap_default(self.all_pos_args[pos_or_str]) + return unwrap_default(self.args[pos_or_str]) else: return unwrap_default(self.kwargs[pos_or_str]) def _map(self, f): - return self._function._map(f)( - *[f(x._map(f)) for x in self._args], + return self.function._map(f)( + *[f(x._map(f)) for x in self.args], **{k: f(v._map(f)) for k, v in self.kwargs.items()}, ) diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index db5a2c6d..a76286c5 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -119,10 +119,9 @@ def reeval_CustomCode(old_value: CustomCode, value: CustomCode): def reeval_CustomCall(old_value: CustomCall, value: CustomCall): return CustomCall( - reeval(old_value._function, value._function), - [reeval(a, b) for a, b in zip(old_value._args, value._args)], - {k: reeval(old_value._kwargs[k], value._kwargs[k]) for k in old_value._kwargs}, - {k: reeval(old_value._kwonly[k], value._kwonly[k]) for k in old_value._kwonly}, + reeval(old_value.function, value.function), + [reeval(a, b) for a, b in zip(old_value.args, value.args)], + {k: reeval(old_value.kwargs[k], value.kwargs[k]) for k in old_value.kwargs}, ) @@ -443,7 +442,7 @@ def intercept(change): # keyword arguments result_kwargs = {} if old_node is None: - old_keywords = {key: None for key in old_value._kwargs.keys()} + old_keywords = {key: None for key in old_value.kwargs.keys()} else: old_keywords = {kw.arg: kw.value for kw in old_node.keywords} @@ -516,9 +515,9 @@ def intercept(change): ( yield from intercept( self.compare( - old_value._function, + old_value.function, old_node.func if old_node else None, - new_value._function, + new_value.function, ) ) ), diff --git a/src/inline_snapshot/plugin/_default_plugin.py b/src/inline_snapshot/plugin/_default_plugin.py index cccef739..7f3a0806 100644 --- a/src/inline_snapshot/plugin/_default_plugin.py +++ b/src/inline_snapshot/plugin/_default_plugin.py @@ -168,7 +168,7 @@ def dataclass_handler(self, value, builder: Builder): kwargs[field.name] = field_value - return builder.create_call(type(value), [], kwargs, {}) + return builder.create_call(type(value), [], kwargs) @customize def namedtuple_handler(self, value, builder: Builder): @@ -195,14 +195,13 @@ def namedtuple_handler(self, value, builder: Builder): ) for field in value._fields }, - {}, ) @customize(tryfirst=True) def defaultdict_handler(self, value, builder: Builder): if isinstance(value, defaultdict): return builder.create_call( - type(value), [value.default_factory, dict(value)], {}, {} + type(value), [value.default_factory, dict(value)], {} ) @customize @@ -306,7 +305,7 @@ def attrs_handler(self, value, builder: Builder): kwargs[field.name] = field_value - return builder.create_call(type(value), [], kwargs, {}) + return builder.create_call(type(value), [], kwargs) try: @@ -361,4 +360,4 @@ def pydantic_model_handler(self, value, builder: Builder): kwargs[name] = field_value - return builder.create_call(type(value), [], kwargs, {}) + return builder.create_call(type(value), [], kwargs) diff --git a/src/inline_snapshot/testing/_example.py b/src/inline_snapshot/testing/_example.py index dc519d37..98211fa8 100644 --- a/src/inline_snapshot/testing/_example.py +++ b/src/inline_snapshot/testing/_example.py @@ -344,20 +344,15 @@ def report_error(message): spec = importlib.util.spec_from_file_location( f"conftest_{conftest_path.parent.name}", conftest_path ) - if spec and spec.loader: - conftest_module = importlib.util.module_from_spec(spec) - sys.modules[spec.name] = conftest_module - conftest_module.__file__ = str(conftest_path) - spec.loader.exec_module(conftest_module) - - # Register customize hooks from this conftest - session.register_customize_hooks_from_module( - conftest_module - ) - else: - assert ( - False - ), f"Could not load conftest from {conftest_path}" + assert spec and spec.loader + + conftest_module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = conftest_module + conftest_module.__file__ = str(conftest_path) + spec.loader.exec_module(conftest_module) + + # Register customize hooks from this conftest + session.register_customize_hooks_from_module(conftest_module) tests_found = False for filename in tmp_path.rglob("test_*.py"): @@ -367,12 +362,11 @@ def report_error(message): spec = importlib.util.spec_from_file_location( filename.stem, filename ) - if spec and spec.loader: - module = importlib.util.module_from_spec(spec) - sys.modules[filename.stem] = module - spec.loader.exec_module(module) - else: - assert False, f"Could not load module from {filename}" + assert spec and spec.loader + + module = importlib.util.module_from_spec(spec) + sys.modules[filename.stem] = module + spec.loader.exec_module(module) # run all test_* functions tests = [ From a4daaf564b437a9a345fcbd368db4b0a970870a6 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sat, 24 Jan 2026 20:25:46 +0100 Subject: [PATCH 64/72] feat: changed plugin implementation --- docs/plugin.md | 91 ++++++++++--------- pyproject.toml | 1 + src/inline_snapshot/_customize/_builder.py | 9 +- .../_customize/_custom_call.py | 8 +- .../_customize/_custom_code.py | 2 +- src/inline_snapshot/_global_state.py | 27 ++++-- src/inline_snapshot/_new_adapter.py | 4 +- src/inline_snapshot/_snapshot_session.py | 11 +++ src/inline_snapshot/plugin/_default_plugin.py | 9 +- src/inline_snapshot/plugin/_spec.py | 2 +- tests/test_customize.py | 7 +- 11 files changed, 105 insertions(+), 66 deletions(-) diff --git a/docs/plugin.md b/docs/plugin.md index e5f142e1..e9de7b67 100644 --- a/docs/plugin.md +++ b/docs/plugin.md @@ -36,7 +36,19 @@ It searches for plugins in: Loading plugins from the `conftest.py` files is the recommended way when you want to change the behavior of inline-snapshot in your own project. -The plugins are searched in your `conftest.py` and the name has to start with `InlineSnapshot*`. Each plugin which is loaded from your `conftest.py` is active globally for all your tests. +Simply use `@customize` on functions directly in your `conftest.py`: + +``` python +from inline_snapshot.plugin import customize + + +@customize +def my_handler(value, builder): + # your logic + pass +``` + +All customizations defined in your `conftest.py` are active globally for all your tests. ### Creating a Plugin Package @@ -44,7 +56,7 @@ To distribute inline-snapshot plugins as a package, register your plugin class u === "pyproject.toml (recommended)" ``` toml - [project.entry-points.inline-snapshot] + [project.entry-points.inline_snapshot] my_plugin = "my_package.plugin:MyInlineSnapshotPlugin" ``` @@ -53,8 +65,8 @@ To distribute inline-snapshot plugins as a package, register your plugin class u setup( name="my-inline-snapshot-plugin", entry_points={ - "inline-snapshot": [ - "my_plugin = my_package.plugin:MyInlineSnapshotPlugin", + "inline_snapshot": [ + "my_plugin = my_package.plugin", ], }, ) @@ -66,17 +78,11 @@ Your plugin class should contain methods decorated with `@customize`, just like from inline_snapshot.plugin import customize, Builder -class MyInlineSnapshotPlugin: - """ - This class will be instantiated by inline-snapshot when the package is installed. - Typically used by library authors who want to provide inline-snapshot integration. - """ - - @customize - def my_custom_handler(self, value, builder: Builder): - # Your customization logic here - if isinstance(value, YourCustomType): - return builder.create_call(YourCustomType, [value.arg]) +@customize +def my_custom_handler(value, builder: Builder): + # Your customization logic here + if isinstance(value, YourCustomType): + return builder.create_call(YourCustomType, [value.arg]) ``` Once installed, the plugin will be automatically loaded by inline-snapshot. @@ -128,11 +134,10 @@ from inline_snapshot.plugin import customize from inline_snapshot.plugin import Builder -class InlineSnapshotPlugin: - @customize - def square_handler(self, value, builder: Builder): - if isinstance(value, Rect) and value.width == value.height: - return builder.create_call(Rect.make_square, [value.width]) +@customize +def square_handler(value, builder: Builder): + if isinstance(value, Rect) and value.width == value.height: + return builder.create_call(Rect.make_square, [value.width]) ``` This allows you to influence the code that is created by inline-snapshot. @@ -167,11 +172,10 @@ from inline_snapshot.plugin import Builder from dirty_equals import IsNow -class InlineSnapshotPlugin: - @customize - def is_now_handler(self, value): - if value == IsNow(): - return IsNow +@customize +def is_now_handler(value): + if value == IsNow(): + return IsNow ``` As explained in the [customize hook specification][inline_snapshot.plugin.InlineSnapshotPluginSpec.customize], you can return types other than Custom objects. Inline-snapshot includes a built-in handler in its default plugin that converts dirty-equals expressions back into source code, which is why you can return `IsNow` directly without using the builder. This approach is much simpler than using `builder.create_call()` for complex dirty-equals expressions. @@ -204,11 +208,10 @@ from inline_snapshot.plugin import customize from inline_snapshot.plugin import Builder -class InlineSnapshotPlugin: - @customize - def long_string_handler(self, value, builder: Builder): - if isinstance(value, str) and value.count("\n") > 5: - return builder.create_external(value) +@customize +def long_string_handler(value, builder: Builder): + if isinstance(value, str) and value.count("\n") > 5: + return builder.create_external(value) ``` @@ -240,12 +243,11 @@ from inline_snapshot.plugin import customize from inline_snapshot.plugin import Builder -class InlineSnapshotPlugin: - @customize - def local_var_handler(self, value, builder, local_vars): - for var_name, var_value in local_vars.items(): - if var_name.startswith("v_") and var_value == value: - return builder.create_code(var_name) +@customize +def local_var_handler(value, builder, local_vars): + for var_name, var_value in local_vars.items(): + if var_name.startswith("v_") and var_value == value: + return builder.create_code(var_name) ``` We check all local variables to see if they match our naming convention and are equal to the value that is part of our snapshot, and return the local variable if we find one that fits the criteria. @@ -317,15 +319,14 @@ from my_secrets import secrets from inline_snapshot.plugin import customize, Builder, ImportFrom -class InlineSnapshotPlugin: - @customize - def secret_handler(self, value, builder: Builder): - for i, secret in enumerate(secrets): - if value == secret: - return builder.create_code( - f"secrets[{i}]", - imports=[ImportFrom("my_secrets", "secrets")], - ) +@customize +def secret_handler(value, builder: Builder): + for i, secret in enumerate(secrets): + if value == secret: + return builder.create_code( + f"secrets[{i}]", + imports=[ImportFrom("my_secrets", "secrets")], + ) ``` The [`create_code()`][inline_snapshot.plugin.Builder.create_code] method takes the desired code representation. The `imports` parameter adds the necessary import statements. diff --git a/pyproject.toml b/pyproject.toml index 6cbb6392..de02550c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "rich>=13.7.1", "tomli>=2.0.0; python_version < '3.11'", "pytest>=8.3.4", + "typing-extensions" ] description = "golden master/snapshot/approval testing library which puts the values right into your source code" keywords = [] diff --git a/src/inline_snapshot/_customize/_builder.py b/src/inline_snapshot/_customize/_builder.py index eb79e5e5..d58c07fc 100644 --- a/src/inline_snapshot/_customize/_builder.py +++ b/src/inline_snapshot/_customize/_builder.py @@ -53,7 +53,14 @@ def _get_handler(self, v) -> Custom: if not isinstance(v, Custom) and self._build_new_value: if result._eval() != v: raise UsageError( - f"Customized value does not match original value: {result._eval()!r} != {v!r}" + f"""\ +Customized value does not match original value: + +original_value={v!r} + +customized_value={result._eval()!r} +customized_representation={result!r} +""" ) return result diff --git a/src/inline_snapshot/_customize/_custom_call.py b/src/inline_snapshot/_customize/_custom_call.py index f6ea8eb9..e2c884f9 100644 --- a/src/inline_snapshot/_customize/_custom_call.py +++ b/src/inline_snapshot/_customize/_custom_call.py @@ -40,13 +40,13 @@ class CustomCall(Custom): def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: args = [] for a in self.args: - v = yield from a._code_repr(context) - args.append(v) + code = yield from a._code_repr(context) + args.append(code) for k, v in self.kwargs.items(): if not isinstance(v, CustomDefault): - value = yield from v._code_repr(context) - args.append(f"{k}={value}") + code = yield from v._code_repr(context) + args.append(f"{k}={code}") return f"{yield from self.function._code_repr(context)}({', '.join(args)})" diff --git a/src/inline_snapshot/_customize/_custom_code.py b/src/inline_snapshot/_customize/_custom_code.py index e9f14925..9c0b27e8 100644 --- a/src/inline_snapshot/_customize/_custom_code.py +++ b/src/inline_snapshot/_customize/_custom_code.py @@ -85,4 +85,4 @@ def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str return self.repr_str def __repr__(self): - return f"CustomValue({self.repr_str})" + return f"CustomCode({self.repr_str!r})" diff --git a/src/inline_snapshot/_global_state.py b/src/inline_snapshot/_global_state.py index 4db55a15..44852437 100644 --- a/src/inline_snapshot/_global_state.py +++ b/src/inline_snapshot/_global_state.py @@ -86,15 +86,30 @@ def enter_snapshot_context(): _current.pm.add_hookspecs(InlineSnapshotPluginSpec) - from .plugin._default_plugin import InlineSnapshotAttrsPlugin - from .plugin._default_plugin import InlineSnapshotDirtyEqualsPlugin from .plugin._default_plugin import InlineSnapshotPlugin - from .plugin._default_plugin import InlineSnapshotPydanticPlugin _current.pm.register(InlineSnapshotPlugin()) - _current.pm.register(InlineSnapshotAttrsPlugin()) - _current.pm.register(InlineSnapshotPydanticPlugin()) - _current.pm.register(InlineSnapshotDirtyEqualsPlugin()) + + try: + from .plugin._default_plugin import InlineSnapshotAttrsPlugin + except ImportError: + pass + else: + _current.pm.register(InlineSnapshotAttrsPlugin()) + + try: + from .plugin._default_plugin import InlineSnapshotPydanticPlugin + except ImportError: + pass + else: + _current.pm.register(InlineSnapshotPydanticPlugin()) + + try: + from .plugin._default_plugin import InlineSnapshotDirtyEqualsPlugin + except ImportError: + pass + else: + _current.pm.register(InlineSnapshotDirtyEqualsPlugin()) _current.pm.load_setuptools_entrypoints(inline_snapshot_plugin_name) diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index a76286c5..1415cb55 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -444,7 +444,9 @@ def intercept(change): if old_node is None: old_keywords = {key: None for key in old_value.kwargs.keys()} else: - old_keywords = {kw.arg: kw.value for kw in old_node.keywords} + old_keywords = { + kw.arg: kw.value for kw in old_node.keywords if kw.arg is not None + } for kw_arg, kw_value in old_keywords.items(): missing = kw_arg not in new_kwargs diff --git a/src/inline_snapshot/_snapshot_session.py b/src/inline_snapshot/_snapshot_session.py index 9535f31b..63f26895 100644 --- a/src/inline_snapshot/_snapshot_session.py +++ b/src/inline_snapshot/_snapshot_session.py @@ -2,6 +2,7 @@ import sys import tokenize from pathlib import Path +from types import SimpleNamespace from typing import Dict from typing import List @@ -248,11 +249,21 @@ def register_customize_hooks_from_module(self, module): self.registered_modules.add(module.__file__) + hooks = {} + for name in dir(module): obj = getattr(module, name, None) if isinstance(obj, type) and name.startswith("InlineSnapshot"): state().pm.register(obj(), name=f"") + if hasattr(obj, "inline_snapshot_impl"): + hooks[name] = obj + + if hooks: + state().pm.register( + SimpleNamespace(**hooks), name=f"" + ) + @staticmethod def test_enter(): state().missing_values = 0 diff --git a/src/inline_snapshot/plugin/_default_plugin.py b/src/inline_snapshot/plugin/_default_plugin.py index 7f3a0806..bdbbd517 100644 --- a/src/inline_snapshot/plugin/_default_plugin.py +++ b/src/inline_snapshot/plugin/_default_plugin.py @@ -226,8 +226,7 @@ def outsource_handler(self, value, builder: Builder): pass except ImportError: # pragma: no cover - class InlineSnapshotDirtyEqualsPlugin: - pass + pass else: import datetime @@ -271,8 +270,7 @@ def is_now_handler(self, value, builder: Builder): import attrs except ImportError: # pragma: no cover - class InlineSnapshotAttrsPlugin: - pass + pass else: @@ -312,8 +310,7 @@ def attrs_handler(self, value, builder: Builder): import pydantic except ImportError: # pragma: no cover - class InlineSnapshotPydanticPlugin: - pass + pass else: # import pydantic diff --git a/src/inline_snapshot/plugin/_spec.py b/src/inline_snapshot/plugin/_spec.py index c0b91b5a..e56eeaea 100644 --- a/src/inline_snapshot/plugin/_spec.py +++ b/src/inline_snapshot/plugin/_spec.py @@ -6,7 +6,7 @@ from inline_snapshot._customize._builder import Builder -inline_snapshot_plugin_name = "inline-snapshot" +inline_snapshot_plugin_name = "inline_snapshot" hookspec = pluggy.HookspecMarker(inline_snapshot_plugin_name) """ diff --git a/tests/test_customize.py b/tests/test_customize.py index 15f5302c..8b709d2c 100644 --- a/tests/test_customize.py +++ b/tests/test_customize.py @@ -243,7 +243,12 @@ def test_a(): raises=snapshot( """\ UsageError: -Customized value does not match original value: 100 != 42\ +Customized value does not match original value: + +original_value=42 + +customized_value=100 +customized_representation=CustomCode('100') """ ), ) From 0df95c8e2c8f63df237888751af50e226ff48bab Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sat, 24 Jan 2026 21:56:20 +0100 Subject: [PATCH 65/72] feat: datetime customization --- src/inline_snapshot/_global_state.py | 6 +-- src/inline_snapshot/_new_adapter.py | 5 +-- src/inline_snapshot/plugin/_default_plugin.py | 44 ++++++++++++++++++ tests/external/storage/test_hash.py | 24 ++++++++++ tests/test_customize.py | 37 +++++++++++++++ tests/test_dirty_equals.py | 45 +++++++++++++++++++ 6 files changed, 155 insertions(+), 6 deletions(-) diff --git a/src/inline_snapshot/_global_state.py b/src/inline_snapshot/_global_state.py index 44852437..92cb3f3a 100644 --- a/src/inline_snapshot/_global_state.py +++ b/src/inline_snapshot/_global_state.py @@ -93,21 +93,21 @@ def enter_snapshot_context(): try: from .plugin._default_plugin import InlineSnapshotAttrsPlugin except ImportError: - pass + pass # pragma: no cover else: _current.pm.register(InlineSnapshotAttrsPlugin()) try: from .plugin._default_plugin import InlineSnapshotPydanticPlugin except ImportError: - pass + pass # pragma: no cover else: _current.pm.register(InlineSnapshotPydanticPlugin()) try: from .plugin._default_plugin import InlineSnapshotDirtyEqualsPlugin except ImportError: - pass + pass # pragma: no cover else: _current.pm.register(InlineSnapshotDirtyEqualsPlugin()) diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index 1415cb55..b0362ce6 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -396,9 +396,8 @@ def compare_CustomCall( @make_gen_map def intercept(change): - if flag == "update": - if change.flag == "fix": - change.flag = "update" + if flag == "update" and change.flag == "fix": + change.flag = "update" return change old_node_args: Sequence[ast.expr | None] diff --git a/src/inline_snapshot/plugin/_default_plugin.py b/src/inline_snapshot/plugin/_default_plugin.py index bdbbd517..0291dda8 100644 --- a/src/inline_snapshot/plugin/_default_plugin.py +++ b/src/inline_snapshot/plugin/_default_plugin.py @@ -75,6 +75,50 @@ def builtin_function_handler(self, value, builder: Builder): if isinstance(value, BuiltinFunctionType): return builder.create_code(value.__name__) + @customize + def datetime_handler(self, value, builder: Builder): + import datetime + + if isinstance(value, datetime.datetime): + return builder.create_call( + datetime.datetime, + [value.year, value.month, value.day], + { + "hour": builder.with_default(value.hour, 0), + "minute": builder.with_default(value.minute, 0), + "second": builder.with_default(value.second, 0), + "microsecond": builder.with_default(value.microsecond, 0), + }, + ) + + if isinstance(value, datetime.date): + return builder.create_call( + datetime.date, [value.year, value.month, value.day] + ) + + if isinstance(value, datetime.time): + return builder.create_call( + datetime.time, + [], + { + "hour": builder.with_default(value.hour, 0), + "minute": builder.with_default(value.minute, 0), + "second": builder.with_default(value.second, 0), + "microsecond": builder.with_default(value.microsecond, 0), + }, + ) + + if isinstance(value, datetime.timedelta): + return builder.create_call( + datetime.timedelta, + [], + { + "days": builder.with_default(value.days, 0), + "seconds": builder.with_default(value.seconds, 0), + "microseconds": builder.with_default(value.microseconds, 0), + }, + ) + @customize def path_handler(self, value, builder: Builder): if isinstance(value, Path): diff --git a/tests/external/storage/test_hash.py b/tests/external/storage/test_hash.py index e62bf818..2ccd28af 100644 --- a/tests/external/storage/test_hash.py +++ b/tests/external/storage/test_hash.py @@ -29,3 +29,27 @@ def test_a(): } ), ) + + +def test_same_hash(): + Example( + """\ +from inline_snapshot import external +def test_a(): + assert "a" == external("hash:") + assert "a" == external("hash:") +""", + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + ".inline-snapshot/external/ca978112ca1bbdcafac231b39a23dc4da786eff8147c4e72b9807785afee48bb.txt": "a", + "tests/test_something.py": """\ +from inline_snapshot import external +def test_a(): + assert "a" == external("hash:ca978112ca1b*.txt") + assert "a" == external("hash:ca978112ca1b*.txt") +""", + } + ), + ) diff --git a/tests/test_customize.py b/tests/test_customize.py index 8b709d2c..38a2636c 100644 --- a/tests/test_customize.py +++ b/tests/test_customize.py @@ -322,3 +322,40 @@ def test_a(): } ), ) + + +def test_datetime_types(): + """Test that datetime types generate correct snapshots with proper imports.""" + + Example( + { + "test_something.py": """\ +from datetime import datetime, date, time, timedelta +from inline_snapshot import snapshot + +def test_datetime_types(): + assert snapshot() == datetime(2024, 1, 15, 10, 30, 45, 123456) + assert snapshot() == date(2024, 1, 15) + assert snapshot() == time(10, 30, 45, 123456) + assert snapshot() == timedelta(days=1, hours=2, minutes=30) + assert snapshot() == timedelta(seconds=5, microseconds=123456) +""", + } + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "test_something.py": """\ +from datetime import datetime, date, time, timedelta +from inline_snapshot import snapshot + +def test_datetime_types(): + assert snapshot(datetime(2024, 1, 15, hour=10, minute=30, second=45, microsecond=123456)) == datetime(2024, 1, 15, 10, 30, 45, 123456) + assert snapshot(date(2024, 1, 15)) == date(2024, 1, 15) + assert snapshot(time(hour=10, minute=30, second=45, microsecond=123456)) == time(10, 30, 45, 123456) + assert snapshot(timedelta(days=1, seconds=9000)) == timedelta(days=1, hours=2, minutes=30) + assert snapshot(timedelta(seconds=5, microseconds=123456)) == timedelta(seconds=5, microseconds=123456) +""" + } + ), + ).run_inline() diff --git a/tests/test_dirty_equals.py b/tests/test_dirty_equals.py index 1098ace5..a1231a47 100644 --- a/tests/test_dirty_equals.py +++ b/tests/test_dirty_equals.py @@ -188,3 +188,48 @@ def test_time(): } ), ) + + +def test_is_now_with_delta() -> None: + """Test that IsNow handler works with custom delta parameter in customization.""" + + Example( + { + "conftest.py": """\ +from datetime import timedelta +from inline_snapshot.plugin import customize +from dirty_equals import IsNow + +@customize +def is_now_with_tolerance(value): + if value == IsNow(delta=timedelta(minutes=10)): + return IsNow(delta=timedelta(minutes=10)) +""", + "test_something.py": """\ +from datetime import datetime, timedelta +from inline_snapshot import snapshot + +def test_time(): + # Compare a time from 5 minutes ago with snapshot + past_time = datetime.now() - timedelta(minutes=5) + assert past_time == snapshot() +""", + } + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "test_something.py": """\ +from datetime import datetime, timedelta +from inline_snapshot import snapshot + +from dirty_equals import IsNow + +def test_time(): + # Compare a time from 5 minutes ago with snapshot + past_time = datetime.now() - timedelta(minutes=5) + assert past_time == snapshot(IsNow(delta=timedelta(seconds=600))) +""" + } + ), + ).run_inline() From a537c4cdfbf5e8e84f630d4d9b0171cb4c9d3acd Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sun, 25 Jan 2026 11:30:54 +0100 Subject: [PATCH 66/72] refactor: coverage --- src/inline_snapshot/_change.py | 14 -------------- src/inline_snapshot/_global_state.py | 12 ++++++------ 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/src/inline_snapshot/_change.py b/src/inline_snapshot/_change.py index bf9d4c52..176ec3dd 100644 --- a/src/inline_snapshot/_change.py +++ b/src/inline_snapshot/_change.py @@ -128,20 +128,6 @@ class Delete(Change): old_value: Any -@dataclass() -class AddArgument(Change): - node: ast.Call - - position: int | None - name: str | None - - new_code: str - new_value: Any - - def __post_init__(self): - self.new_code = self.file.format_expression(self.new_code) - - @dataclass() class ListInsert(Change): node: ast.List | ast.Tuple diff --git a/src/inline_snapshot/_global_state.py b/src/inline_snapshot/_global_state.py index 92cb3f3a..360b99e0 100644 --- a/src/inline_snapshot/_global_state.py +++ b/src/inline_snapshot/_global_state.py @@ -92,22 +92,22 @@ def enter_snapshot_context(): try: from .plugin._default_plugin import InlineSnapshotAttrsPlugin - except ImportError: - pass # pragma: no cover + except ImportError: # pragma: no cover + pass else: _current.pm.register(InlineSnapshotAttrsPlugin()) try: from .plugin._default_plugin import InlineSnapshotPydanticPlugin - except ImportError: - pass # pragma: no cover + except ImportError: # pragma: no cover + pass else: _current.pm.register(InlineSnapshotPydanticPlugin()) try: from .plugin._default_plugin import InlineSnapshotDirtyEqualsPlugin - except ImportError: - pass # pragma: no cover + except ImportError: # pragma: no cover + pass else: _current.pm.register(InlineSnapshotDirtyEqualsPlugin()) From 8a3d36362a58fe9574480866f61394cdfb2422cf Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Mon, 26 Jan 2026 05:52:38 +0100 Subject: [PATCH 67/72] refactor: review changes and changelog --- .../20260125_114615_15r10nk-git_customize.md | 40 +++++++++++++++++++ docs/customize_repr.md | 16 ++++---- docs/plugin.md | 22 +++++----- .../_external/_find_external.py | 1 - src/inline_snapshot/_new_adapter.py | 11 +++-- src/inline_snapshot/plugin/_default_plugin.py | 16 ++++---- tests/test_docs.py | 1 - 7 files changed, 73 insertions(+), 34 deletions(-) create mode 100644 changelog.d/20260125_114615_15r10nk-git_customize.md diff --git a/changelog.d/20260125_114615_15r10nk-git_customize.md b/changelog.d/20260125_114615_15r10nk-git_customize.md new file mode 100644 index 00000000..a3f8114b --- /dev/null +++ b/changelog.d/20260125_114615_15r10nk-git_customize.md @@ -0,0 +1,40 @@ + + +### Added + +- Support for import statement generation for all types and user-customized code. +- Added a new way to customize snapshot creation with `@customize`. +- Added a plugin system which allows you to reuse customizations across multiple projects. +- Added built-in handlers for `datetime.datetime`, `date`, `time`, and `timedelta` that generate clean snapshots with proper imports. +- Added support for conditional external storage to automatically store values in external files based on custom criteria (e.g., string length, data size). +- Generates `__file__` instead of the filename string of the current source file. +- Uses dirty-equals `IsNow()` instead of the current datetime when the time value equals the current time. + + + +### Deprecated + +- Deprecated `@customize_repr` which can be replaced with `@customize`. + + + diff --git a/docs/customize_repr.md b/docs/customize_repr.md index 17dea008..3fb3e82a 100644 --- a/docs/customize_repr.md +++ b/docs/customize_repr.md @@ -1,6 +1,6 @@ !!! warning "deprecated" - `@customize_repr` will be removed in the future because `@customize` provides the same and even more features. + `@customize_repr` will be removed in the future because [`@customize`](plugin.md#customize-examples) provides the same and even more features. You should use ``` python title="conftest.py" @@ -24,7 +24,7 @@ That said, what is/was `@customize_repr` for? -`repr()` can be used to convert a python object into a source code representation of the object, but this does not work for every type. +`repr()` can be used to convert a Python object into a source code representation of the object, but this does not work for every type. Here are some examples: ```pycon @@ -69,7 +69,7 @@ def _(value: MyClass): return f"{MyClass.__qualname__}({' '.join(value.values) !r})" ``` -This implementation is then used by inline-snapshot if `repr()` is called during the code generation, but not in normal code. +This implementation is then used by inline-snapshot if `repr()` is called during code generation, but not in normal code. ``` python @@ -83,17 +83,17 @@ def test_my_class(): # normal repr assert repr(e) == "['1', '5', 'hello']" - # the special implementation to convert the Enum into a code + # the special implementation to convert the Enum into code assert e == snapshot(MyClass("1 5 hello")) ``` !!! note - The example above can be better handled with `@customize` as shown in the documentation there. + The example above can be better handled with [`@customize`](plugin.md#customize-examples) as shown in the [plugin documentation](plugin.md). ## customize recursive repr -You can also use `repr()` inside `__repr__()`, if you want to make your own type compatible with inline-snapshot. +You can also use `repr()` inside `__repr__()` if you want to make your own type compatible with inline-snapshot. ``` python @@ -124,7 +124,7 @@ E = Enum("E", ["a", "b"]) def test_enum(): - # the special repr implementation is used recursive here + # the special repr implementation is used recursively here # to convert every Enum to the correct representation assert Pair(E.a, [E.b]) == snapshot(Pair(E.a, [E.b])) ``` @@ -136,7 +136,7 @@ def test_enum(): This implementation allows inline-snapshot to use the custom `repr()` recursively, but it does not allow you to use [unmanaged](/eq_snapshot.md#unmanaged-snapshot-values) snapshot values like `#!python Pair(Is(some_var),5)` -you can also customize the representation of data types in other libraries: +You can also customize the representation of data types in other libraries: ``` python from inline_snapshot import customize_repr diff --git a/docs/plugin.md b/docs/plugin.md index e9de7b67..c07aaede 100644 --- a/docs/plugin.md +++ b/docs/plugin.md @@ -1,6 +1,6 @@ -inline-snapshot provides a plugin architecture based on [pluggy](https://pluggy.readthedocs.io/en/latest/index.html) which can be used to extend and customize it. +inline-snapshot provides a plugin architecture based on [pluggy](https://pluggy.readthedocs.io/en/latest/index.html) that can be used to extend and customize it. ## Overview @@ -27,14 +27,15 @@ Plugins can: ## Plugin Discovery -inline-snapshot loads the plugins at the beginning of the session. +inline-snapshot loads plugins at the beginning of the session. It searches for plugins in: + * installed packages with an `inline-snapshot` entry point * your pytest `conftest.py` files ### Loading Plugins from conftest.py -Loading plugins from the `conftest.py` files is the recommended way when you want to change the behavior of inline-snapshot in your own project. +Loading plugins from `conftest.py` files is the recommended way when you want to change the behavior of inline-snapshot in your own project. Simply use `@customize` on functions directly in your `conftest.py`: @@ -178,7 +179,10 @@ def is_now_handler(value): return IsNow ``` -As explained in the [customize hook specification][inline_snapshot.plugin.InlineSnapshotPluginSpec.customize], you can return types other than Custom objects. Inline-snapshot includes a built-in handler in its default plugin that converts dirty-equals expressions back into source code, which is why you can return `IsNow` directly without using the builder. This approach is much simpler than using `builder.create_call()` for complex dirty-equals expressions. +As explained in the [customize hook specification][inline_snapshot.plugin. +InlineSnapshotPluginSpec.customize], you can return types other than Custom objects. +inline-snapshot includes a built-in handler in its default plugin that converts dirty-equals expressions back into source code, which is why you can return `IsNow` directly without using the builder. +This approach is much simpler than using `builder.create_call()` for complex dirty-equals expressions. ``` python title="test_is_now.py" @@ -191,10 +195,10 @@ def test_is_now(): assert datetime.now() == snapshot(IsNow) ``` -1. Inline-snapshot also creates the imports when they are missing +1. inline-snapshot also creates the imports when they are missing !!! important - Inline-snapshot will never change the dirty-equals expressions in your code because they are [unmanaged](eq_snapshot.md#unmanaged-snapshot-values). + inline-snapshot will never change the dirty-equals expressions in your code because they are [unmanaged](eq_snapshot.md#unmanaged-snapshot-values). Using `@customize` with dirty-equals is a one-way ticket. Once the code is created, inline-snapshot does not know if it was created by inline-snapshot itself or by the user and will not change it when you change the `@customize` implementation, because it has to assume that it was created by the user. @@ -234,7 +238,7 @@ c\ ### Reusing local variables -There are times when your local or global variables become part of your snapshots, like uuids or user names. +There are times when your local or global variables become part of your snapshots, like UUIDs or user names. Customize hooks accept `local_vars` and `global_vars` as arguments that can be used to generate the code. @@ -269,7 +273,7 @@ def test_user(): assert get_data(v_user) == snapshot({"user": v_user, "age": 55}) ``` -Inline-snapshot uses `v_user` because it met the criteria in your customization hook, but not `some_number` because it does not start with `v_`. +inline-snapshot uses `v_user` because it met the criteria in your customization hook, but not `some_number` because it does not start with `v_`. You can also do this only for specific types of objects or for a whitelist of variable names. It is up to you to set the rules that work best in your project. @@ -331,7 +335,7 @@ def secret_handler(value, builder: Builder): The [`create_code()`][inline_snapshot.plugin.Builder.create_code] method takes the desired code representation. The `imports` parameter adds the necessary import statements. -Inline-snapshot will now create the correct code and import statement when you run your tests with `--inline-snapshot=update`. +inline-snapshot will now create the correct code and import statement when you run your tests with `--inline-snapshot=update`. ``` python hl_lines="4 5 9" diff --git a/src/inline_snapshot/_external/_find_external.py b/src/inline_snapshot/_external/_find_external.py index 93011a27..8c2c697a 100644 --- a/src/inline_snapshot/_external/_find_external.py +++ b/src/inline_snapshot/_external/_find_external.py @@ -116,7 +116,6 @@ def ensure_import( module_imports: Set[str], recorder: ChangeRecorder, ): - print("file", filename) source = Source.for_filename(filename) change = recorder.new_change() diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index b0362ce6..f8456bab 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -331,7 +331,6 @@ def compare_CustomDict( if isinstance(old_node, ast.Dict): node = old_node.values[list(old_value.value.keys()).index(key)] else: - assert False node = None # check values with same keys result[key] = yield from self.compare( @@ -426,8 +425,8 @@ def intercept(change): ) if old_args_len < len(new_args): - for insert_pos, value in list(enumerate(new_args))[old_args_len:]: - new_code = yield from value._code_repr(self.context) + for insert_pos, insert_value in list(enumerate(new_args))[old_args_len:]: + new_code = yield from insert_value._code_repr(self.context) yield CallArg( flag=flag, file=self.context.file, @@ -435,7 +434,7 @@ def intercept(change): arg_pos=insert_pos, arg_name=None, new_code=new_code, - new_value=value, + new_value=insert_value, ) # keyword arguments @@ -482,14 +481,14 @@ def intercept(change): ) if to_insert: - for key, value in to_insert: + for insert_key, value in to_insert: new_code = yield from value._code_repr(self.context) yield CallArg( flag=flag, file=self.context.file, node=old_node, arg_pos=insert_pos, - arg_name=key, + arg_name=insert_key, new_code=new_code, new_value=value, ) diff --git a/src/inline_snapshot/plugin/_default_plugin.py b/src/inline_snapshot/plugin/_default_plugin.py index 0291dda8..9c11bc7d 100644 --- a/src/inline_snapshot/plugin/_default_plugin.py +++ b/src/inline_snapshot/plugin/_default_plugin.py @@ -1,4 +1,5 @@ import ast +import datetime from collections import Counter from collections import defaultdict from dataclasses import MISSING @@ -77,7 +78,6 @@ def builtin_function_handler(self, value, builder: Builder): @customize def datetime_handler(self, value, builder: Builder): - import datetime if isinstance(value, datetime.datetime): return builder.create_call( @@ -133,10 +133,12 @@ def sort_set_values(self, set_values): set_values = sorted(set_values) is_sorted = True except TypeError: + # can not be sorted by value pass set_values = list(map(repr, set_values)) if not is_sorted: + # sort by string representation set_values = sorted(set_values) return set_values @@ -267,15 +269,11 @@ def outsource_handler(self, value, builder: Builder): try: - pass + import dirty_equals except ImportError: # pragma: no cover - pass - else: - import datetime - from dirty_equals import IsNow from dirty_equals._utils import Omit class InlineSnapshotDirtyEqualsPlugin: @@ -295,7 +293,7 @@ def dirty_equals_handler(self, value, builder: Builder): kwargs = { k: a for k, a in value._repr_kwargs.items() if a is not Omit } - if type(value) == IsNow: + if type(value) == dirty_equals.IsNow: kwargs.pop("approx") if ( isinstance(delta := kwargs["delta"], datetime.timedelta) @@ -306,8 +304,8 @@ def dirty_equals_handler(self, value, builder: Builder): @customize(tryfirst=True) def is_now_handler(self, value, builder: Builder): - if value == IsNow(): - return IsNow() + if value == dirty_equals.IsNow(): + return dirty_equals.IsNow() try: diff --git a/tests/test_docs.py b/tests/test_docs.py index 1a4a9703..9c85ae2f 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -447,7 +447,6 @@ def test_block(block: Block): ) linenum = 1 - hl_lines = "" if last_code is not None and "first_block" not in options: changed_lines = [] From 93dd721ab1af7536e77eeed34fc0e16121c32249 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Fri, 30 Jan 2026 08:31:16 +0100 Subject: [PATCH 68/72] fix: fixes for pydantic-ai --- src/inline_snapshot/_customize/_builder.py | 14 ++++++- src/inline_snapshot/_utils.py | 6 +++ tests/test_pydantic.py | 46 ++++++++++++++++++++++ 3 files changed, 65 insertions(+), 1 deletion(-) diff --git a/src/inline_snapshot/_customize/_builder.py b/src/inline_snapshot/_customize/_builder.py index d58c07fc..ea75ecb6 100644 --- a/src/inline_snapshot/_customize/_builder.py +++ b/src/inline_snapshot/_customize/_builder.py @@ -51,7 +51,19 @@ def _get_handler(self, v) -> Custom: result.__dict__["original_value"] = v if not isinstance(v, Custom) and self._build_new_value: - if result._eval() != v: + is_same = False + v_eval = result._eval() + + if ( + hasattr(v, "__pydantic_generic_metadata__") + and v.__pydantic_generic_metadata__["origin"] == v_eval + ): + is_same = True + + if not is_same and v_eval == v: + is_same = True + + if not is_same: raise UsageError( f"""\ Customized value does not match original value: diff --git a/src/inline_snapshot/_utils.py b/src/inline_snapshot/_utils.py index 0ba40271..d1403c42 100644 --- a/src/inline_snapshot/_utils.py +++ b/src/inline_snapshot/_utils.py @@ -3,6 +3,9 @@ import token from collections import namedtuple from pathlib import Path +from types import BuiltinFunctionType +from types import FunctionType +from types import MethodType from inline_snapshot._exceptions import UsageError @@ -149,6 +152,9 @@ def __eq__(self, other): def clone(obj): + if isinstance(obj, (type, FunctionType, BuiltinFunctionType, MethodType)): + return obj + new = copy.deepcopy(obj) if not obj == new: raise UsageError( diff --git a/tests/test_pydantic.py b/tests/test_pydantic.py index fdb9f016..de11ca65 100644 --- a/tests/test_pydantic.py +++ b/tests/test_pydantic.py @@ -1,3 +1,6 @@ +import pydantic +import pytest + from inline_snapshot import snapshot from inline_snapshot.testing import Example @@ -181,3 +184,46 @@ def test_something(): ), returncode=1, ) + + +@pytest.mark.skipif( + pydantic.version.VERSION.startswith("1."), + reason="pydantic 1 can not compare C[int]() with C()", +) +def test_pydantic_generic_class(): + Example( + """\ +from typing import Generic, TypeVar +from inline_snapshot import snapshot +from pydantic import BaseModel + +I=TypeVar("I") +class C(BaseModel,Generic[I]): + a:int + +def test_a(): + c=C[int](a=5) + + assert c == snapshot() +""" + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from typing import Generic, TypeVar +from inline_snapshot import snapshot +from pydantic import BaseModel + +I=TypeVar("I") +class C(BaseModel,Generic[I]): + a:int + +def test_a(): + c=C[int](a=5) + + assert c == snapshot(C(a=5)) +""" + } + ), + ).run_inline() From d93a47070d83c9bb8b1ec8bf185bb889659976b4 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Tue, 3 Feb 2026 09:03:55 +0100 Subject: [PATCH 69/72] fix: coverage --- src/inline_snapshot/_external/_external_file.py | 4 +++- src/inline_snapshot/_new_adapter.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/inline_snapshot/_external/_external_file.py b/src/inline_snapshot/_external/_external_file.py index e2d64d75..978e6d2f 100644 --- a/src/inline_snapshot/_external/_external_file.py +++ b/src/inline_snapshot/_external/_external_file.py @@ -39,7 +39,9 @@ def _load_value(self): try: return self._format.decode(self._filename) except FileNotFoundError: - raise StorageLookupError(f"can not read {self._filename}") + raise StorageLookupError( + f"can not read {self._filename}", files=[self._filename] + ) def external_file(path: Union[Path, str], *, format: Optional[str] = None): diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py index f8456bab..e73cb871 100644 --- a/src/inline_snapshot/_new_adapter.py +++ b/src/inline_snapshot/_new_adapter.py @@ -158,7 +158,7 @@ def compare( else True ) and ( - isinstance(old_value, (CustomCall, CustomSequence)) + isinstance(old_value, (CustomCall, CustomSequence, CustomDict)) if old_node is None else True ) From 0ded4c7812c986b0c5688aab2bd23933d1de615c Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sun, 8 Feb 2026 06:58:59 +0100 Subject: [PATCH 70/72] fix: fix corner case for 3.10 --- src/inline_snapshot/_customize/_builder.py | 24 ++++++--- .../_snapshot/generic_value.py | 4 +- .../_snapshot/undecided_value.py | 54 +++++++++++++++++++ tests/test_without_node.py | 30 +++++++++++ 4 files changed, 104 insertions(+), 8 deletions(-) create mode 100644 tests/test_without_node.py diff --git a/src/inline_snapshot/_customize/_builder.py b/src/inline_snapshot/_customize/_builder.py index ea75ecb6..991be12f 100644 --- a/src/inline_snapshot/_customize/_builder.py +++ b/src/inline_snapshot/_customize/_builder.py @@ -28,6 +28,13 @@ class Builder: _snapshot_context: AdapterContext _build_new_value: bool = False + _recursive: bool = True + + def _get_handler_recursive(self, v) -> Custom: + if self._recursive: + return self._get_handler(v) + else: + return v def _get_handler(self, v) -> Custom: @@ -93,7 +100,7 @@ def create_list(self, value: list) -> Custom: `create_list([1,2,3])` becomes `[1,2,3]` in the code. List elements don't have to be Custom nodes and are converted by inline-snapshot if needed. """ - custom = [self._get_handler(v) for v in value] + custom = [self._get_handler_recursive(v) for v in value] return CustomList(value=custom) def create_tuple(self, value: tuple) -> Custom: @@ -103,7 +110,7 @@ def create_tuple(self, value: tuple) -> Custom: `create_tuple((1, 2, 3))` becomes `(1, 2, 3)` in the code. Tuple elements don't have to be Custom nodes and are converted by inline-snapshot if needed. """ - custom = [self._get_handler(v) for v in value] + custom = [self._get_handler_recursive(v) for v in value] return CustomTuple(value=custom) def with_default(self, value: Any, default: Any): @@ -117,7 +124,7 @@ def with_default(self, value: Any, default: Any): raise UsageError("default value can not be an Custom value") if value == default: - return CustomDefault(value=self._get_handler(value)) + return CustomDefault(value=self._get_handler_recursive(value)) return value def create_call( @@ -129,9 +136,9 @@ def create_call( `create_call(MyClass, [arg1, arg2], {'key': value})` becomes `MyClass(arg1, arg2, key=value)` in the code. Function, arguments, and keyword arguments don't have to be Custom nodes and are converted by inline-snapshot if needed. """ - function = self._get_handler(function) - posonly_args = [self._get_handler(arg) for arg in posonly_args] - kwargs = {k: self._get_handler(arg) for k, arg in kwargs.items()} + function = self._get_handler_recursive(function) + posonly_args = [self._get_handler_recursive(arg) for arg in posonly_args] + kwargs = {k: self._get_handler_recursive(arg) for k, arg in kwargs.items()} return CustomCall( function=function, @@ -146,7 +153,10 @@ def create_dict(self, value: dict) -> Custom: `create_dict({'key': 'value'})` becomes `{'key': 'value'}` in the code. Dict keys and values don't have to be Custom nodes and are converted by inline-snapshot if needed. """ - custom = {self._get_handler(k): self._get_handler(v) for k, v in value.items()} + custom = { + self._get_handler_recursive(k): self._get_handler_recursive(v) + for k, v in value.items() + } return CustomDict(value=custom) @cached_property diff --git a/src/inline_snapshot/_snapshot/generic_value.py b/src/inline_snapshot/_snapshot/generic_value.py index 01f17ba5..39a157a5 100644 --- a/src/inline_snapshot/_snapshot/generic_value.py +++ b/src/inline_snapshot/_snapshot/generic_value.py @@ -58,7 +58,9 @@ def value_to_custom(self, value): return value if self._ast_node is None: - return self.to_custom(value) + from inline_snapshot._snapshot.undecided_value import ValueToCustom + + return ValueToCustom(self._context).convert(value) else: from inline_snapshot._snapshot.undecided_value import AstToCustom diff --git a/src/inline_snapshot/_snapshot/undecided_value.py b/src/inline_snapshot/_snapshot/undecided_value.py index b9bd1055..d7dcaaff 100644 --- a/src/inline_snapshot/_snapshot/undecided_value.py +++ b/src/inline_snapshot/_snapshot/undecided_value.py @@ -2,9 +2,12 @@ from typing import Any from typing import Iterator +from inline_snapshot._code_repr import mock_repr from inline_snapshot._compare_context import compare_only +from inline_snapshot._customize._builder import Builder from inline_snapshot._customize._custom import Custom from inline_snapshot._customize._custom_call import CustomCall +from inline_snapshot._customize._custom_call import CustomDefault from inline_snapshot._customize._custom_code import CustomCode from inline_snapshot._customize._custom_dict import CustomDict from inline_snapshot._customize._custom_sequence import CustomList @@ -21,6 +24,7 @@ class AstToCustom: + context: AdapterContext def __init__(self, context): self.eval = context.eval @@ -69,6 +73,56 @@ def convert_Dict(self, value: dict, node: ast.Dict): ) +class ValueToCustom: + """this implementation is for cpython <= 3.10 only. + It works similar to AstToCustom but can not handle calls + """ + + context: AdapterContext + + def __init__(self, context): + self.context = context + + def convert(self, value: Any): + if is_unmanaged(value): + return CustomUnmanaged(value) + + if isinstance(value, CustomDefault): + return self.convert(value.value) + + t = type(value).__name__ + return getattr(self, "convert_" + t, self.convert_generic)(value) + + def convert_generic(self, value: Any) -> Custom: + if value is ...: + return CustomUndefined() + else: + with mock_repr(self.context): + result = Builder(self.context, _recursive=False)._get_handler(value) + if isinstance(result, CustomCall) and result.function == type(value): + function = self.convert(result.function) + posonly_args = [self.convert(arg) for arg in result.args] + kwargs = {k: self.convert(arg) for k, arg in result.kwargs.items()} + + return CustomCall( + function=function, + args=posonly_args, + kwargs=kwargs, + ) + return CustomCode(value, "") + + def convert_list(self, value: list): + return CustomList([self.convert(v) for v in value]) + + def convert_tuple(self, value: tuple): + return CustomTuple([self.convert(v) for v in value]) + + def convert_dict(self, value: dict): + return CustomDict( + {self.convert(k): self.convert(v) for (k, v) in value.items()} + ) + + class UndecidedValue(GenericValue): def __init__(self, old_value, ast_node, context: AdapterContext): self._context = context diff --git a/tests/test_without_node.py b/tests/test_without_node.py new file mode 100644 index 00000000..24e08192 --- /dev/null +++ b/tests/test_without_node.py @@ -0,0 +1,30 @@ +import pytest +from executing import is_pytest_compatible + +from inline_snapshot.testing import Example + + +@pytest.mark.skipIf( + is_pytest_compatible, reason="this is only a problem when executing can return None" +) +def test_without_node(): + + Example( + { + "conftest.py": """\ +from inline_snapshot.plugin import customize + +@customize +def handler(value,builder): + if value=="foo": + return builder.create_code("'foo'") +""", + "test_example.py": """\ +from inline_snapshot import snapshot +from dirty_equals import IsStr + +def test_foo(): + assert "not_foo" == snapshot(IsStr()) +""", + } + ).run_pytest() From c3953c90bcbee699d148005c7190e70138cbc3b0 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Mon, 9 Feb 2026 09:07:51 +0100 Subject: [PATCH 71/72] fix: original_value in custom child --- src/inline_snapshot/_customize/_builder.py | 2 +- tests/test_customize.py | 45 ++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/src/inline_snapshot/_customize/_builder.py b/src/inline_snapshot/_customize/_builder.py index 991be12f..05dd9eaa 100644 --- a/src/inline_snapshot/_customize/_builder.py +++ b/src/inline_snapshot/_customize/_builder.py @@ -55,7 +55,7 @@ def _get_handler(self, v) -> Custom: else: result = r - result.__dict__["original_value"] = v + result.__dict__["original_value"] = v._eval() if isinstance(v, Custom) else v if not isinstance(v, Custom) and self._build_new_value: is_same = False diff --git a/tests/test_customize.py b/tests/test_customize.py index 38a2636c..a27b2f29 100644 --- a/tests/test_customize.py +++ b/tests/test_customize.py @@ -359,3 +359,48 @@ def test_datetime_types(): } ), ).run_inline() + + +def test_custom_children(): + Example( + { + "c.py": """\ +from dataclasses import dataclass + +@dataclass +class C: + i:int +""", + "conftest.py": """\ +from inline_snapshot.plugin import customize +from c import C + +@customize +def handler(value,builder): + if isinstance(value,C) and value.i==2: + return builder.create_call(C,[],{"i":builder.create_code("1+1")}) + """, + "test_example.py": """\ +from inline_snapshot import snapshot +from c import C + +def test(): + assert C(i=2) == snapshot() +""", + } + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "test_example.py": """\ +from inline_snapshot import snapshot +from c import C + +def test(): + assert C(i=2) == snapshot(C(i=1 + 1)) +""" + } + ), + ).run_inline( + ["--inline-snapshot=fix"] + ) From 6aeb712ba266cd30a610fa375ef53b843768438c Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Tue, 10 Feb 2026 19:26:15 +0100 Subject: [PATCH 72/72] fix: coverage --- tests/test_without_node.py | 44 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/test_without_node.py b/tests/test_without_node.py index 24e08192..5f13d9c5 100644 --- a/tests/test_without_node.py +++ b/tests/test_without_node.py @@ -1,6 +1,7 @@ import pytest from executing import is_pytest_compatible +from inline_snapshot import snapshot from inline_snapshot.testing import Example @@ -28,3 +29,46 @@ def test_foo(): """, } ).run_pytest() + + +def test_custom_default_case_in_ValueToCustom(executing_used): + Example( + """\ +from inline_snapshot import snapshot +from dataclasses import dataclass + +@dataclass +class A: + a:int=5 + +def test_something(): + assert A(a=3) == snapshot(A(a=5)),"not equal" +""" + ).run_inline( + changed_files=snapshot({}), + raises=snapshot( + """\ +AssertionError: +not equal\ +""" + ), + ) + + +def test_tuple_case_in_ValueToCustom(executing_used): + Example( + """\ +from inline_snapshot import snapshot +from dataclasses import dataclass + +@dataclass +class A: + a:int=5 + +def test_something(): + assert (1,2) == snapshot((1,2)),"not equal" +""" + ).run_inline( + changed_files=snapshot({}), + raises=snapshot(None), + )