diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4c794ad2..caa69a99 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -122,7 +122,8 @@ jobs: uvx coverage html --skip-covered --skip-empty # Report and write to summary. - uvx coverage report --format=markdown >> $GITHUB_STEP_SUMMARY + uvx coverage report --format=markdown --skip-covered --skip-empty -m >> $GITHUB_STEP_SUMMARY + uvx coverage report --skip-covered --skip-empty -m > htmlcov/report.txt # Report again and fail if under 100%. uvx coverage report --fail-under=100 diff --git a/README.md b/README.md index 817c0b86..5fdabde2 100644 --- a/README.md +++ b/README.md @@ -85,9 +85,9 @@ def test_something(): The following examples show how you can use inline-snapshot in your tests. Take a look at the [documentation](https://15r10nk.github.io/inline-snapshot/latest) if you want to know more. - + ``` python -from inline_snapshot import snapshot, outsource, external +from inline_snapshot import external, outsource, snapshot def test_something(): @@ -135,9 +135,9 @@ strings\ ``` python -from inline_snapshot import snapshot import subprocess as sp import sys +from inline_snapshot import snapshot def run_python(cmd, stdout=None, stderr=None): diff --git a/changelog.d/20251201_200505_15r10nk-git_customize.md b/changelog.d/20251201_200505_15r10nk-git_customize.md new file mode 100644 index 00000000..b6cc6f8a --- /dev/null +++ b/changelog.d/20251201_200505_15r10nk-git_customize.md @@ -0,0 +1,3 @@ +### Added + +- `pathlib.Path/PurePath` values are now never stored as `Posix/WindowsPath` or their Pure variants, which improves the writing of platform independent tests. diff --git a/changelog.d/20260125_114615_15r10nk-git_customize.md b/changelog.d/20260125_114615_15r10nk-git_customize.md new file mode 100644 index 00000000..a3f8114b --- /dev/null +++ b/changelog.d/20260125_114615_15r10nk-git_customize.md @@ -0,0 +1,40 @@ + + +### Added + +- Support for import statement generation for all types and user-customized code. +- Added a new way to customize snapshot creation with `@customize`. +- Added a plugin system which allows you to reuse customizations across multiple projects. +- Added built-in handlers for `datetime.datetime`, `date`, `time`, and `timedelta` that generate clean snapshots with proper imports. +- Added support for conditional external storage to automatically store values in external files based on custom criteria (e.g., string length, data size). +- Generates `__file__` instead of the filename string of the current source file. +- Uses dirty-equals `IsNow()` instead of the current datetime when the time value equals the current time. + + + +### Deprecated + +- Deprecated `@customize_repr` which can be replaced with `@customize`. + + + diff --git a/docs/categories.md b/docs/categories.md index 55c31e48..157a5ba2 100644 --- a/docs/categories.md +++ b/docs/categories.md @@ -184,8 +184,10 @@ It is recommended to use trim only if you run your complete test suite. ### Update Changes in the update category do not change the value of the snapshot, just its representation. -These updates are not shown by default in your reports and can be enabled with [show-updates](configuration.md/#show-updates). -The reason might be that `#!python repr()` of the object has changed or that inline-snapshot provides some new logic which changes the representation. Like with the strings in the following example: +These updates are not shown by default in your reports, because it can be confusing for users who uses inline-snapshot the first time or want to change the snapshot values manual. +Updates can be enabled with [show-updates](configuration.md/#show-updates). + +The reason for updates might be that `#!python repr()` of the object has changed or that inline-snapshot provides some new logic which changes the representation. Like with the strings in the following example: === "original" @@ -254,7 +256,7 @@ The reason might be that `#!python repr()` of the object has changed or that inl ``` -The approval of this type of changes is easier, because inline-snapshot assures that the value has not changed. +The approval of this type of changes is easier, because the update category assures that the value has not changed. The goal of inline-snapshot is to generate the values for you in the correct format so that no manual editing is required. This improves your productivity and saves time. @@ -270,5 +272,5 @@ You can agree with inline-snapshot and accept the changes or you can use one of 4. you can also open an [issue](https://github.com/15r10nk/inline-snapshot/issues?q=is%3Aissue%20state%3Aopen%20label%3Aupdate_related) if you have a specific problem with the way inline-snapshot generates the code. -!!! note: - [#177](https://github.com/15r10nk/inline-snapshot/issues/177) will give the developer more control about how snapshots are created. *update* will them become much more useful. +!!! note + [#177](https://github.com/15r10nk/inline-snapshot/issues/177) will give the developer more control about how snapshots are created. *update* will then become much more useful. diff --git a/docs/code_generation.md b/docs/code_generation.md index 3556780b..9a472640 100644 --- a/docs/code_generation.md +++ b/docs/code_generation.md @@ -7,8 +7,8 @@ It might be necessary to import the right modules to match the `repr()` output. === "original code" ``` python - from inline_snapshot import snapshot import datetime + from inline_snapshot import snapshot def something(): @@ -29,8 +29,8 @@ It might be necessary to import the right modules to match the `repr()` output. === "--inline-snapshot=create" ``` python hl_lines="18 19 20 21 22 23 24 25 26 27 28" - from inline_snapshot import snapshot import datetime + from inline_snapshot import snapshot def something(): diff --git a/docs/customize_repr.md b/docs/customize_repr.md index 5b01c13f..3fb3e82a 100644 --- a/docs/customize_repr.md +++ b/docs/customize_repr.md @@ -1,7 +1,30 @@ +!!! warning "deprecated" + `@customize_repr` will be removed in the future because [`@customize`](plugin.md#customize-examples) provides the same and even more features. + You should use + ``` python title="conftest.py" + class InlineSnapshotPlugin: + @customize + def my_class_handler(value, builder): + if isinstance(value, MyClass): + return builder.create_code("my_class_repr") + ``` -`repr()` can be used to convert a python object into a source code representation of the object, but this does not work for every type. + instead of + + ``` python title="conftest.py" + @customize_repr + def my_class_handler(value: MyClass): + return "my_class_repr" + ``` + + `@customize` allows you not only to generate code but also imports and function calls which can be analysed by inline-snapshot. + + +That said, what is/was `@customize_repr` for? + +`repr()` can be used to convert a Python object into a source code representation of the object, but this does not work for every type. Here are some examples: ```pycon @@ -16,64 +39,66 @@ Here are some examples: `customize_repr` can be used to overwrite the default `repr()` behaviour. -The implementation for `Enum` looks like this: +The implementation for `MyClass` could look like this: -``` python exec="1" result="python" -print('--8<-- "src/inline_snapshot/_code_repr.py:Enum"') -``` + +``` python title="my_class.py" +class MyClass: + def __init__(self, values): + self.values = values.split() -This implementation is then used by inline-snapshot if `repr()` is called during the code generation, but not in normal code. + def __repr__(self): + return repr(self.values) - -``` python -from inline_snapshot import snapshot -from enum import Enum + def __eq__(self, other): + if not isinstance(other, MyClass): + return NotImplemented + return self.values == other.values +``` +You can specify the `repr()` used by inline-snapshot in your *conftest.py* -def test_enum(): - E = Enum("E", ["a", "b"]) + +``` python title="conftest.py" +from my_class import MyClass +from inline_snapshot import customize_repr - # normal repr - assert repr(E.a) == "" - # the special implementation to convert the Enum into a code - assert E.a == snapshot(E.a) +@customize_repr +def _(value: MyClass): + return f"{MyClass.__qualname__}({' '.join(value.values) !r})" ``` -## built-in data types +This implementation is then used by inline-snapshot if `repr()` is called during code generation, but not in normal code. -inline-snapshot comes with a special implementation for the following types: + +``` python +from my_class import MyClass +from inline_snapshot import snapshot -``` python exec="1" -from inline_snapshot._code_repr import code_repr_dispatch, code_repr -for name, obj in sorted( - ( - getattr( - obj, "_inline_snapshot_name", f"{obj.__module__}.{obj.__qualname__}" - ), - obj, - ) - for obj in code_repr_dispatch.registry.keys() -): - if obj is not object: - print(f"- `{name}`") -``` +def test_my_class(): + e = MyClass("1 5 hello") -Please open an [issue](https://github.com/15r10nk/inline-snapshot/issues) if you found a built-in type which is not supported by inline-snapshot. + # normal repr + assert repr(e) == "['1', '5', 'hello']" + + # the special implementation to convert the Enum into code + assert e == snapshot(MyClass("1 5 hello")) +``` !!! note - Container types like `dict`, `list`, `tuple` or `dataclass` are handled in a different way, because inline-snapshot also needs to inspect these types to implement [unmanaged](/eq_snapshot.md#unmanaged-snapshot-values) snapshot values. + The example above can be better handled with [`@customize`](plugin.md#customize-examples) as shown in the [plugin documentation](plugin.md). ## customize recursive repr -You can also use `repr()` inside `__repr__()`, if you want to make your own type compatible with inline-snapshot. +You can also use `repr()` inside `__repr__()` if you want to make your own type compatible with inline-snapshot. ``` python -from inline_snapshot import snapshot from enum import Enum +from inline_snapshot import snapshot class Pair: @@ -94,10 +119,12 @@ class Pair: return self.a == other.a and self.b == other.b +E = Enum("E", ["a", "b"]) + + def test_enum(): - E = Enum("E", ["a", "b"]) - # the special repr implementation is used recursive here + # the special repr implementation is used recursively here # to convert every Enum to the correct representation assert Pair(E.a, [E.b]) == snapshot(Pair(E.a, [E.b])) ``` @@ -109,7 +136,7 @@ def test_enum(): This implementation allows inline-snapshot to use the custom `repr()` recursively, but it does not allow you to use [unmanaged](/eq_snapshot.md#unmanaged-snapshot-values) snapshot values like `#!python Pair(Is(some_var),5)` -you can also customize the representation of data types in other libraries: +You can also customize the representation of data types in other libraries: ``` python from inline_snapshot import customize_repr diff --git a/docs/eq_snapshot.md b/docs/eq_snapshot.md index d5d3d4ad..fd3e1df7 100644 --- a/docs/eq_snapshot.md +++ b/docs/eq_snapshot.md @@ -117,8 +117,8 @@ You could initialize a test like this: ``` python -from inline_snapshot import snapshot import datetime +from inline_snapshot import snapshot def get_data(): @@ -135,31 +135,11 @@ def test_function(): If you use `--inline-snapshot=create`, inline-snapshot will record the current `datetime` in the snapshot: -``` python hl_lines="13 14 15" -from inline_snapshot import snapshot +``` python hl_lines="4 5 15" import datetime - - -def get_data(): - return { - "date": datetime.datetime.utcnow(), - "payload": "some data", - } - - -def test_function(): - assert get_data() == snapshot( - {"date": datetime.datetime(2024, 3, 14, 0, 0), "payload": "some data"} - ) -``` - -To avoid the test failing in future runs, replace the `datetime` with [dirty-equals' `IsDatetime()`](https://dirty-equals.helpmanual.io/latest/types/datetime/#dirty_equals.IsDatetime): - - -``` python from inline_snapshot import snapshot -from dirty_equals import IsDatetime -import datetime + +from dirty_equals import IsNow def get_data(): @@ -170,21 +150,17 @@ def get_data(): def test_function(): - assert get_data() == snapshot( - { - "date": IsDatetime(), - "payload": "some data", - } - ) + assert get_data() == snapshot({"date": IsNow(), "payload": "some data"}) ``` +Inline-snapshot uses [dirty-equals `IsNow()`](https://dirty-equals.helpmanual.io/latest/types/datetime/#dirty_equals.IsDatetime) by default when the value is equal to the current time to avoid the test failing in future runs. Say a different part of the return data changes, such as the `payload` value: -``` python hl_lines="9" -from inline_snapshot import snapshot -from dirty_equals import IsDatetime +``` python hl_lines="3 9 14 15 16 17 18 19" import datetime +from dirty_equals import IsNow +from inline_snapshot import snapshot def get_data(): @@ -197,7 +173,7 @@ def get_data(): def test_function(): assert get_data() == snapshot( { - "date": IsDatetime(), + "date": IsNow(), "payload": "some data", } ) @@ -207,9 +183,9 @@ Re-running the test with `--inline-snapshot=fix` will update the snapshot to mat ``` python hl_lines="17" -from inline_snapshot import snapshot -from dirty_equals import IsDatetime import datetime +from dirty_equals import IsNow +from inline_snapshot import snapshot def get_data(): @@ -222,7 +198,7 @@ def get_data(): def test_function(): assert get_data() == snapshot( { - "date": IsDatetime(), + "date": IsNow(), "payload": "data changed for some good reason", } ) @@ -270,7 +246,7 @@ It tells inline-snapshot that the developer wants control over some part of the ``` python -from inline_snapshot import snapshot, Is +from inline_snapshot import Is, snapshot current_version = "1.5" @@ -292,7 +268,7 @@ The snapshot does not need to be fixed when `current_version` changes in the fut === "original code" ``` python - from inline_snapshot import snapshot, Is + from inline_snapshot import Is, snapshot def test_function(): @@ -303,7 +279,7 @@ The snapshot does not need to be fixed when `current_version` changes in the fut === "--inline-snapshot=fix" ``` python hl_lines="6" - from inline_snapshot import snapshot, Is + from inline_snapshot import Is, snapshot def test_function(): @@ -326,7 +302,7 @@ The following example shows how this can be used to run a tests with two differe === "my_lib.py v1" - ``` python + ``` python title="my_lib.py" version = 1 @@ -337,7 +313,7 @@ The following example shows how this can be used to run a tests with two differe === "my_lib.py v2" - ``` python + ``` python title="my_lib.py" version = 2 @@ -348,8 +324,8 @@ The following example shows how this can be used to run a tests with two differe ``` python +from my_lib import get_schema, version from inline_snapshot import snapshot -from my_lib import version, get_schema def test_function(): @@ -368,8 +344,8 @@ The advantage of this approach is that the test uses always the correct values f You can also extract the version logic into its own function. ``` python -from inline_snapshot import snapshot, Snapshot -from my_lib import version, get_schema +from my_lib import get_schema, version +from inline_snapshot import Snapshot, snapshot def version_snapshot(v1: Snapshot, v2: Snapshot): diff --git a/docs/external/external.md b/docs/external/external.md index bfecbe45..b7ac463f 100644 --- a/docs/external/external.md +++ b/docs/external/external.md @@ -64,7 +64,7 @@ The `external()` function can also be used inside other data structures. ``` python -from inline_snapshot import snapshot, external +from inline_snapshot import external, snapshot def test_something(): @@ -75,7 +75,7 @@ def test_something(): ``` python hl_lines="6 7 8 9 10 11 12 13" -from inline_snapshot import snapshot, external +from inline_snapshot import external, snapshot def test_something(): @@ -176,7 +176,7 @@ You must specify the suffix in this case. ``` python -from inline_snapshot import register_format_alias, external +from inline_snapshot import external, register_format_alias register_format_alias(".html", ".txt") @@ -189,7 +189,7 @@ inline-snapshot uses the given suffix to create an external snapshot. ``` python hl_lines="7 8 9" -from inline_snapshot import register_format_alias, external +from inline_snapshot import external, register_format_alias register_format_alias(".html", ".txt") diff --git a/docs/external/outsource.md b/docs/external/outsource.md index cde6bb58..4c50323f 100644 --- a/docs/external/outsource.md +++ b/docs/external/outsource.md @@ -26,7 +26,7 @@ def test_captcha(): inline-snapshot always generates an external object in this case. -``` python hl_lines="3 4 20 21 22 23 24 25 26" +``` python hl_lines="3 4 20 21 22 23 24 25 26 27 28" from inline_snapshot import outsource, register_format_alias, snapshot from inline_snapshot import external @@ -50,7 +50,9 @@ def test_captcha(): { "size": "200x100", "difficulty": 8, - "picture": external("hash:0da2cc316111*.png"), + "picture": external( + "uuid:f728b4fa-4248-4e3a-8a5d-2f346baa9455.png" + ), } ) ``` diff --git a/docs/external/register_format.md b/docs/external/register_format.md index 216ec030..c0e72cbb 100644 --- a/docs/external/register_format.md +++ b/docs/external/register_format.md @@ -129,8 +129,8 @@ The custom format is then used every time a `NumberSet` is compared with an empt === "example" ``` python - from inline_snapshot import external from number_set import NumberSet + from inline_snapshot import external def test(): diff --git a/docs/plugin.md b/docs/plugin.md new file mode 100644 index 00000000..c07aaede --- /dev/null +++ b/docs/plugin.md @@ -0,0 +1,377 @@ + + +inline-snapshot provides a plugin architecture based on [pluggy](https://pluggy.readthedocs.io/en/latest/index.html) that can be used to extend and customize it. + +## Overview + +Plugins allow you to customize how inline-snapshot generates code for your snapshots. The primary use case is implementing custom representation logic through the `@customize` hook, which controls how Python objects are converted into source code. + +### When to Use Plugins + +You should consider creating a plugin when: + +- You find yourself manually editing snapshots after they are generated +- You want to use custom constructors or factory methods in your snapshots +- You need to reference local/global variables instead of hardcoding values +- You want to store certain values in external files based on specific criteria +- You need special code representations for your custom types + +### Plugin Capabilities + +Plugins can: + +- **Customize code generation**: Control how objects appear in snapshot code (e.g., use `Color.RED` instead of `Color(255, 0, 0)`) +- **Reference variables**: Use existing local or global variables in snapshots instead of literals +- **External storage**: Automatically store large or sensitive values in external files +- **Import management**: Automatically add necessary import statements to test files + +## Plugin Discovery + +inline-snapshot loads plugins at the beginning of the session. +It searches for plugins in: + +* installed packages with an `inline-snapshot` entry point +* your pytest `conftest.py` files + +### Loading Plugins from conftest.py + +Loading plugins from `conftest.py` files is the recommended way when you want to change the behavior of inline-snapshot in your own project. + +Simply use `@customize` on functions directly in your `conftest.py`: + +``` python +from inline_snapshot.plugin import customize + + +@customize +def my_handler(value, builder): + # your logic + pass +``` + +All customizations defined in your `conftest.py` are active globally for all your tests. + +### Creating a Plugin Package + +To distribute inline-snapshot plugins as a package, register your plugin class using the `inline-snapshot` entry point in your `setup.py` or `pyproject.toml`: + +=== "pyproject.toml (recommended)" + ``` toml + [project.entry-points.inline_snapshot] + my_plugin = "my_package.plugin:MyInlineSnapshotPlugin" + ``` + +=== "setup.py" + ``` python + setup( + name="my-inline-snapshot-plugin", + entry_points={ + "inline_snapshot": [ + "my_plugin = my_package.plugin", + ], + }, + ) + ``` + +Your plugin class should contain methods decorated with `@customize`, just like in conftest.py: + +``` python title="my_package/plugin.py" +from inline_snapshot.plugin import customize, Builder + + +@customize +def my_custom_handler(value, builder: Builder): + # Your customization logic here + if isinstance(value, YourCustomType): + return builder.create_call(YourCustomType, [value.arg]) +``` + +Once installed, the plugin will be automatically loaded by inline-snapshot. + +### Plugin Specification + +::: inline_snapshot.plugin + options: + heading_level: 3 + members: [InlineSnapshotPluginSpec] + show_root_heading: false + show_bases: false + show_source: false + + + +## Customize Examples + +The following examples demonstrate common use cases for the `@customize` hook. Each example shows how to implement custom representation logic for different scenarios. + +The [customize][inline_snapshot.plugin.InlineSnapshotPluginSpec.customize] hook controls how inline-snapshot generates your snapshots. +You should use it when you find yourself manually editing snapshots after they were created by inline-snapshot. + + +### Custom constructor methods +One use case might be that you have a dataclass with a special constructor function that can be used for specific instances of this dataclass, and you want inline-snapshot to use this constructor when possible. + + +``` python title="rect.py" +from dataclasses import dataclass + + +@dataclass +class Rect: + width: int + height: int + + @staticmethod + def make_square(size): + return Rect(size, size) +``` + +You can define a hook in your `conftest.py` that checks if your value is a square and calls the correct constructor function. + + +``` python title="conftest.py" +from rect import Rect +from inline_snapshot.plugin import customize +from inline_snapshot.plugin import Builder + + +@customize +def square_handler(value, builder: Builder): + if isinstance(value, Rect) and value.width == value.height: + return builder.create_call(Rect.make_square, [value.width]) +``` + +This allows you to influence the code that is created by inline-snapshot. + + +``` python title="test_square.py" +from rect import Rect +from inline_snapshot import snapshot + + +def test_square(): + assert Rect.make_square(5) == snapshot(Rect.make_square(5)) # (1)! + assert Rect(1, 1) == snapshot(Rect.make_square(1)) # (2)! + assert Rect(1, 2) == snapshot(Rect(width=1, height=2)) # (3)! + assert [Rect(3, 3), Rect(4, 5)] == snapshot( + [Rect.make_square(3), Rect(width=4, height=5)] + ) # (4)! +``` + +1. Your handler is used because you created a square +2. Your handler is used because you created a Rect that happens to have the same width and height +3. Your handler is not used because width and height are different +4. The handler is applied recursively to each Rect inside the list - the first is converted to `make_square()` while the second uses the regular constructor + +### dirty-equal expressions +It can also be used to instruct inline-snapshot to use specific dirty-equals expressions for specific values. + + +``` python title="conftest.py" +from inline_snapshot.plugin import customize +from inline_snapshot.plugin import Builder +from dirty_equals import IsNow + + +@customize +def is_now_handler(value): + if value == IsNow(): + return IsNow +``` + +As explained in the [customize hook specification][inline_snapshot.plugin. +InlineSnapshotPluginSpec.customize], you can return types other than Custom objects. +inline-snapshot includes a built-in handler in its default plugin that converts dirty-equals expressions back into source code, which is why you can return `IsNow` directly without using the builder. +This approach is much simpler than using `builder.create_call()` for complex dirty-equals expressions. + + +``` python title="test_is_now.py" +from datetime import datetime +from dirty_equals import IsNow # (1)! +from inline_snapshot import snapshot + + +def test_is_now(): + assert datetime.now() == snapshot(IsNow) +``` + +1. inline-snapshot also creates the imports when they are missing + +!!! important + inline-snapshot will never change the dirty-equals expressions in your code because they are [unmanaged](eq_snapshot.md#unmanaged-snapshot-values). + Using `@customize` with dirty-equals is a one-way ticket. Once the code is created, inline-snapshot does not know if it was created by inline-snapshot itself or by the user and will not change it when you change the `@customize` implementation, because it has to assume that it was created by the user. + + +### Conditional external objects + +`create_external` can be used to store values in external files if a specific criterion is met. + + +``` python title="conftest.py" +from inline_snapshot.plugin import customize +from inline_snapshot.plugin import Builder + + +@customize +def long_string_handler(value, builder: Builder): + if isinstance(value, str) and value.count("\n") > 5: + return builder.create_external(value) +``` + + +``` python title="test_long_strings.py" +from inline_snapshot import external, snapshot + + +def test_long_strings(): + assert "a\nb\nc" == snapshot( + """\ +a +b +c\ +""" + ) + assert "a\n" * 50 == snapshot( + external("uuid:e3e70682-c209-4cac-a29f-6fbed82c07cd.txt") + ) +``` + +### Reusing local variables + +There are times when your local or global variables become part of your snapshots, like UUIDs or user names. +Customize hooks accept `local_vars` and `global_vars` as arguments that can be used to generate the code. + + +``` python title="conftest.py" +from inline_snapshot.plugin import customize +from inline_snapshot.plugin import Builder + + +@customize +def local_var_handler(value, builder, local_vars): + for var_name, var_value in local_vars.items(): + if var_name.startswith("v_") and var_value == value: + return builder.create_code(var_name) +``` + +We check all local variables to see if they match our naming convention and are equal to the value that is part of our snapshot, and return the local variable if we find one that fits the criteria. + + + +``` python title="test_user.py" +from inline_snapshot import snapshot + + +def get_data(user): + return {"user": user, "age": 55} + + +def test_user(): + v_user = "Bob" + some_number = 50 + 5 + + assert get_data(v_user) == snapshot({"user": v_user, "age": 55}) +``` + +inline-snapshot uses `v_user` because it met the criteria in your customization hook, but not `some_number` because it does not start with `v_`. +You can also do this only for specific types of objects or for a whitelist of variable names. +It is up to you to set the rules that work best in your project. + +!!! note + It is not recommended to check only for the value because this might result in local variables which become part of the snapshot just because they are equal to the value and not because they should be there (see `age=55` in the example above). + This is also the reason why inline-snapshot does not provide default customizations that check your local variables. + The rules are project-specific and what might work well for one project can cause problems for others. + +### Creating special code + +Let's say that you have an array of secrets which are used in your code. + + +``` python title="my_secrets.py" +secrets = ["some_secret", "some_other_secret"] +``` + + +``` python title="get_data.py" +from my_secrets import secrets + + +def get_data(): + return {"data": "large data block", "used_secret": secrets[1]} +``` + +The problem is that `--inline-snapshot=create` puts your secret into your test. + + +``` python +from get_data import get_data +from inline_snapshot import snapshot + + +def test_my_class(): + assert get_data() == snapshot( + {"data": "large data block", "used_secret": "some_other_secret"} + ) +``` + +Maybe this is not what you want because the secret is a different one in CI or for every test run or the raw value leads to unreadable tests. +What you can do now, instead of replacing `"some_other_secret"` with `secrets[1]` by hand, is to tell inline-snapshot in your *conftest.py* how it should generate this code. + + +``` python title="conftest.py" +from my_secrets import secrets +from inline_snapshot.plugin import customize, Builder, ImportFrom + + +@customize +def secret_handler(value, builder: Builder): + for i, secret in enumerate(secrets): + if value == secret: + return builder.create_code( + f"secrets[{i}]", + imports=[ImportFrom("my_secrets", "secrets")], + ) +``` + +The [`create_code()`][inline_snapshot.plugin.Builder.create_code] method takes the desired code representation. The `imports` parameter adds the necessary import statements. + +inline-snapshot will now create the correct code and import statement when you run your tests with `--inline-snapshot=update`. + + +``` python hl_lines="4 5 9" +from get_data import get_data +from inline_snapshot import snapshot + +from my_secrets import secrets + + +def test_my_class(): + assert get_data() == snapshot( + {"data": "large data block", "used_secret": secrets[1]} + ) +``` + +!!! question "why update?" + `"some_other_secret"` was already a correct value for your assertion and `--inline-snapshot=fix` only changes code when the current value is not correct and needs to be fixed. `update` is the category for all other changes where inline-snapshot wants to generate different code which represents the same value as the code before. + + You only have to use `update` when you changed your customizations and want to use the new code representations in your existing tests. The new representation is also used by `create` or `fix` when you write new tests. + + The `update` category is not enabled by default for `--inline-snapshot=review/report`. + You can read [here](categories.md#update) more about it. + + + + + + + + + +## Reference +::: inline_snapshot.plugin + options: + heading_level: 3 + members: [hookimpl,customize,Builder,Custom,Import,ImportFrom] + show_root_heading: false + show_bases: false + show_source: false diff --git a/docs/testing.md b/docs/testing.md index 19977812..4d9e4c9f 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -5,8 +5,8 @@ The following example shows how you can use the `Example` class to test what inl ``` python -from inline_snapshot.testing import Example from inline_snapshot import snapshot +from inline_snapshot.testing import Example def test_something(): @@ -30,8 +30,8 @@ Inline-snapshot will then populate the empty snapshots. ``` python hl_lines="17 18 19 20 21 22 23 24 25 26" -from inline_snapshot.testing import Example from inline_snapshot import snapshot +from inline_snapshot.testing import Example def test_something(): @@ -68,8 +68,8 @@ This allows for more complex tests where you create one example and perform mult ``` python -from inline_snapshot.testing import Example from inline_snapshot import snapshot +from inline_snapshot.testing import Example def test_something(): @@ -116,8 +116,8 @@ You can also use the same example multiple times and call different methods on i ``` python -from inline_snapshot.testing import Example from inline_snapshot import snapshot +from inline_snapshot.testing import Example def test_something(): diff --git a/mkdocs.yml b/mkdocs.yml index 6759f2e9..fb37b78f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -47,6 +47,7 @@ nav: - Pytest: pytest.md - PyCharm: pycharm.md - Categories: categories.md + - Plugin: plugin.md - Code generation: code_generation.md - Limitations: limitations.md - Alternatives: alternatives.md diff --git a/pyproject.toml b/pyproject.toml index 3d517995..de02550c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "rich>=13.7.1", "tomli>=2.0.0; python_version < '3.11'", "pytest>=8.3.4", + "typing-extensions" ] description = "golden master/snapshot/approval testing library which puts the values right into your source code" keywords = [] @@ -59,7 +60,8 @@ dev = [ "coverage-enable-subprocess>=1.0", "attrs>=24.3.0", "pydantic>=1", - "black==25.1.0" + "black==25.1.0", + "isort" ] [project.entry-points.pytest11] @@ -93,7 +95,9 @@ exclude_lines = [ "# pragma: no cover", "if TYPE_CHECKING:", "if is_insider", + "\\.\\.\\." ] +ignore_errors=true [tool.coverage.run] @@ -126,14 +130,17 @@ dependencies = [ ] [tool.hatch.envs.docs.scripts] -build = "mkdocs build --strict" +build = "mkdocs build --strict {args}" export-deps = "pip freeze" -serve = "mkdocs serve --livereload" +serve = "mkdocs serve --livereload {args}" [tool.hatch.envs.default] installer="uv" +[tool.hatch.envs.cov] +python="3.12" + [tool.hatch.envs.cov.scripts] github=[ "- rm htmlcov/*", @@ -165,16 +172,7 @@ matrix.extra-deps.dependencies = [ ] [tool.hatch.envs.hatch-test] -extra-dependencies = [ - "inline-snapshot[black,dirty-equals]", - "dirty-equals>=0.9.0", - "hypothesis>=6.75.5", - "mypy>=1.2.0 ; implementation_name == 'cpython'", - "pyright>=1.1.359", - "pytest-freezer>=0.4.8", - "pytest-mock>=3.14.0", - "black==25.1.0" -] +dependency-groups=["dev"] env-vars.TOP = "{root}" [tool.hatch.envs.hatch-test.scripts] @@ -184,13 +182,9 @@ cov-combine = "coverage combine" cov-report=["coverage report","coverage html"] [tool.hatch.envs.types] +dependency-groups=["dev"] extra-dependencies = [ - "inline-snapshot[black,dirty-equals]", "mypy>=1.0.0", - "hypothesis>=6.75.5", - "pydantic", - "attrs", - "typing-extensions" ] [[tool.hatch.envs.types.matrix]] diff --git a/src/inline_snapshot/_adapter/__init__.py b/src/inline_snapshot/_adapter/__init__.py deleted file mode 100644 index 2f699011..00000000 --- a/src/inline_snapshot/_adapter/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .adapter import get_adapter_type - -__all__ = ("get_adapter_type",) diff --git a/src/inline_snapshot/_adapter/adapter.py b/src/inline_snapshot/_adapter/adapter.py deleted file mode 100644 index 0ed39a40..00000000 --- a/src/inline_snapshot/_adapter/adapter.py +++ /dev/null @@ -1,101 +0,0 @@ -from __future__ import annotations - -import ast -import typing -from dataclasses import dataclass - -from inline_snapshot._source_file import SourceFile - - -def get_adapter_type(value): - from inline_snapshot._adapter.generic_call_adapter import get_adapter_for_type - - adapter = get_adapter_for_type(type(value)) - if adapter is not None: - return adapter - - if isinstance(value, list): - from .sequence_adapter import ListAdapter - - return ListAdapter - - if type(value) is tuple: - from .sequence_adapter import TupleAdapter - - return TupleAdapter - - if isinstance(value, dict): - from .dict_adapter import DictAdapter - - return DictAdapter - - from .value_adapter import ValueAdapter - - return ValueAdapter - - -class Item(typing.NamedTuple): - value: typing.Any - node: ast.expr - - -@dataclass -class FrameContext: - globals: dict - locals: dict - - -@dataclass -class AdapterContext: - file: SourceFile - frame: FrameContext | None - qualname: str - - def eval(self, node): - assert self.frame is not None - - return eval( - compile(ast.Expression(node), self.file.filename, "eval"), - self.frame.globals, - self.frame.locals, - ) - - -class Adapter: - context: AdapterContext - - def __init__(self, context: AdapterContext): - self.context = context - - def get_adapter(self, old_value, new_value) -> Adapter: - if type(old_value) is not type(new_value): - from .value_adapter import ValueAdapter - - return ValueAdapter(self.context) - - adapter_type = get_adapter_type(old_value) - if adapter_type is not None: - return adapter_type(self.context) - assert False - - def assign(self, old_value, old_node, new_value): - raise NotImplementedError(self) - - def value_assign(self, old_value, old_node, new_value): - from .value_adapter import ValueAdapter - - adapter = ValueAdapter(self.context) - result = yield from adapter.assign(old_value, old_node, new_value) - return result - - @classmethod - def map(cls, value, map_function): - raise NotImplementedError(cls) - - @classmethod - def repr(cls, value): - raise NotImplementedError(cls) - - -def adapter_map(value, map_function): - return get_adapter_type(value).map(value, map_function) diff --git a/src/inline_snapshot/_adapter/dict_adapter.py b/src/inline_snapshot/_adapter/dict_adapter.py deleted file mode 100644 index 1f0b4a0c..00000000 --- a/src/inline_snapshot/_adapter/dict_adapter.py +++ /dev/null @@ -1,145 +0,0 @@ -from __future__ import annotations - -import ast -import warnings - -from .._change import Delete -from .._change import DictInsert -from ..syntax_warnings import InlineSnapshotSyntaxWarning -from .adapter import Adapter -from .adapter import Item -from .adapter import adapter_map - - -class DictAdapter(Adapter): - - @classmethod - def repr(cls, value): - result = ( - "{" - + ", ".join(f"{repr(k)}: {repr(value)}" for k, value in value.items()) - + "}" - ) - - if type(value) is not dict: - result = f"{repr(type(value))}({result})" - - return result - - @classmethod - def map(cls, value, map_function): - return {k: adapter_map(v, map_function) for k, v in value.items()} - - @classmethod - def items(cls, value, node): - if node is None or not isinstance(node, ast.Dict): - return [Item(value=value, node=None) for value in value.values()] - - result = [] - - for value_key, node_key, node_value in zip( - value.keys(), node.keys, node.values - ): - try: - # this is just a sanity check, dicts should be ordered - node_key = ast.literal_eval(node_key) - except Exception: - pass - else: - assert node_key == value_key - - result.append(Item(value=value[value_key], node=node_value)) - - return result - - def assign(self, old_value, old_node, new_value): - if old_node is not None: - if not ( - isinstance(old_node, ast.Dict) and len(old_value) == len(old_node.keys) - ): - result = yield from self.value_assign(old_value, old_node, new_value) - return result - - for key, value in zip(old_node.keys, old_node.values): - if key is None: - warnings.warn_explicit( - "star-expressions are not supported inside snapshots", - filename=self.context.file._source.filename, - lineno=value.lineno, - category=InlineSnapshotSyntaxWarning, - ) - return old_value - - for value, node in zip(old_value.keys(), old_node.keys): - - try: - # this is just a sanity check, dicts should be ordered - node_value = ast.literal_eval(node) - except: - continue - assert node_value == value - - result = {} - for key, node in zip( - old_value.keys(), - (old_node.values if old_node is not None else [None] * len(old_value)), - ): - if key not in new_value: - # delete entries - yield Delete("fix", self.context.file._source, node, old_value[key]) - - to_insert = [] - insert_pos = 0 - for key, new_value_element in new_value.items(): - if key not in old_value: - # add new values - to_insert.append((key, new_value_element)) - result[key] = new_value_element - else: - if isinstance(old_node, ast.Dict): - node = old_node.values[list(old_value.keys()).index(key)] - else: - node = None - # check values with same keys - result[key] = yield from self.get_adapter( - old_value[key], new_value[key] - ).assign(old_value[key], node, new_value[key]) - - if to_insert: - new_code = [ - ( - self.context.file._value_to_code(k), - self.context.file._value_to_code(v), - ) - for k, v in to_insert - ] - yield DictInsert( - "fix", - self.context.file._source, - old_node, - insert_pos, - new_code, - to_insert, - ) - to_insert = [] - - insert_pos += 1 - - if to_insert: - new_code = [ - ( - self.context.file._value_to_code(k), - self.context.file._value_to_code(v), - ) - for k, v in to_insert - ] - yield DictInsert( - "fix", - self.context.file._source, - old_node, - len(old_value), - new_code, - to_insert, - ) - - return result diff --git a/src/inline_snapshot/_adapter/generic_call_adapter.py b/src/inline_snapshot/_adapter/generic_call_adapter.py deleted file mode 100644 index afa67f8f..00000000 --- a/src/inline_snapshot/_adapter/generic_call_adapter.py +++ /dev/null @@ -1,449 +0,0 @@ -from __future__ import annotations - -import ast -import warnings -from abc import ABC -from collections import defaultdict -from dataclasses import MISSING -from dataclasses import fields -from dataclasses import is_dataclass -from typing import Any - -from .._change import CallArg -from .._change import Delete -from ..syntax_warnings import InlineSnapshotSyntaxWarning -from .adapter import Adapter -from .adapter import Item -from .adapter import adapter_map - - -def get_adapter_for_type(value_type): - subclasses = GenericCallAdapter.__subclasses__() - options = [cls for cls in subclasses if cls.check_type(value_type)] - - if not options: - return - - assert len(options) == 1 - return options[0] - - -class Argument: - value: Any - is_default: bool = False - - def __init__(self, value, is_default=False): - self.value = value - self.is_default = is_default - - -class GenericCallAdapter(Adapter): - - @classmethod - def check_type(cls, value_type) -> bool: - raise NotImplementedError(cls) - - @classmethod - def arguments(cls, value) -> tuple[list[Argument], dict[str, Argument]]: - raise NotImplementedError(cls) - - @classmethod - def argument(cls, value, pos_or_name) -> Any: - raise NotImplementedError(cls) - - @classmethod - def repr(cls, value): - - args, kwargs = cls.arguments(value) - - arguments = [repr(value.value) for value in args] + [ - f"{key}={repr(value.value)}" - for key, value in kwargs.items() - if not value.is_default - ] - - return f"{repr(type(value))}({', '.join(arguments)})" - - @classmethod - def map(cls, value, map_function): - new_args, new_kwargs = cls.arguments(value) - return type(value)( - *[adapter_map(arg.value, map_function) for arg in new_args], - **{ - k: adapter_map(kwarg.value, map_function) - for k, kwarg in new_kwargs.items() - }, - ) - - @classmethod - def items(cls, value, node): - new_args, new_kwargs = cls.arguments(value) - - if node is not None: - assert isinstance(node, ast.Call) - assert all(kw.arg for kw in node.keywords) - kw_arg_node = {kw.arg: kw.value for kw in node.keywords if kw.arg}.get - - def pos_arg_node(pos): - return node.args[pos] - - else: - - def kw_arg_node(_): - return None - - def pos_arg_node(_): - return None - - return [ - Item(value=arg.value, node=pos_arg_node(i)) - for i, arg in enumerate(new_args) - ] + [ - Item(value=kw.value, node=kw_arg_node(name)) - for name, kw in new_kwargs.items() - ] - - def assign(self, old_value, old_node, new_value): - if old_node is None or not isinstance(old_node, ast.Call): - result = yield from self.value_assign(old_value, old_node, new_value) - return result - - call_type = self.context.eval(old_node.func) - - if not (isinstance(call_type, type) and self.check_type(call_type)): - result = yield from self.value_assign(old_value, old_node, new_value) - return result - - # positional arguments - for pos_arg in old_node.args: - if isinstance(pos_arg, ast.Starred): - warnings.warn_explicit( - "star-expressions are not supported inside snapshots", - filename=self.context.file._source.filename, - lineno=pos_arg.lineno, - category=InlineSnapshotSyntaxWarning, - ) - return old_value - - # keyword arguments - for kw in old_node.keywords: - if kw.arg is None: - warnings.warn_explicit( - "star-expressions are not supported inside snapshots", - filename=self.context.file._source.filename, - lineno=kw.value.lineno, - category=InlineSnapshotSyntaxWarning, - ) - return old_value - - new_args, new_kwargs = self.arguments(new_value) - - # positional arguments - - result_args = [] - - for i, (new_value_element, node) in enumerate(zip(new_args, old_node.args)): - old_value_element = self.argument(old_value, i) - result = yield from self.get_adapter( - old_value_element, new_value_element.value - ).assign(old_value_element, node, new_value_element.value) - result_args.append(result) - - if len(old_node.args) > len(new_args): - for arg_pos, node in list(enumerate(old_node.args))[len(new_args) :]: - yield Delete( - "fix", - self.context.file._source, - node, - self.argument(old_value, arg_pos), - ) - - if len(old_node.args) < len(new_args): - for insert_pos, value in list(enumerate(new_args))[len(old_node.args) :]: - yield CallArg( - flag="fix", - file=self.context.file._source, - node=old_node, - arg_pos=insert_pos, - arg_name=None, - new_code=self.context.file._value_to_code(value.value), - new_value=value.value, - ) - - # keyword arguments - result_kwargs = {} - for kw in old_node.keywords: - if kw.arg not in new_kwargs or new_kwargs[kw.arg].is_default: - # delete entries - yield Delete( - ( - "update" - if self.argument(old_value, kw.arg) - == self.argument(new_value, kw.arg) - else "fix" - ), - self.context.file._source, - kw.value, - self.argument(old_value, kw.arg), - ) - - old_node_kwargs = {kw.arg: kw.value for kw in old_node.keywords} - - to_insert = [] - insert_pos = 0 - for key, new_value_element in new_kwargs.items(): - if new_value_element.is_default: - continue - if key not in old_node_kwargs: - # add new values - to_insert.append((key, new_value_element.value)) - result_kwargs[key] = new_value_element.value - else: - node = old_node_kwargs[key] - - # check values with same keys - old_value_element = self.argument(old_value, key) - result_kwargs[key] = yield from self.get_adapter( - old_value_element, new_value_element.value - ).assign(old_value_element, node, new_value_element.value) - - if to_insert: - for key, value in to_insert: - - yield CallArg( - flag="fix", - file=self.context.file._source, - node=old_node, - arg_pos=insert_pos, - arg_name=key, - new_code=self.context.file._value_to_code(value), - new_value=value, - ) - to_insert = [] - - insert_pos += 1 - - if to_insert: - - for key, value in to_insert: - - yield CallArg( - flag="fix", - file=self.context.file._source, - node=old_node, - arg_pos=insert_pos, - arg_name=key, - new_code=self.context.file._value_to_code(value), - new_value=value, - ) - - return type(old_value)(*result_args, **result_kwargs) - - -class DataclassAdapter(GenericCallAdapter): - - @classmethod - def check_type(cls, value): - return is_dataclass(value) - - @classmethod - def arguments(cls, value): - - kwargs = {} - - for field in fields(value): # type: ignore - if field.repr: - field_value = getattr(value, field.name) - is_default = False - - if field.default != MISSING and field.default == field_value: - is_default = True - - if ( - field.default_factory != MISSING - and field.default_factory() == field_value - ): - is_default = True - - kwargs[field.name] = Argument(value=field_value, is_default=is_default) - - return ([], kwargs) - - def argument(self, value, pos_or_name): - if isinstance(pos_or_name, str): - return getattr(value, pos_or_name) - else: - args = [field for field in fields(value) if field.init] - return args[pos_or_name] - - -try: - import attrs -except ImportError: # pragma: no cover - pass -else: - - class AttrAdapter(GenericCallAdapter): - - @classmethod - def check_type(cls, value): - return attrs.has(value) - - @classmethod - def arguments(cls, value): - - kwargs = {} - - for field in attrs.fields(type(value)): - if field.repr: - field_value = getattr(value, field.name) - is_default = False - - if field.default is not attrs.NOTHING: - - default_value = ( - field.default - if not isinstance(field.default, attrs.Factory) - else ( - field.default.factory() - if not field.default.takes_self - else field.default.factory(value) - ) - ) - - if default_value == field_value: - - is_default = True - - kwargs[field.name] = Argument( - value=field_value, is_default=is_default - ) - - return ([], kwargs) - - def argument(self, value, pos_or_name): - assert isinstance(pos_or_name, str) - return getattr(value, pos_or_name) - - -try: - import pydantic -except ImportError: # pragma: no cover - pass -else: - # import pydantic - if pydantic.version.VERSION.startswith("1."): - # pydantic v1 - from pydantic.fields import Undefined as PydanticUndefined # type: ignore[attr-defined,no-redef] - - def get_fields(value): - return value.__fields__ - - else: - # pydantic v2 - from pydantic_core import PydanticUndefined - - def get_fields(value): - return type(value).model_fields - - from pydantic import BaseModel - - class PydanticContainer(GenericCallAdapter): - - @classmethod - def check_type(cls, value): - return issubclass(value, BaseModel) - - @classmethod - def arguments(cls, value): - - kwargs = {} - - for name, field in get_fields(value).items(): # type: ignore - if getattr(field, "repr", True): - field_value = getattr(value, name) - is_default = False - - if ( - field.default is not PydanticUndefined - and field.default == field_value - ): - is_default = True - - if ( - field.default_factory is not None - and field.default_factory() == field_value - ): - is_default = True - - kwargs[name] = Argument(value=field_value, is_default=is_default) - - return ([], kwargs) - - @classmethod - def argument(cls, value, pos_or_name): - assert isinstance(pos_or_name, str) - return getattr(value, pos_or_name) - - -class IsNamedTuple(ABC): - _inline_snapshot_name = "namedtuple" - - _fields: tuple - _field_defaults: dict - - @classmethod - def __subclasshook__(cls, t): - b = t.__bases__ - if len(b) != 1 or b[0] != tuple: - return False - f = getattr(t, "_fields", None) - if not isinstance(f, tuple): - return False - return all(type(n) == str for n in f) - - -class NamedTupleAdapter(GenericCallAdapter): - - @classmethod - def check_type(cls, value): - return issubclass(value, IsNamedTuple) - - @classmethod - def arguments(cls, value: IsNamedTuple): - - return ( - [], - { - field: Argument(value=getattr(value, field)) - for field in value._fields - if field not in value._field_defaults - or getattr(value, field) != value._field_defaults[field] - }, - ) - - def argument(self, value, pos_or_name): - assert isinstance(pos_or_name, str) - return getattr(value, pos_or_name) - - -class DefaultDictAdapter(GenericCallAdapter): - @classmethod - def check_type(cls, value): - return issubclass(value, defaultdict) - - @classmethod - def arguments(cls, value: defaultdict): - - return ( - [Argument(value=value.default_factory), Argument(value=dict(value))], - {}, - ) - - def argument(self, value, pos_or_name): - assert isinstance(pos_or_name, int) - if pos_or_name == 0: - return value.default_factory - elif pos_or_name == 1: - return dict(value) - assert False diff --git a/src/inline_snapshot/_adapter/sequence_adapter.py b/src/inline_snapshot/_adapter/sequence_adapter.py deleted file mode 100644 index a080e840..00000000 --- a/src/inline_snapshot/_adapter/sequence_adapter.py +++ /dev/null @@ -1,119 +0,0 @@ -from __future__ import annotations - -import ast -import warnings -from collections import defaultdict - -from .._align import add_x -from .._align import align -from .._change import Delete -from .._change import ListInsert -from .._compare_context import compare_context -from ..syntax_warnings import InlineSnapshotSyntaxWarning -from .adapter import Adapter -from .adapter import Item -from .adapter import adapter_map - - -class SequenceAdapter(Adapter): - node_type: type - value_type: type - braces: str - trailing_comma: bool - - @classmethod - def repr(cls, value): - if len(value) == 1 and cls.trailing_comma: - seq = repr(value[0]) + "," - else: - seq = ", ".join(map(repr, value)) - return cls.braces[0] + seq + cls.braces[1] - - @classmethod - def map(cls, value, map_function): - result = [adapter_map(v, map_function) for v in value] - return cls.value_type(result) - - @classmethod - def items(cls, value, node): - if node is None or not isinstance(node, cls.node_type): - return [Item(value=v, node=None) for v in value] - - assert len(value) == len(node.elts) - - return [Item(value=v, node=n) for v, n in zip(value, node.elts)] - - def assign(self, old_value, old_node, new_value): - if old_node is not None: - if not isinstance( - old_node, ast.List if isinstance(old_value, list) else ast.Tuple - ): - result = yield from self.value_assign(old_value, old_node, new_value) - return result - - for e in old_node.elts: - if isinstance(e, ast.Starred): - warnings.warn_explicit( - "star-expressions are not supported inside snapshots", - filename=self.context.file.filename, - lineno=e.lineno, - category=InlineSnapshotSyntaxWarning, - ) - return old_value - - with compare_context(): - diff = add_x(align(old_value, new_value)) - old = zip( - old_value, - old_node.elts if old_node is not None else [None] * len(old_value), - ) - new = iter(new_value) - old_position = 0 - to_insert = defaultdict(list) - result = [] - for c in diff: - if c in "mx": - old_value_element, old_node_element = next(old) - new_value_element = next(new) - v = yield from self.get_adapter( - old_value_element, new_value_element - ).assign(old_value_element, old_node_element, new_value_element) - result.append(v) - old_position += 1 - elif c == "i": - new_value_element = next(new) - new_code = self.context.file._value_to_code(new_value_element) - result.append(new_value_element) - to_insert[old_position].append((new_code, new_value_element)) - elif c == "d": - old_value_element, old_node_element = next(old) - yield Delete( - "fix", - self.context.file._source, - old_node_element, - old_value_element, - ) - old_position += 1 - else: - assert False - - for position, code_values in to_insert.items(): - yield ListInsert( - "fix", self.context.file._source, old_node, position, *zip(*code_values) - ) - - return self.value_type(result) - - -class ListAdapter(SequenceAdapter): - node_type = ast.List - value_type = list - braces = "[]" - trailing_comma = False - - -class TupleAdapter(SequenceAdapter): - node_type = ast.Tuple - value_type = tuple - braces = "()" - trailing_comma = True diff --git a/src/inline_snapshot/_adapter/value_adapter.py b/src/inline_snapshot/_adapter/value_adapter.py deleted file mode 100644 index a47e38eb..00000000 --- a/src/inline_snapshot/_adapter/value_adapter.py +++ /dev/null @@ -1,74 +0,0 @@ -from __future__ import annotations - -import ast -import warnings - -from .._change import Replace -from .._code_repr import value_code_repr -from .._sentinels import undefined -from .._unmanaged import Unmanaged -from .._unmanaged import update_allowed -from .._utils import value_to_token -from ..syntax_warnings import InlineSnapshotInfo -from .adapter import Adapter - - -class ValueAdapter(Adapter): - - @classmethod - def repr(cls, value): - return value_code_repr(value) - - @classmethod - def map(cls, value, map_function): - return map_function(value) - - def assign(self, old_value, old_node, new_value): - # generic fallback - - # because IsStr() != IsStr() - if isinstance(old_value, Unmanaged): - return old_value - - if old_node is None: - new_token = [] - else: - new_token = value_to_token(new_value) - - if isinstance(old_node, ast.JoinedStr) and isinstance(new_value, str): - if not old_value == new_value: - warnings.warn_explicit( - f"inline-snapshot will be able to fix f-strings in the future.\nThe current string value is:\n {new_value!r}", - filename=self.context.file._source.filename, - lineno=old_node.lineno, - category=InlineSnapshotInfo, - ) - return old_value - - if not old_value == new_value: - if old_value is undefined: - flag = "create" - else: - flag = "fix" - elif ( - old_node is not None - and update_allowed(old_value) - and self.context.file._token_of_node(old_node) != new_token - ): - flag = "update" - else: - # equal and equal repr - return old_value - - new_code = self.context.file._token_to_code(new_token) - - yield Replace( - node=old_node, - file=self.context.file._source, - new_code=new_code, - flag=flag, - old_value=old_value, - new_value=new_value, - ) - - return new_value diff --git a/src/inline_snapshot/_adapter_context.py b/src/inline_snapshot/_adapter_context.py new file mode 100644 index 00000000..e95cf3c3 --- /dev/null +++ b/src/inline_snapshot/_adapter_context.py @@ -0,0 +1,27 @@ +import ast +from dataclasses import dataclass +from typing import Optional + +from inline_snapshot._source_file import SourceFile + + +@dataclass +class FrameContext: + globals: dict + locals: dict + + +@dataclass +class AdapterContext: + file: SourceFile + frame: Optional[FrameContext] + qualname: str + + def eval(self, node): + assert self.frame is not None + + return eval( + compile(ast.Expression(node), self.file.filename, "eval"), + self.frame.globals, + self.frame.locals, + ) diff --git a/src/inline_snapshot/_change.py b/src/inline_snapshot/_change.py index eb5c5487..176ec3dd 100644 --- a/src/inline_snapshot/_change.py +++ b/src/inline_snapshot/_change.py @@ -14,6 +14,7 @@ from executing.executing import EnhancedAST from inline_snapshot._external._external_location import Location +from inline_snapshot._external._find_external import ensure_import from inline_snapshot._source_file import SourceFile from ._rewrite_code import ChangeRecorder @@ -116,30 +117,28 @@ def apply_external_changes(self): @dataclass() -class Delete(Change): - node: ast.AST - old_value: Any +class RequiredImport(Change): + module: str + name: str | None = None @dataclass() -class AddArgument(Change): - node: ast.Call - - position: int | None - name: str | None - - new_code: str - new_value: Any +class Delete(Change): + node: ast.AST | None + old_value: Any @dataclass() class ListInsert(Change): - node: ast.List + node: ast.List | ast.Tuple position: int new_code: list[str] new_values: list[Any] + def __post_init__(self): + self.new_code = [self.file.format_expression(v) for v in self.new_code] + @dataclass() class DictInsert(Change): @@ -149,6 +148,12 @@ class DictInsert(Change): new_code: list[tuple[str, str]] new_values: list[tuple[Any, Any]] + def __post_init__(self): + self.new_code = [ + (self.file.format_expression(k), self.file.format_expression(v)) + for k, v in self.new_code + ] + @dataclass() class Replace(Change): @@ -163,6 +168,9 @@ def apply(self, recorder: ChangeRecorder): range = self.file.asttokens().get_text_positions(self.node, False) change.replace(range, self.new_code, filename=self.filename) + def __post_init__(self): + self.new_code = self.file.format_expression(self.new_code) + @dataclass() class CallArg(Change): @@ -173,6 +181,9 @@ class CallArg(Change): new_code: str new_value: Any + def __post_init__(self): + self.new_code = self.file.format_expression(self.new_code) + TokenRange = Tuple[Token, Token] @@ -251,6 +262,10 @@ def apply_all(all_changes: list[ChangeBase], recorder: ChangeRecorder): ) sources: dict[EnhancedAST, SourceFile] = {} + # file -> module -> names + imports_by_file: dict[str, dict[str, set]] = defaultdict(lambda: defaultdict(set)) + module_imports_by_file: dict[str, set] = defaultdict(set) + for change in all_changes: if isinstance(change, Delete): node = cast(EnhancedAST, change.node).parent @@ -263,9 +278,19 @@ def apply_all(all_changes: list[ChangeBase], recorder: ChangeRecorder): node = cast(EnhancedAST, change.node) by_parent[node].append(change) sources[node] = change.file + elif isinstance(change, RequiredImport): + if change.name: + imports_by_file[change.file.filename][change.module].add(change.name) + else: + module_imports_by_file[change.file.filename].add(change.module) else: change.apply(recorder) + for filename in set(imports_by_file) | set(module_imports_by_file): + imports = imports_by_file.get(filename, defaultdict(set)) + module_imports = module_imports_by_file.get(filename, set()) + ensure_import(filename, imports, module_imports, recorder) + for parent, changes in by_parent.items(): source = sources[parent] @@ -320,7 +345,7 @@ def arg_token_range(node): else len(parent.args) + len(parent.keywords) ) to_insert[position].append( - f"{change.arg_name} = {change.new_code}" + f"{change.arg_name}={change.new_code}" ) else: assert change.arg_pos is not None diff --git a/src/inline_snapshot/_code_repr.py b/src/inline_snapshot/_code_repr.py index 0e24cf4b..09df8105 100644 --- a/src/inline_snapshot/_code_repr.py +++ b/src/inline_snapshot/_code_repr.py @@ -1,11 +1,19 @@ -import ast -from enum import Enum -from enum import Flag +from __future__ import annotations + +import warnings +from contextlib import contextmanager from functools import singledispatch -from types import BuiltinFunctionType -from types import FunctionType +from typing import TYPE_CHECKING from unittest import mock +from typing_extensions import deprecated + +from inline_snapshot._generator_utils import only_value + +if TYPE_CHECKING: + from inline_snapshot._adapter_context import AdapterContext + + real_repr = repr @@ -35,19 +43,9 @@ def __eq__(self, other): if type(other) is not self._type: return False - other_repr = code_repr(other) - return other_repr == self._str_repr or other_repr == repr(self) - - -def used_hasrepr(tree): - return [ - n - for n in ast.walk(tree) - if isinstance(n, ast.Call) - and isinstance(n.func, ast.Name) - and n.func.id == "HasRepr" - and len(n.args) == 2 - ] + with mock_repr(None): + other_repr = value_code_repr(other) + return other_repr == self._str_repr or other_repr == real_repr(self) @singledispatch @@ -55,6 +53,7 @@ def code_repr_dispatch(value): return real_repr(value) +@deprecated("use @customize instead") def customize_repr(f): """Register a function which should be used to get the code representation of a object. @@ -71,24 +70,37 @@ def _(obj: MyCustomClass): * __repr__() of your class returns a valid code representation, * and __repr__() uses `repr()` to get the representation of the child objects """ + warnings.warn( + "@customize_repr is deprecated, @customize should be used instead", + DeprecationWarning, + ) code_repr_dispatch.register(f) def code_repr(obj): + from inline_snapshot._adapter_context import AdapterContext - with mock.patch("builtins.repr", mocked_code_repr): - return mocked_code_repr(obj) + context = AdapterContext(None, None, "") + with mock_repr(context): + return repr(obj) -def mocked_code_repr(obj): - from inline_snapshot._adapter.adapter import get_adapter_type +@contextmanager +def mock_repr(context: AdapterContext): + def new_repr(obj): + from inline_snapshot._customize._builder import Builder + + return only_value( + Builder(_snapshot_context=context)._get_handler(obj)._code_repr(context) + ) - adapter = get_adapter_type(obj) - assert adapter is not None - return adapter.repr(obj) + with mock.patch("builtins.repr", new_repr): + yield def value_code_repr(obj): + assert repr is not real_repr, "@mock_repr is missing" + if not type(obj) == type(obj): # pragma: no cover # this was caused by https://github.com/samuelcolvin/dirty-equals/issues/104 # dispatch will not work in cases like this @@ -98,70 +110,4 @@ def value_code_repr(obj): result = code_repr_dispatch(obj) - try: - ast.parse(result) - except SyntaxError: - return real_repr(HasRepr(type(obj), result)) - return result - - -# -8<- [start:Enum] -@customize_repr -def _(value: Enum): - return f"{type(value).__qualname__}.{value.name}" - - -# -8<- [end:Enum] - - -@customize_repr -def _(value: Flag): - name = type(value).__qualname__ - return " | ".join(f"{name}.{flag.name}" for flag in type(value) if flag in value) - - -def sort_set_values(set_values): - is_sorted = False - try: - set_values = sorted(set_values) - is_sorted = True - except TypeError: - pass - - set_values = list(map(repr, set_values)) - if not is_sorted: - set_values = sorted(set_values) - - return set_values - - -@customize_repr -def _(value: set): - if len(value) == 0: - return "set()" - - return "{" + ", ".join(sort_set_values(value)) + "}" - - -@customize_repr -def _(value: frozenset): - if len(value) == 0: - return "frozenset()" - - return "frozenset({" + ", ".join(sort_set_values(value)) + "})" - - -@customize_repr -def _(value: type): - return value.__qualname__ - - -@customize_repr -def _(value: FunctionType): - return value.__qualname__ - - -@customize_repr -def _(value: BuiltinFunctionType): - return value.__name__ diff --git a/src/inline_snapshot/_compare_context.py b/src/inline_snapshot/_compare_context.py index 104a235c..c70bcdcd 100644 --- a/src/inline_snapshot/_compare_context.py +++ b/src/inline_snapshot/_compare_context.py @@ -13,5 +13,7 @@ def compare_context(): global _eq_check_only old_eq_only = _eq_check_only _eq_check_only = True - yield - _eq_check_only = old_eq_only + try: + yield + finally: + _eq_check_only = old_eq_only diff --git a/src/inline_snapshot/_customize/__init__.py b/src/inline_snapshot/_customize/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/inline_snapshot/_customize/_builder.py b/src/inline_snapshot/_customize/_builder.py new file mode 100644 index 00000000..05dd9eaa --- /dev/null +++ b/src/inline_snapshot/_customize/_builder.py @@ -0,0 +1,255 @@ +from __future__ import annotations + +from dataclasses import dataclass +from functools import cached_property +from typing import Any +from typing import Callable + +from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._sentinels import undefined + +missing = undefined +from inline_snapshot._compare_context import compare_context +from inline_snapshot._exceptions import UsageError + +from ._custom import Custom +from ._custom_call import CustomCall +from ._custom_call import CustomDefault +from ._custom_code import CustomCode +from ._custom_code import Import +from ._custom_code import ImportFrom +from ._custom_dict import CustomDict +from ._custom_external import CustomExternal +from ._custom_sequence import CustomList +from ._custom_sequence import CustomTuple + + +@dataclass +class Builder: + _snapshot_context: AdapterContext + _build_new_value: bool = False + _recursive: bool = True + + def _get_handler_recursive(self, v) -> Custom: + if self._recursive: + return self._get_handler(v) + else: + return v + + def _get_handler(self, v) -> Custom: + + from inline_snapshot._global_state import state + + result = v + + while not isinstance(result, Custom): + with compare_context(): + r = state().pm.hook.customize( + value=result, + builder=self, + local_vars=self._get_local_vars, + global_vars=self._get_global_vars, + ) + if r is None: + result = CustomCode(result) + else: + result = r + + result.__dict__["original_value"] = v._eval() if isinstance(v, Custom) else v + + if not isinstance(v, Custom) and self._build_new_value: + is_same = False + v_eval = result._eval() + + if ( + hasattr(v, "__pydantic_generic_metadata__") + and v.__pydantic_generic_metadata__["origin"] == v_eval + ): + is_same = True + + if not is_same and v_eval == v: + is_same = True + + if not is_same: + raise UsageError( + f"""\ +Customized value does not match original value: + +original_value={v!r} + +customized_value={result._eval()!r} +customized_representation={result!r} +""" + ) + + return result + + def create_external( + self, value: Any, format: str | None = None, storage: str | None = None + ): + """ + Creates a new `external()` with the given format and storage. + """ + + return CustomExternal(value, format=format, storage=storage) + + def create_list(self, value: list) -> Custom: + """ + Creates an intermediate node for a list-expression which can be used as a result for your customization function. + + `create_list([1,2,3])` becomes `[1,2,3]` in the code. + List elements don't have to be Custom nodes and are converted by inline-snapshot if needed. + """ + custom = [self._get_handler_recursive(v) for v in value] + return CustomList(value=custom) + + def create_tuple(self, value: tuple) -> Custom: + """ + Creates an intermediate node for a tuple-expression which can be used as a result for your customization function. + + `create_tuple((1, 2, 3))` becomes `(1, 2, 3)` in the code. + Tuple elements don't have to be Custom nodes and are converted by inline-snapshot if needed. + """ + custom = [self._get_handler_recursive(v) for v in value] + return CustomTuple(value=custom) + + def with_default(self, value: Any, default: Any): + """ + Creates an intermediate node for a default value which can be used as an argument for create_call. + + Arguments are not included in the generated code when they match the actual default. + The value doesn't have to be a Custom node and is converted by inline-snapshot if needed. + """ + if isinstance(default, Custom): + raise UsageError("default value can not be an Custom value") + + if value == default: + return CustomDefault(value=self._get_handler_recursive(value)) + return value + + def create_call( + self, function: Custom | Callable, posonly_args=[], kwargs={} + ) -> Custom: + """ + Creates an intermediate node for a function call expression which can be used as a result for your customization function. + + `create_call(MyClass, [arg1, arg2], {'key': value})` becomes `MyClass(arg1, arg2, key=value)` in the code. + Function, arguments, and keyword arguments don't have to be Custom nodes and are converted by inline-snapshot if needed. + """ + function = self._get_handler_recursive(function) + posonly_args = [self._get_handler_recursive(arg) for arg in posonly_args] + kwargs = {k: self._get_handler_recursive(arg) for k, arg in kwargs.items()} + + return CustomCall( + function=function, + args=posonly_args, + kwargs=kwargs, + ) + + def create_dict(self, value: dict) -> Custom: + """ + Creates an intermediate node for a dict-expression which can be used as a result for your customization function. + + `create_dict({'key': 'value'})` becomes `{'key': 'value'}` in the code. + Dict keys and values don't have to be Custom nodes and are converted by inline-snapshot if needed. + """ + custom = { + self._get_handler_recursive(k): self._get_handler_recursive(v) + for k, v in value.items() + } + return CustomDict(value=custom) + + @cached_property + def _get_local_vars(self): + """Get local vars from snapshot context.""" + if ( + self._snapshot_context is not None + and (frame := self._snapshot_context.frame) is not None + ): + return { + var_name: var_value + for var_name, var_value in frame.locals.items() + if "@" not in var_name + } + return {} + + @cached_property + def _get_global_vars(self): + """Get global vars from snapshot context.""" + if ( + self._snapshot_context is not None + and (frame := self._snapshot_context.frame) is not None + ): + return { + var_name: var_value + for var_name, var_value in frame.globals.items() + if "@" not in var_name + } + return {} + + def _build_import_vars(self, imports): + """Build import vars from imports parameter.""" + import_vars = {} + if imports: + import importlib + + for imp in imports: + if isinstance(imp, Import): + # import module - makes top-level package available + importlib.import_module(imp.module) + top_level = imp.module.split(".")[0] + import_vars[top_level] = importlib.import_module(top_level) + elif isinstance(imp, ImportFrom): + # from module import name + module = importlib.import_module(imp.module) + import_vars[imp.name] = getattr(module, imp.name) + else: + assert False + return import_vars + + def create_code( + self, code: str, *, imports: list[Import | ImportFrom] = [] + ) -> Custom: + """ + Creates an intermediate node for a value with a custom representation which can be used as a result for your customization function. + + `create_code('{value-1!r}+1')` becomes `4+1` in the code. + Use this when you need to control the exact string representation of a value. + + Arguments: + code: Custom string representation to evaluate. This is required and will be evaluated using the snapshot context. + imports: Optional list of Import and ImportFrom objects to add required imports to the generated code. + Example: `imports=[Import("os"), ImportFrom("pathlib", "Path")]` + """ + import_vars = None + + # Try direct variable lookup for simple identifiers (fastest) + if code.isidentifier(): + # Direct lookup with proper precedence: local > import > global + if code in self._get_local_vars: + return CustomCode(self._get_local_vars[code], code, imports) + + # Build import vars only if needed + import_vars = self._build_import_vars(imports) + if code in import_vars: + return CustomCode(import_vars[code], code, imports) + + if code in self._get_global_vars: + return CustomCode(self._get_global_vars[code], code, imports) + + # Try ast.literal_eval for simple literals (fast and safe) + try: + import ast + + return CustomCode(ast.literal_eval(code), code, imports) + except (ValueError, SyntaxError): + # Fall back to eval with context for complex expressions + # Build evaluation context with proper precedence: global < import < local + if import_vars is None: + import_vars = self._build_import_vars(imports) + eval_context = { + **self._get_global_vars, + **import_vars, + **self._get_local_vars, + } + return CustomCode(eval(code, eval_context), code, imports) diff --git a/src/inline_snapshot/_customize/_custom.py b/src/inline_snapshot/_customize/_custom.py new file mode 100644 index 00000000..fe6a3575 --- /dev/null +++ b/src/inline_snapshot/_customize/_custom.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import ast +from abc import ABC +from abc import abstractmethod +from typing import TYPE_CHECKING +from typing import Any +from typing import Generator + +from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._change import ChangeBase + +if TYPE_CHECKING: + pass + + +class Custom(ABC): + """ + Custom objects are returned by the `create_*` functions of the builder. + They should only be returned in your customize function or used as arguments for other `create_*` functions. + """ + + node_type: type[ast.AST] = ast.AST + original_value: Any + + def __hash__(self): + return hash(self._eval()) + + def __eq__(self, other): + assert isinstance(other, Custom) + return self._eval() == other._eval() + + @abstractmethod + def _map(self, f): + raise NotImplementedError() + + @abstractmethod + def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + raise NotImplementedError() + + def _eval(self): + return self._map(lambda a: a) diff --git a/src/inline_snapshot/_customize/_custom_call.py b/src/inline_snapshot/_customize/_custom_call.py new file mode 100644 index 00000000..e2c884f9 --- /dev/null +++ b/src/inline_snapshot/_customize/_custom_call.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import ast +from dataclasses import dataclass +from dataclasses import field +from typing import Generator + +from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._change import ChangeBase + +from ._custom import Custom + + +@dataclass(frozen=True) +class CustomDefault(Custom): + value: Custom = field(compare=False) + + def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + yield from () # pragma: no cover + # this should never be called because default values are never converted into code + assert False + + def _map(self, f): + return self.value._map(f) + + +def unwrap_default(value): + if isinstance(value, CustomDefault): + return value.value + return value + + +@dataclass(frozen=True) +class CustomCall(Custom): + node_type = ast.Call + function: Custom = field(compare=False) + args: list[Custom] = field(compare=False) + kwargs: dict[str, Custom] = field(compare=False) + + def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + args = [] + for a in self.args: + code = yield from a._code_repr(context) + args.append(code) + + for k, v in self.kwargs.items(): + if not isinstance(v, CustomDefault): + code = yield from v._code_repr(context) + args.append(f"{k}={code}") + + return f"{yield from self.function._code_repr(context)}({', '.join(args)})" + + def argument(self, pos_or_str): + if isinstance(pos_or_str, int): + return unwrap_default(self.args[pos_or_str]) + else: + return unwrap_default(self.kwargs[pos_or_str]) + + def _map(self, f): + return self.function._map(f)( + *[f(x._map(f)) for x in self.args], + **{k: f(v._map(f)) for k, v in self.kwargs.items()}, + ) diff --git a/src/inline_snapshot/_customize/_custom_code.py b/src/inline_snapshot/_customize/_custom_code.py new file mode 100644 index 00000000..9c0b27e8 --- /dev/null +++ b/src/inline_snapshot/_customize/_custom_code.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import ast +import importlib +from dataclasses import dataclass +from typing import Generator + +from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._change import ChangeBase +from inline_snapshot._change import RequiredImport +from inline_snapshot._code_repr import HasRepr +from inline_snapshot._code_repr import value_code_repr +from inline_snapshot._utils import clone + +from ._custom import Custom + + +@dataclass(frozen=True) +class Import: + """Represents an import statement: `import module`""" + + module: str + + +@dataclass(frozen=True) +class ImportFrom: + """Represents a from-import statement: `from module import name`""" + + module: str + name: str + + +def _simplify_module_path(module: str, name: str) -> str: + """Simplify module path by finding the shortest import path for a given name.""" + value = getattr(importlib.import_module(module), name) + parts = module.split(".") + while len(parts) >= 2: + if getattr(importlib.import_module(".".join(parts[:-1])), name, None) == value: + parts.pop() + else: + break + return ".".join(parts) + + +class CustomCode(Custom): + _imports: list[Import | ImportFrom] + + def __init__(self, value, repr_str=None, imports: list[Import | ImportFrom] = []): + assert not isinstance(value, Custom) + value = clone(value) + self._imports = list(imports) + + if repr_str is None: + self.repr_str = value_code_repr(value) + + try: + ast.parse(self.repr_str) + except SyntaxError: + self.repr_str = HasRepr(type(value), self.repr_str).__repr__() + self._imports.append(ImportFrom("inline_snapshot", "HasRepr")) + else: + self.repr_str = repr_str + + self.value = value + + super().__init__() + + def _map(self, f): + return f(self.value) + + def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + for imp in self._imports: + if isinstance(imp, Import): + yield RequiredImport(flag="fix", file=context.file, module=imp.module) + elif isinstance(imp, ImportFrom): + yield RequiredImport( + flag="fix", + file=context.file, + module=_simplify_module_path(imp.module, imp.name), + name=imp.name, + ) + else: + assert False + + return self.repr_str + + def __repr__(self): + return f"CustomCode({self.repr_str!r})" diff --git a/src/inline_snapshot/_customize/_custom_dict.py b/src/inline_snapshot/_customize/_custom_dict.py new file mode 100644 index 00000000..43274b10 --- /dev/null +++ b/src/inline_snapshot/_customize/_custom_dict.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import ast +from dataclasses import dataclass +from dataclasses import field +from typing import Generator + +from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._change import ChangeBase + +from ._custom import Custom + + +@dataclass(frozen=True) +class CustomDict(Custom): + node_type = ast.Dict + value: dict[Custom, Custom] = field(compare=False) + + def _map(self, f): + return f({k._map(f): v._map(f) for k, v in self.value.items()}) + + def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + values = [] + for k, v in self.value.items(): + key = yield from k._code_repr(context) + value = yield from v._code_repr(context) + values.append(f"{key}: {value}") + + return f"{{{ ', '.join(values)}}}" diff --git a/src/inline_snapshot/_customize/_custom_external.py b/src/inline_snapshot/_customize/_custom_external.py new file mode 100644 index 00000000..069fdcea --- /dev/null +++ b/src/inline_snapshot/_customize/_custom_external.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any +from typing import Generator + +from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._change import ChangeBase +from inline_snapshot._change import ExternalChange +from inline_snapshot._change import RequiredImport +from inline_snapshot._external._external_location import ExternalLocation +from inline_snapshot._external._format._protocol import get_format_handler + +from ._custom import Custom + + +@dataclass(frozen=True) +class CustomExternal(Custom): + value: Any + format: str | None = None + storage: str | None = None + + def _map(self, f): + return f(self.value) + + def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + from inline_snapshot._global_state import state + + storage_name = self.storage or state().config.default_storage + + format = get_format_handler(self.value, self.format or "") + + location = ExternalLocation( + storage=storage_name, + stem="", + suffix=self.format or format.suffix, + filename=Path(context.file.filename), + qualname=context.qualname, + ) + + tmp_file = state().new_tmp_path(location.suffix) + + storage = state().all_storages[storage_name] + + format.encode(self.value, tmp_file) + location = storage.new_location(location, tmp_file) + + yield ExternalChange( + "create", + tmp_file, + ExternalLocation.from_name("", context=context), + location, + format, + ) + yield RequiredImport("create", context.file, "inline_snapshot", "external") + + return f"external({location.to_str()!r})" diff --git a/src/inline_snapshot/_customize/_custom_sequence.py b/src/inline_snapshot/_customize/_custom_sequence.py new file mode 100644 index 00000000..98252ae2 --- /dev/null +++ b/src/inline_snapshot/_customize/_custom_sequence.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import ast +from dataclasses import dataclass +from dataclasses import field +from typing import Generator + +from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._change import ChangeBase + +from ._custom import Custom + + +class CustomSequenceTypes: + trailing_comma: bool + braces: str + value_type: type + + +@dataclass(frozen=True) +class CustomSequence(Custom, CustomSequenceTypes): + value: list[Custom] = field(compare=False) + + def _map(self, f): + return f(self.value_type([x._map(f) for x in self.value])) + + def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + values = [] + for v in self.value: + value = yield from v._code_repr(context) + values.append(value) + + trailing_comma = self.trailing_comma and len(self.value) == 1 + return f"{self.braces[0]}{', '.join(values)}{', ' if trailing_comma else ''}{self.braces[1]}" + + +class CustomList(CustomSequence): + node_type = ast.List + value_type = list + braces = "[]" + trailing_comma = False + + +class CustomTuple(CustomSequence): + node_type = ast.Tuple + value_type = tuple + braces = "()" + trailing_comma = True diff --git a/src/inline_snapshot/_customize/_custom_undefined.py b/src/inline_snapshot/_customize/_custom_undefined.py new file mode 100644 index 00000000..90da8e69 --- /dev/null +++ b/src/inline_snapshot/_customize/_custom_undefined.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from typing import Generator + +from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._change import ChangeBase +from inline_snapshot._sentinels import undefined + +from ._custom import Custom + + +class CustomUndefined(Custom): + def __init__(self): + self.value = undefined + + def _code_repr(self, context: AdapterContext) -> Generator[ChangeBase, None, str]: + yield from () + return "..." + + def _map(self, f): + return f(undefined) diff --git a/src/inline_snapshot/_customize/_custom_unmanaged.py b/src/inline_snapshot/_customize/_custom_unmanaged.py new file mode 100644 index 00000000..62b63e54 --- /dev/null +++ b/src/inline_snapshot/_customize/_custom_unmanaged.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any +from typing import Generator + +from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._change import ChangeBase + +from ._custom import Custom + + +@dataclass() +class CustomUnmanaged(Custom): + value: Any + + def _code_repr( + self, context: AdapterContext + ) -> Generator[ChangeBase, None, str]: # pragma: no cover + yield from () + return "'unmanaged'" + + def _map(self, f): + return f(self.value) diff --git a/src/inline_snapshot/_external/_external.py b/src/inline_snapshot/_external/_external.py index 876809fc..553c5493 100644 --- a/src/inline_snapshot/_external/_external.py +++ b/src/inline_snapshot/_external/_external.py @@ -3,7 +3,7 @@ import ast from pathlib import Path -from inline_snapshot._adapter.adapter import AdapterContext +from inline_snapshot._adapter_context import AdapterContext from inline_snapshot._change import CallArg from inline_snapshot._change import Replace from inline_snapshot._exceptions import UsageError @@ -32,6 +32,9 @@ def is_inside_testdir(path: Path) -> bool: @declare_unmanaged class External(ExternalBase): + _original_location: ExternalLocation + _location: ExternalLocation + def __init__(self, name: str, expr, context: AdapterContext): """External objects are used to move some data outside the source code. You should not instantiate this class directly, but by using `external()` instead. diff --git a/src/inline_snapshot/_external/_external_base.py b/src/inline_snapshot/_external/_external_base.py index 76ae3de7..79295c02 100644 --- a/src/inline_snapshot/_external/_external_base.py +++ b/src/inline_snapshot/_external/_external_base.py @@ -4,10 +4,13 @@ from inline_snapshot._change import ExternalChange from inline_snapshot._exceptions import UsageError +from inline_snapshot._external._format._protocol import get_format_handler from inline_snapshot._external._outsource import Outsourced +from inline_snapshot._external._storage._protocol import StorageLookupError from inline_snapshot._global_state import state from .._snapshot.generic_value import GenericValue +from ._external_location import ExternalLocation from ._external_location import Location @@ -46,15 +49,18 @@ def __eq__(self, other): __tracebackhide__ = True if isinstance(other, Outsourced): - self._location.suffix = other._location.suffix + + format = get_format_handler(other.data, other.suffix or "") + + self._location.suffix = other.suffix or format.suffix other = other.data - if isinstance(other, ExternalBase): + elif isinstance(other, ExternalBase): raise UsageError( f"you can not compare {external_type}(...) with {external_type}(...)" ) - if isinstance(other, GenericValue): + elif isinstance(other, GenericValue): raise UsageError( f"you can not compare {external_type}(...) with snapshot(...)" ) @@ -70,7 +76,17 @@ def __eq__(self, other): return True return False - value = self._load_value() + try: + value = self._load_value() + except StorageLookupError as error: + if not error.files and state().update_flags.fix: + self._original_location = ExternalLocation.from_name("") + self._assign(other) + state().incorrect_values += 1 + return True + else: + raise + result = value == other if not result and first_comparison: diff --git a/src/inline_snapshot/_external/_external_file.py b/src/inline_snapshot/_external/_external_file.py index e2d64d75..978e6d2f 100644 --- a/src/inline_snapshot/_external/_external_file.py +++ b/src/inline_snapshot/_external/_external_file.py @@ -39,7 +39,9 @@ def _load_value(self): try: return self._format.decode(self._filename) except FileNotFoundError: - raise StorageLookupError(f"can not read {self._filename}") + raise StorageLookupError( + f"can not read {self._filename}", files=[self._filename] + ) def external_file(path: Union[Path, str], *, format: Optional[str] = None): diff --git a/src/inline_snapshot/_external/_external_location.py b/src/inline_snapshot/_external/_external_location.py index 28f6a3ea..f5adddad 100644 --- a/src/inline_snapshot/_external/_external_location.py +++ b/src/inline_snapshot/_external/_external_location.py @@ -8,7 +8,7 @@ from pathlib import Path from typing import Generator -from inline_snapshot._adapter.adapter import AdapterContext +from inline_snapshot._adapter_context import AdapterContext class Location: diff --git a/src/inline_snapshot/_external/_find_external.py b/src/inline_snapshot/_external/_find_external.py index 5246e72c..8c2c697a 100644 --- a/src/inline_snapshot/_external/_find_external.py +++ b/src/inline_snapshot/_external/_find_external.py @@ -1,7 +1,11 @@ import ast +import os from dataclasses import replace from pathlib import Path +from typing import Dict from typing import List +from typing import Optional +from typing import Set from typing import Union from executing import Source @@ -24,6 +28,16 @@ def contains_import(tree, module, name): return False +def contains_module_import(tree, module): + for node in tree.body: + if isinstance(node, ast.Import): + if any( + alias.name == module and alias.asname is None for alias in node.names + ): + return True + return False + + def used_externals_in( filename: Path, source: Union[str, ast.Module], check_import=True ) -> List[ExternalLocation]: @@ -64,7 +78,44 @@ def used_externals_in( return usages -def ensure_import(filename, imports, recorder: ChangeRecorder): +def module_name_of(filename: Union[str, os.PathLike]) -> Optional[str]: + path = Path(filename).resolve() + + assert path.suffix == ".py" + + parts = [] + + if path.name != "__init__.py": + parts.append(path.stem) + + current = path.parent + + while current != current.root: + if not (current / "__init__.py").exists(): + break + + parts.append(current.name) + + next_parent = current.parent + if next_parent == current: + break # pragma: no cover + current = next_parent + else: + pass # pragma: no cover + + parts.reverse() + + assert parts + + return ".".join(parts) + + +def ensure_import( + filename, + imports: Dict[str, Set[str]], + module_imports: Set[str], + recorder: ChangeRecorder, +): source = Source.for_filename(filename) change = recorder.new_change() @@ -72,18 +123,35 @@ def ensure_import(filename, imports, recorder: ChangeRecorder): tree = source.tree token = source.asttokens() - to_add = [] + my_module_name = module_name_of(filename) + code = "" for module, names in imports.items(): - for name in names: + if module == my_module_name: + continue + if module == "builtins": + continue + for name in sorted(names): if not contains_import(tree, module, name): - to_add.append((module, name)) + code += f"from {module} import {name}\n" + + for module in sorted(module_imports): + if not contains_module_import(tree, module): + code += f"import {module}\n" assert isinstance(tree, ast.Module) + # find source position last_import = None for node in tree.body: - if not isinstance(node, (ast.ImportFrom, ast.Import)): + if not ( + isinstance(node, (ast.ImportFrom, ast.Import)) + or ( + isinstance(node, ast.Expr) + and isinstance(node.value, ast.Constant) + and isinstance(node.value.value, str) + ) + ): break last_import = node @@ -99,8 +167,6 @@ def ensure_import(filename, imports, recorder: ChangeRecorder): break position = end_of(last_token) - code = "" - for module, name in to_add: - code += f"\nfrom {module} import {name}\n" - - change.insert(position, code, filename=filename) + if code: + code = "\n" + code + change.insert(position, code, filename=filename) diff --git a/src/inline_snapshot/_external/_outsource.py b/src/inline_snapshot/_external/_outsource.py index 9a4c23ef..dd5e62ca 100644 --- a/src/inline_snapshot/_external/_outsource.py +++ b/src/inline_snapshot/_external/_outsource.py @@ -1,35 +1,18 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Any -from inline_snapshot._external._external_location import ExternalLocation from inline_snapshot._external._format._protocol import get_format_handler from inline_snapshot._global_state import state - -from .._snapshot.generic_value import GenericValue +from inline_snapshot._snapshot.generic_value import GenericValue +@dataclass class Outsourced: - def __init__(self, data: Any, suffix: str | None): - self.data = data - - self._format = get_format_handler(data, suffix or "") - if suffix is None: - suffix = self._format.suffix - - self._location = ExternalLocation("hash", "", suffix, None, None) - - tmp_path = state().new_tmp_path(suffix) - - self._format.encode(data, tmp_path) - - storage = state().all_storages["hash"] - - self._location = storage.new_location( - self._location, tmp_path # type:ignore - ) - - storage.store(self._location, tmp_path) # type: ignore + data: Any + suffix: str | None + storage: str | None def __eq__(self, other): if isinstance(other, GenericValue): @@ -38,20 +21,22 @@ def __eq__(self, other): if isinstance(other, Outsourced): return self.data == other.data - return NotImplemented + from inline_snapshot._external._external_base import ExternalBase - def __repr__(self) -> str: - return f'external("{self._location.to_str()}")' + if isinstance(other, ExternalBase): + return NotImplemented - def _load_value(self) -> Any: - return self.data + return self.data == other -def outsource(data: Any, suffix: str | None = None) -> Any: +def outsource(data: Any, suffix: str | None = None, storage: str | None = None) -> Any: if suffix and suffix[0] != ".": raise ValueError("suffix has to start with a '.' like '.png'") if not state().active: return data - return Outsourced(data, suffix) + # check if the suffix/datatype is supported + get_format_handler(data, suffix or "") + + return Outsourced(data, suffix, storage) diff --git a/src/inline_snapshot/_external/_storage/__init__.py b/src/inline_snapshot/_external/_storage/__init__.py index 45220b1b..7766487e 100644 --- a/src/inline_snapshot/_external/_storage/__init__.py +++ b/src/inline_snapshot/_external/_storage/__init__.py @@ -9,5 +9,4 @@ def default_storages(storage_dir: Path): - return {"hash": HashStorage(storage_dir / "external"), "uuid": UuidStorage()} diff --git a/src/inline_snapshot/_external/_storage/_hash.py b/src/inline_snapshot/_external/_storage/_hash.py index 4876c5b7..f54b4f38 100644 --- a/src/inline_snapshot/_external/_storage/_hash.py +++ b/src/inline_snapshot/_external/_storage/_hash.py @@ -93,11 +93,13 @@ def _lookup_path(self, name) -> Path: if len(files) > 1: raise StorageLookupError( - f"hash collision files={sorted(f.name for f in files)}" + f"hash collision files={sorted(f.name for f in files)}", files=files ) if not files: - raise StorageLookupError(f"hash {name!r} is not found in the HashStorage") + raise StorageLookupError( + f"hash {name!r} is not found in the HashStorage", files=[] + ) return files[0] diff --git a/src/inline_snapshot/_external/_storage/_protocol.py b/src/inline_snapshot/_external/_storage/_protocol.py index 1e174ab3..45168e5c 100644 --- a/src/inline_snapshot/_external/_storage/_protocol.py +++ b/src/inline_snapshot/_external/_storage/_protocol.py @@ -8,7 +8,9 @@ class StorageLookupError(Exception): - pass + def __init__(self, msg, files): + super().__init__(msg) + self.files = files class StorageProtocol: diff --git a/src/inline_snapshot/_external/_storage/_uuid.py b/src/inline_snapshot/_external/_storage/_uuid.py index de40ad00..07cd3e26 100644 --- a/src/inline_snapshot/_external/_storage/_uuid.py +++ b/src/inline_snapshot/_external/_storage/_uuid.py @@ -48,7 +48,7 @@ def _lookup_path(self, location: ExternalLocation): if location.path in external_files(): return external_files()[location.path] else: - raise StorageLookupError(location) + raise StorageLookupError(location, files=[]) def store(self, location: ExternalLocation, file_path: Path): snapshot_path = self._get_path(location) diff --git a/src/inline_snapshot/_generator_utils.py b/src/inline_snapshot/_generator_utils.py new file mode 100644 index 00000000..5119fc82 --- /dev/null +++ b/src/inline_snapshot/_generator_utils.py @@ -0,0 +1,42 @@ +from collections import namedtuple + +IterResult = namedtuple("IterResult", "value list") + + +def split_gen(gen): + it = iter(gen) + l = [] + while True: + try: + l.append(next(it)) + except StopIteration as stop: + return IterResult(stop.value, l) + + +def only_value(gen): + return split_gen(gen).value + + +def gen_map(stream, f): + stream = iter(stream) + while True: + try: + yield f(next(stream)) + except StopIteration as stop: + return stop.value + + +def with_flag(stream, flag): + def map(change): + change.flag = flag + return change + + return gen_map(stream, map) + + +def make_gen_map(f): + def m(stream): + + return gen_map(stream, f) + + return m diff --git a/src/inline_snapshot/_get_snapshot_value.py b/src/inline_snapshot/_get_snapshot_value.py index 8239f8e4..a6bfffe6 100644 --- a/src/inline_snapshot/_get_snapshot_value.py +++ b/src/inline_snapshot/_get_snapshot_value.py @@ -1,6 +1,5 @@ from typing import TypeVar -from ._adapter.adapter import adapter_map from ._exceptions import UsageError from ._external._external import External from ._external._external_file import ExternalFile @@ -10,12 +9,14 @@ from ._is import Is from ._snapshot.generic_value import GenericValue from ._types import Snapshot -from ._unmanaged import Unmanaged def unwrap(value): if isinstance(value, GenericValue): - return adapter_map(value._visible_value(), lambda v: unwrap(v)[0]), True + return value._visible_value()._map(lambda v: unwrap(v)[0]), True + + if isinstance(value, Outsourced): + return (value.data, True) if isinstance(value, (External, Outsourced, ExternalFile)): try: @@ -23,9 +24,6 @@ def unwrap(value): except (UsageError, StorageLookupError): return (None, False) - if isinstance(value, Unmanaged): - return unwrap(value.value)[0], True - if isinstance(value, Is): return value.value, True diff --git a/src/inline_snapshot/_global_state.py b/src/inline_snapshot/_global_state.py index 57fdc589..360b99e0 100644 --- a/src/inline_snapshot/_global_state.py +++ b/src/inline_snapshot/_global_state.py @@ -12,7 +12,10 @@ from typing import Literal from uuid import uuid4 +import pluggy + from inline_snapshot._config import Config +from inline_snapshot.plugin._spec import inline_snapshot_plugin_name if TYPE_CHECKING: from inline_snapshot._external._format._protocol import Format @@ -50,6 +53,10 @@ class State: default_factory=lambda: TemporaryDirectory(prefix="inline-snapshot-") ) + pm: pluggy.PluginManager = field( + default_factory=lambda: pluggy.PluginManager(inline_snapshot_plugin_name) + ) + def new_tmp_path(self, suffix: str) -> Path: assert self.tmp_dir is not None return Path(self.tmp_dir.name) / f"tmp-path-{uuid4()}{suffix}" @@ -75,6 +82,37 @@ def enter_snapshot_context(): _current.all_formats = dict(latest.all_formats) _current.config = deepcopy(latest.config) + from .plugin._spec import InlineSnapshotPluginSpec + + _current.pm.add_hookspecs(InlineSnapshotPluginSpec) + + from .plugin._default_plugin import InlineSnapshotPlugin + + _current.pm.register(InlineSnapshotPlugin()) + + try: + from .plugin._default_plugin import InlineSnapshotAttrsPlugin + except ImportError: # pragma: no cover + pass + else: + _current.pm.register(InlineSnapshotAttrsPlugin()) + + try: + from .plugin._default_plugin import InlineSnapshotPydanticPlugin + except ImportError: # pragma: no cover + pass + else: + _current.pm.register(InlineSnapshotPydanticPlugin()) + + try: + from .plugin._default_plugin import InlineSnapshotDirtyEqualsPlugin + except ImportError: # pragma: no cover + pass + else: + _current.pm.register(InlineSnapshotDirtyEqualsPlugin()) + + _current.pm.load_setuptools_entrypoints(inline_snapshot_plugin_name) + def leave_snapshot_context(): global _current diff --git a/src/inline_snapshot/_inline_snapshot.py b/src/inline_snapshot/_inline_snapshot.py index 0f098153..441903fb 100644 --- a/src/inline_snapshot/_inline_snapshot.py +++ b/src/inline_snapshot/_inline_snapshot.py @@ -7,13 +7,15 @@ from executing import Source +from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._adapter_context import FrameContext +from inline_snapshot._customize._custom_undefined import CustomUndefined +from inline_snapshot._generator_utils import with_flag from inline_snapshot._source_file import SourceFile from inline_snapshot._types import SnapshotRefBase -from ._adapter.adapter import AdapterContext -from ._adapter.adapter import FrameContext from ._change import CallArg -from ._change import Change +from ._change import ChangeBase from ._global_state import state from ._sentinels import undefined from ._snapshot.undecided_value import UndecidedValue @@ -125,18 +127,18 @@ def result(self): def create_raw(obj, context: AdapterContext): return obj - def _changes(self) -> Iterator[Change]: + def _changes(self) -> Iterator[ChangeBase]: if ( - self._value._old_value is undefined + isinstance(self._value._old_value, CustomUndefined) if self._expr is None else not self._expr.node.args ): - if self._value._new_value is undefined: + if isinstance(self._value._new_value, CustomUndefined): return - new_code = self._value._new_code() + new_code = yield from with_flag(self._value._new_code(), "create") yield CallArg( flag="create", diff --git a/src/inline_snapshot/_new_adapter.py b/src/inline_snapshot/_new_adapter.py new file mode 100644 index 00000000..e73cb871 --- /dev/null +++ b/src/inline_snapshot/_new_adapter.py @@ -0,0 +1,526 @@ +from __future__ import annotations + +import ast +import warnings +from collections import defaultdict +from typing import Generator +from typing import Sequence + +from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._align import add_x +from inline_snapshot._align import align +from inline_snapshot._change import CallArg +from inline_snapshot._change import ChangeBase +from inline_snapshot._change import Delete +from inline_snapshot._change import DictInsert +from inline_snapshot._change import ListInsert +from inline_snapshot._change import Replace +from inline_snapshot._compare_context import compare_context +from inline_snapshot._customize._custom import Custom +from inline_snapshot._customize._custom_call import CustomCall +from inline_snapshot._customize._custom_call import CustomDefault +from inline_snapshot._customize._custom_code import CustomCode +from inline_snapshot._customize._custom_dict import CustomDict +from inline_snapshot._customize._custom_sequence import CustomList +from inline_snapshot._customize._custom_sequence import CustomSequence +from inline_snapshot._customize._custom_undefined import CustomUndefined +from inline_snapshot._customize._custom_unmanaged import CustomUnmanaged +from inline_snapshot._exceptions import UsageError +from inline_snapshot._generator_utils import make_gen_map +from inline_snapshot._generator_utils import only_value +from inline_snapshot._generator_utils import split_gen +from inline_snapshot.syntax_warnings import InlineSnapshotInfo +from inline_snapshot.syntax_warnings import InlineSnapshotSyntaxWarning + + +def warn_star_expression(node, context): + if isinstance(node, ast.Call): + for pos_arg in node.args: + if isinstance(pos_arg, ast.Starred): + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=context.file._source.filename, + lineno=pos_arg.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return True + + # keyword arguments + for kw in node.keywords: + if kw.arg is None: + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=context.file._source.filename, + lineno=kw.value.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return True + + if isinstance(node, (ast.Tuple, ast.List)): + + for e in node.elts: + if isinstance(e, ast.Starred): + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=context.file.filename, + lineno=e.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return True + if isinstance(node, ast.Dict): + + for key1, value in zip(node.keys, node.values): + if key1 is None: + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=context.file._source.filename, + lineno=value.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return True + + return False + + +def reeval(old_value: Custom, value: Custom) -> Custom: + function_name = f"reeval_{type(old_value).__name__}" + result = globals()[function_name](old_value, value) + assert isinstance(result, Custom) + + return result + + +def reeval_CustomList(old_value: CustomList, value: CustomList): + assert len(old_value.value) == len(value.value) + return CustomList([reeval(a, b) for a, b in zip(old_value.value, value.value)]) + + +reeval_CustomTuple = reeval_CustomList + + +def reeval_CustomUnmanaged(old_value: CustomUnmanaged, value: CustomUnmanaged): + old_value.value = value.value + return old_value + + +def reeval_CustomUndefined(old_value, value): + return value + + +def reeval_CustomCode(old_value: CustomCode, value: CustomCode): + + if not old_value._eval() == value._eval(): + raise UsageError( + "snapshot value should not change. Use Is(...) for dynamic snapshot parts." + ) + + return value + + +def reeval_CustomCall(old_value: CustomCall, value: CustomCall): + return CustomCall( + reeval(old_value.function, value.function), + [reeval(a, b) for a, b in zip(old_value.args, value.args)], + {k: reeval(old_value.kwargs[k], value.kwargs[k]) for k in old_value.kwargs}, + ) + + +def reeval_CustomDict(old_value, value): + assert len(old_value.value) == len(value.value) + return CustomDict( + { + reeval(k1, k2): reeval(v1, v2) + for (k1, v1), (k2, v2) in zip(old_value.value.items(), value.value.items()) + } + ) + + +class NewAdapter: + + def __init__(self, context: AdapterContext): + self.context = context + + def compare( + self, old_value: Custom, old_node, new_value: Custom + ) -> Generator[ChangeBase, None, Custom]: + + if isinstance(old_value, CustomUnmanaged): + return old_value + + if isinstance(new_value, CustomUnmanaged): + raise UsageError("unmanaged values can not be compared with snapshots") + + if ( + type(old_value) is type(new_value) + and ( + isinstance(old_node, new_value.node_type) + if old_node is not None + else True + ) + and ( + isinstance(old_value, (CustomCall, CustomSequence, CustomDict)) + if old_node is None + else True + ) + ): + function_name = f"compare_{type(old_value).__name__}" + result = yield from getattr(self, function_name)( + old_value, old_node, new_value + ) + else: + result = yield from self.compare_CustomCode(old_value, old_node, new_value) + return result + + def compare_CustomCode( + self, old_value: Custom, old_node: ast.expr, new_value: Custom + ) -> Generator[ChangeBase, None, Custom]: + + assert isinstance(old_value, Custom) + assert isinstance(new_value, Custom) + assert isinstance(old_node, (ast.expr, type(None))), old_node + + new_code, new_changes = split_gen(new_value._code_repr(self.context)) + + if ( + isinstance(old_node, ast.JoinedStr) + and isinstance(new_value, CustomCode) + and isinstance(new_value.value, str) + ): + if not old_value._eval() == new_value._eval(): + + value = only_value(new_value._code_repr(self.context)) + warnings.warn_explicit( + f"inline-snapshot will be able to fix f-strings in the future.\nThe current string value is:\n {value}", + filename=self.context.file._source.filename, + lineno=old_node.lineno, + category=InlineSnapshotInfo, + ) + return old_value + + if not old_value._eval() == new_value.original_value: + if isinstance(old_value, CustomUndefined): + flag = "create" + else: + flag = "fix" + elif not isinstance( + old_value, CustomUnmanaged + ) and self.context.file.code_changed(old_node, new_code): + flag = "update" + else: + # equal and equal repr + return old_value + + for change in new_changes: + change.flag = flag + yield change + + yield Replace( + node=old_node, + file=self.context.file, + new_code=new_code, + flag=flag, + old_value=old_value._eval(), + new_value=new_value, + ) + + return new_value + + def compare_CustomSequence( + self, old_value: CustomSequence, old_node: ast.AST, new_value: CustomSequence + ) -> Generator[ChangeBase, None, CustomSequence]: + + if old_node is not None: + assert isinstance( + old_node, ast.List if isinstance(old_value._eval(), list) else ast.Tuple + ) + assert isinstance(old_node, (ast.List, ast.Tuple)) + + else: + pass # pragma: no cover + + with compare_context(): + diff = add_x(align(old_value.value, new_value.value)) + old = zip( + old_value.value, + old_node.elts if old_node is not None else [None] * len(old_value.value), + ) + new = iter(new_value.value) + old_position = 0 + to_insert = defaultdict(list) + result = [] + for c in diff: + if c in "mx": + old_value_element, old_node_element = next(old) + new_value_element = next(new) + v = yield from self.compare( + old_value_element, old_node_element, new_value_element + ) + result.append(v) + old_position += 1 + elif c == "i": + new_value_element = next(new) + new_code = yield from new_value_element._code_repr(self.context) + result.append(new_value_element) + to_insert[old_position].append((new_code, new_value_element)) + elif c == "d": + old_value_element, old_node_element = next(old) + yield Delete( + "fix", + self.context.file, + old_node_element, + old_value_element, + ) + old_position += 1 + else: + assert False + + for position, code_values in to_insert.items(): + yield ListInsert( + "fix", + self.context.file, + old_node, + position, + *zip(*code_values), # type:ignore + ) + + return type(new_value)(result) + + compare_CustomTuple = compare_CustomSequence + compare_CustomList = compare_CustomSequence + + def compare_CustomDict( + self, old_value: CustomDict, old_node: ast.Dict, new_value: CustomDict + ) -> Generator[ChangeBase, None, Custom]: + assert isinstance(old_value, CustomDict) + assert isinstance(new_value, CustomDict) + + if old_node is not None: + + for value2, node in zip(old_value.value.keys(), old_node.keys): + assert node is not None + try: + # this is just a sanity check, dicts should be ordered + node_value = ast.literal_eval(node) + except Exception: + continue + assert node_value == value2._eval() + else: + pass # pragma: no cover + + result = {} + for key2, node2 in zip( + old_value.value.keys(), + ( + old_node.values + if old_node is not None + else [None] * len(old_value.value) + ), + ): + if key2 not in new_value.value: + # delete entries + yield Delete("fix", self.context.file, node2, old_value.value[key2]) + + to_insert = [] + insert_pos = 0 + for key, new_value_element in new_value.value.items(): + if key not in old_value.value: + # add new values + to_insert.append((key, new_value_element)) + result[key] = new_value_element + else: + if isinstance(old_node, ast.Dict): + node = old_node.values[list(old_value.value.keys()).index(key)] + else: + node = None + # check values with same keys + result[key] = yield from self.compare( + old_value.value[key], node, new_value.value[key] + ) + + if to_insert: + new_code = [] + for k, v in to_insert: + new_code_key = yield from k._code_repr(self.context) + new_code_value = yield from v._code_repr(self.context) + new_code.append((new_code_key, new_code_value)) + + yield DictInsert( + "fix", + self.context.file, + old_node, + insert_pos, + new_code, + to_insert, + ) + to_insert = [] + + insert_pos += 1 + + if to_insert: + new_code = [] + for k, v in to_insert: + new_code_key = yield from k._code_repr(self.context) + new_code_value = yield from v._code_repr(self.context) + new_code.append( + ( + new_code_key, + new_code_value, + ) + ) + + yield DictInsert( + "fix", + self.context.file, + old_node, + len(old_value.value), + new_code, + to_insert, + ) + + return CustomDict(value=result) + + def compare_CustomCall( + self, old_value: CustomCall, old_node: ast.Call, new_value: CustomCall + ) -> Generator[ChangeBase, None, Custom]: + + call = new_value + new_args = call.args + new_kwargs = call.kwargs + + # positional arguments + + result_args = [] + + flag = "update" if old_value._eval() == new_value.original_value else "fix" + + @make_gen_map + def intercept(change): + if flag == "update" and change.flag == "fix": + change.flag = "update" + return change + + old_node_args: Sequence[ast.expr | None] + if old_node: + old_node_args = old_node.args + else: + old_node_args = [None] * len(new_args) + + for i, (new_value_element, node) in enumerate(zip(new_args, old_node_args)): + old_value_element = old_value.argument(i) + result = yield from intercept( + self.compare(old_value_element, node, new_value_element) + ) + result_args.append(result) + + old_args_len = len(old_node.args if old_node else old_value.args) + + if old_node is not None: + if old_args_len > len(new_args): + for arg_pos, node in list(enumerate(old_node.args))[len(new_args) :]: + yield Delete( + flag, + self.context.file, + node, + old_value.argument(arg_pos), + ) + + if old_args_len < len(new_args): + for insert_pos, insert_value in list(enumerate(new_args))[old_args_len:]: + new_code = yield from insert_value._code_repr(self.context) + yield CallArg( + flag=flag, + file=self.context.file, + node=old_node, + arg_pos=insert_pos, + arg_name=None, + new_code=new_code, + new_value=insert_value, + ) + + # keyword arguments + result_kwargs = {} + if old_node is None: + old_keywords = {key: None for key in old_value.kwargs.keys()} + else: + old_keywords = { + kw.arg: kw.value for kw in old_node.keywords if kw.arg is not None + } + + for kw_arg, kw_value in old_keywords.items(): + missing = kw_arg not in new_kwargs + if missing or isinstance(new_kwargs[kw_arg], CustomDefault): + # delete entries + yield Delete( + ( + "update" + if not missing + and old_value.argument(kw_arg) == new_value.argument(kw_arg) + else flag + ), + self.context.file, + kw_value, + old_value.argument(kw_arg), + ) + + to_insert = [] + insert_pos = 0 + for key, new_value_element in new_kwargs.items(): + if isinstance(new_value_element, CustomDefault): + continue + if key not in old_keywords: + # add new values + to_insert.append((key, new_value_element)) + result_kwargs[key] = new_value_element + else: + node = old_keywords[key] + + # check values with same keys + old_value_element = old_value.argument(key) + result_kwargs[key] = yield from intercept( + self.compare(old_value_element, node, new_value_element) + ) + + if to_insert: + for insert_key, value in to_insert: + new_code = yield from value._code_repr(self.context) + yield CallArg( + flag=flag, + file=self.context.file, + node=old_node, + arg_pos=insert_pos, + arg_name=insert_key, + new_code=new_code, + new_value=value, + ) + to_insert = [] + + insert_pos += 1 + + if to_insert: + + for key, value in to_insert: + new_code = yield from value._code_repr(self.context) + + yield CallArg( + flag=flag, + file=self.context.file, + node=old_node, + arg_pos=insert_pos, + arg_name=key, + new_code=new_code, + new_value=value, + ) + + return CustomCall( + ( + yield from intercept( + self.compare( + old_value.function, + old_node.func if old_node else None, + new_value.function, + ) + ) + ), + result_args, + result_kwargs, + ) diff --git a/src/inline_snapshot/_snapshot/__init__.py b/src/inline_snapshot/_snapshot/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/inline_snapshot/_snapshot/collection_value.py b/src/inline_snapshot/_snapshot/collection_value.py index 951ca021..360c452f 100644 --- a/src/inline_snapshot/_snapshot/collection_value.py +++ b/src/inline_snapshot/_snapshot/collection_value.py @@ -1,49 +1,56 @@ import ast +from typing import Generator from typing import Iterator +from typing import Union -from .._change import Change +from inline_snapshot._customize._custom_sequence import CustomList +from inline_snapshot._customize._custom_undefined import CustomUndefined + +from .._change import ChangeBase from .._change import Delete from .._change import ListInsert from .._change import Replace from .._global_state import state -from .._sentinels import undefined -from .._utils import value_to_token from .generic_value import GenericValue -from .generic_value import clone from .generic_value import ignore_old_value class CollectionValue(GenericValue): _current_op = "x in snapshot" + _ast_node: Union[ast.List, ast.Tuple] + _new_value: CustomList def __contains__(self, item): - if self._old_value is undefined: + if isinstance(self._old_value, CustomUndefined): state().missing_values += 1 - if self._new_value is undefined: - self._new_value = [clone(item)] + if isinstance(self._new_value, CustomUndefined): + self._new_value = CustomList([self.to_custom(item)]) else: - if item not in self._new_value: - self._new_value.append(clone(item)) + if item not in self._new_value._eval(): + self._new_value.value.append(self.to_custom(item)) - if ignore_old_value() or self._old_value is undefined: + if ignore_old_value() or isinstance(self._old_value, CustomUndefined): return True else: - return self._return(item in self._old_value) + return self._return(item in self._old_value._eval()) - def _new_code(self): - return self._file._value_to_code(self._new_value) + def _new_code(self) -> Generator[ChangeBase, None, str]: + code = yield from self._new_value._code_repr(self._context) + return code - def _get_changes(self) -> Iterator[Change]: + def _get_changes(self) -> Iterator[ChangeBase]: + assert isinstance(self._old_value, CustomList), self._old_value + assert isinstance(self._new_value, CustomList), self._new_value if self._ast_node is None: - elements = [None] * len(self._old_value) + elements = [None] * len(self._old_value.value) else: assert isinstance(self._ast_node, ast.List) elements = self._ast_node.elts - for old_value, old_node in zip(self._old_value, elements): - if old_value not in self._new_value: + for old_value, old_node in zip(self._old_value.value, elements): + if old_value not in self._new_value.value: yield Delete( flag="trim", file=self._file, @@ -53,13 +60,11 @@ def _get_changes(self) -> Iterator[Change]: continue # check for update - new_token = value_to_token(old_value) + new_code = yield from self.to_custom(old_value._eval())._code_repr( + self._context + ) - if ( - old_node is not None - and self._file._token_of_node(old_node) != new_token - ): - new_code = self._file._token_to_code(new_token) + if self._file.code_changed(old_node, new_code): yield Replace( node=old_node, @@ -70,13 +75,20 @@ def _get_changes(self) -> Iterator[Change]: new_value=old_value, ) - new_values = [v for v in self._new_value if v not in self._old_value] - if new_values: + new_codes = [] + new_values = [] + for v in self._new_value.value: + if v not in self._old_value.value: + new_code = yield from v._code_repr(self._context) + new_codes.append(new_code) + new_values.append(v._eval()) + + if new_codes: yield ListInsert( flag="fix", file=self._file, node=self._ast_node, - position=len(self._old_value), - new_code=[self._file._value_to_code(v) for v in new_values], + position=len(self._old_value.value), + new_code=new_codes, new_values=new_values, ) diff --git a/src/inline_snapshot/_snapshot/dict_value.py b/src/inline_snapshot/_snapshot/dict_value.py index afed0073..3071f144 100644 --- a/src/inline_snapshot/_snapshot/dict_value.py +++ b/src/inline_snapshot/_snapshot/dict_value.py @@ -1,97 +1,110 @@ import ast +from typing import Generator from typing import Iterator -from .._adapter.adapter import AdapterContext -from .._change import Change +from inline_snapshot._customize._custom_dict import CustomDict +from inline_snapshot._customize._custom_undefined import CustomUndefined + +from .._adapter_context import AdapterContext +from .._change import ChangeBase from .._change import Delete from .._change import DictInsert from .._global_state import state -from .._inline_snapshot import UndecidedValue -from .._sentinels import undefined from .generic_value import GenericValue +from .undecided_value import UndecidedValue class DictValue(GenericValue): _current_op = "snapshot[key]" + _new_value: CustomDict + _old_value: CustomDict + _ast_node: ast.Dict + def __getitem__(self, index): + if isinstance(self._new_value, CustomUndefined): + self._new_value = CustomDict({}) - if self._new_value is undefined: - self._new_value = {} + index = self.to_custom(index) - if index not in self._new_value: - old_value = self._old_value - if old_value is undefined: + if index not in self._new_value.value: + if isinstance(self._old_value, CustomUndefined): state().missing_values += 1 old_value = {} + else: + old_value = self._old_value.value child_node = None if self._ast_node is not None: assert isinstance(self._ast_node, ast.Dict) - if index in old_value: - pos = list(old_value.keys()).index(index) + old_keys = [k for k in old_value.keys()] + if index in old_keys: + pos = old_keys.index(index) child_node = self._ast_node.values[pos] - self._new_value[index] = UndecidedValue( - old_value.get(index, undefined), child_node, self._context + self._new_value.value[index] = UndecidedValue( + old_value.get(index, CustomUndefined()), child_node, self._context ) - return self._new_value[index] + return self._new_value.value[index] def _re_eval(self, value, context: AdapterContext): super()._re_eval(value, context) - if self._new_value is not undefined and self._old_value is not undefined: - for key, s in self._new_value.items(): - if key in self._old_value: - s._re_eval(self._old_value[key], context) - - def _new_code(self): - return ( - "{" - + ", ".join( - [ - f"{self._file._value_to_code(k)}: {v._new_code()}" - for k, v in self._new_value.items() - if not isinstance(v, UndecidedValue) - ] - ) - + "}" - ) + if not isinstance(self._new_value, CustomUndefined) and not isinstance( + self._old_value, CustomUndefined + ): + for key, s in self._new_value.value.items(): + if key in self._old_value.value: + s._re_eval(self._old_value.value[key], context) # type:ignore - def _get_changes(self) -> Iterator[Change]: + def _new_code(self) -> Generator[ChangeBase, None, str]: + values = [] + for k, v in self._new_value.value.items(): + if not isinstance(v, UndecidedValue): + new_code = yield from v._new_code() # type:ignore + new_key = yield from k._code_repr(self._context) + values.append(f"{new_key}: {new_code}") - assert self._old_value is not undefined + return "{" + ", ".join(values) + "}" + + def _get_changes(self) -> Iterator[ChangeBase]: + + assert not isinstance(self._old_value, CustomUndefined) if self._ast_node is None: - values = [None] * len(self._old_value) + values = [None] * len(self._old_value.value) else: assert isinstance(self._ast_node, ast.Dict) values = self._ast_node.values - for key, node in zip(self._old_value.keys(), values): - if key in self._new_value: + for key, node in zip(self._old_value.value.keys(), values): + if key in self._new_value.value: # check values with same keys - yield from self._new_value[key]._get_changes() + yield from self._new_value.value[key]._get_changes() # type:ignore else: # delete entries - yield Delete("trim", self._file, node, self._old_value[key]) + yield Delete("trim", self._file, node, self._old_value.value[key]) to_insert = [] - for key, new_value_element in self._new_value.items(): - if key not in self._old_value and not isinstance( + to_insert_values = [] + for key, new_value_element in self._new_value.value.items(): + if key not in self._old_value.value and not isinstance( new_value_element, UndecidedValue ): # add new values - to_insert.append((key, new_value_element._new_code())) + new_value = yield from new_value_element._new_code() # type:ignore + new_key = yield from key._code_repr(self._context) + + to_insert.append((new_key, new_value)) + to_insert_values.append((key, new_value_element)) if to_insert: - new_code = [(self._file._value_to_code(k), v) for k, v in to_insert] yield DictInsert( "create", self._file, self._ast_node, - len(self._old_value), - new_code, + len(self._old_value.value), to_insert, + to_insert_values, ) diff --git a/src/inline_snapshot/_snapshot/eq_value.py b/src/inline_snapshot/_snapshot/eq_value.py index 9e141964..1cb39e54 100644 --- a/src/inline_snapshot/_snapshot/eq_value.py +++ b/src/inline_snapshot/_snapshot/eq_value.py @@ -1,14 +1,16 @@ +from typing import Generator from typing import Iterator from typing import List -from inline_snapshot._adapter.adapter import Adapter +from inline_snapshot._customize._custom_undefined import CustomUndefined +from inline_snapshot._generator_utils import split_gen +from inline_snapshot._new_adapter import NewAdapter from .._change import Change +from .._change import ChangeBase from .._compare_context import compare_only from .._global_state import state -from .._sentinels import undefined from .generic_value import GenericValue -from .generic_value import clone class EqValue(GenericValue): @@ -16,24 +18,30 @@ class EqValue(GenericValue): _changes: List[Change] def __eq__(self, other): - if self._old_value is undefined: + custom_other = self.to_custom(other, _build_new_value=True) + + if isinstance(self._old_value, CustomUndefined): state().missing_values += 1 - if not compare_only() and self._new_value is undefined: + if not compare_only() and isinstance(self._new_value, CustomUndefined): self._changes = [] - adapter = Adapter(self._context).get_adapter(self._old_value, other) - it = iter(adapter.assign(self._old_value, self._ast_node, clone(other))) - while True: - try: - self._changes.append(next(it)) - except StopIteration as ex: - self._new_value = ex.value - break - return self._return(self._old_value == other, self._new_value == other) + adapter = NewAdapter(self._context) + + result = split_gen( + adapter.compare(self._old_value, self._ast_node, custom_other) + ) + self._changes = result.list + self._new_value = result.value + + return self._return( + self._old_value._eval() == other, + self._new_value._eval() == other, + ) - def _new_code(self): - return self._file._value_to_code(self._new_value) + def _new_code(self) -> Generator[ChangeBase, None, str]: + code = yield from self._new_value._code_repr(self._context) + return code def _get_changes(self) -> Iterator[Change]: - return iter(self._changes) + return iter(getattr(self, "_changes", [])) diff --git a/src/inline_snapshot/_snapshot/generic_value.py b/src/inline_snapshot/_snapshot/generic_value.py index 29db3fc8..39a157a5 100644 --- a/src/inline_snapshot/_snapshot/generic_value.py +++ b/src/inline_snapshot/_snapshot/generic_value.py @@ -1,37 +1,18 @@ import ast -import copy -from typing import Any +from typing import Generator from typing import Iterator -from .._adapter.adapter import AdapterContext -from .._adapter.adapter import get_adapter_type -from .._change import Change -from .._code_repr import code_repr -from .._exceptions import UsageError +from inline_snapshot._adapter_context import AdapterContext +from inline_snapshot._code_repr import mock_repr +from inline_snapshot._customize._builder import Builder +from inline_snapshot._customize._custom import Custom +from inline_snapshot._customize._custom_undefined import CustomUndefined +from inline_snapshot._new_adapter import reeval + +from .._change import ChangeBase from .._global_state import state -from .._sentinels import undefined from .._types import SnapshotBase -from .._unmanaged import Unmanaged from .._unmanaged import declare_unmanaged -from .._unmanaged import update_allowed - - -def clone(obj): - new = copy.deepcopy(obj) - if not obj == new: - raise UsageError( - f"""\ -inline-snapshot uses `copy.deepcopy` to copy objects, -but the copied object is not equal to the original one: - -value = {code_repr(obj)} -copied_value = copy.deepcopy(value) -assert value == copied_value - -Please fix the way your object is copied or your __eq__ implementation. -""" - ) - return new def ignore_old_value(): @@ -40,19 +21,27 @@ def ignore_old_value(): @declare_unmanaged class GenericValue(SnapshotBase): - _new_value: Any - _old_value: Any + _new_value: Custom + _old_value: Custom _current_op = "undefined" - _ast_node: ast.Expr + _ast_node: ast.expr _context: AdapterContext + def get_builder(self, **args): + return Builder(_snapshot_context=self._context, **args) + def _return(self, result, new_result=True): if not result: state().incorrect_values += 1 flags = state().update_flags - if flags.fix or flags.create or flags.update or self._old_value is undefined: + if ( + flags.fix + or flags.create + or flags.update + or isinstance(self._old_value, CustomUndefined) + ): return new_result return result @@ -60,45 +49,34 @@ def _return(self, result, new_result=True): def _file(self): return self._context.file - def get_adapter(self, value): - return get_adapter_type(value)(self._context) + def to_custom(self, value, **args): + with mock_repr(self._context): + return self.get_builder(**args)._get_handler(value) - def _re_eval(self, value, context: AdapterContext): - self._context = context - - def re_eval(old_value, node, value): - if isinstance(old_value, Unmanaged): - old_value.value = value - return + def value_to_custom(self, value): + if isinstance(value, Custom): + return value - assert type(old_value) is type(value) + if self._ast_node is None: + from inline_snapshot._snapshot.undecided_value import ValueToCustom - adapter = self.get_adapter(old_value) - if adapter is not None and hasattr(adapter, "items"): - old_items = adapter.items(old_value, node) - new_items = adapter.items(value, node) - assert len(old_items) == len(new_items) + return ValueToCustom(self._context).convert(value) + else: + from inline_snapshot._snapshot.undecided_value import AstToCustom - for old_item, new_item in zip(old_items, new_items): - re_eval(old_item.value, old_item.node, new_item.value) + return AstToCustom(self._context).convert(value, self._ast_node) - else: - if update_allowed(old_value): - if not old_value == value: - raise UsageError( - "snapshot value should not change. Use Is(...) for dynamic snapshot parts." - ) - else: - assert False, "old_value should be converted to Unmanaged" + def _re_eval(self, value, context: AdapterContext): + self._context = context - re_eval(self._old_value, self._ast_node, value) + self._old_value = reeval(self._old_value, self.value_to_custom(value)) def _ignore_old(self): return ( state().update_flags.fix or state().update_flags.update or state().update_flags.create - or self._old_value is undefined + or isinstance(self._old_value, CustomUndefined) ) def _visible_value(self): @@ -107,14 +85,14 @@ def _visible_value(self): else: return self._old_value - def _get_changes(self) -> Iterator[Change]: + def _get_changes(self) -> Iterator[ChangeBase]: raise NotImplementedError() - def _new_code(self): + def _new_code(self) -> Generator[ChangeBase, None, str]: raise NotImplementedError() def __repr__(self): - return repr(self._visible_value()) + return repr(self._visible_value()._eval()) def _type_error(self, op): __tracebackhide__ = True @@ -122,22 +100,22 @@ def _type_error(self, op): f"This snapshot cannot be use with `{op}`, because it was previously used with `{self._current_op}`" ) - def __eq__(self, _other): + def __eq__(self, other): __tracebackhide__ = True self._type_error("==") - def __le__(self, _other): + def __le__(self, other): __tracebackhide__ = True self._type_error("<=") - def __ge__(self, _other): + def __ge__(self, other): __tracebackhide__ = True self._type_error(">=") - def __contains__(self, _other): + def __contains__(self, item): __tracebackhide__ = True self._type_error("in") - def __getitem__(self, _item): + def __getitem__(self, item): __tracebackhide__ = True self._type_error("snapshot[key]") diff --git a/src/inline_snapshot/_snapshot/min_max_value.py b/src/inline_snapshot/_snapshot/min_max_value.py index 9ef0a65c..4af440df 100644 --- a/src/inline_snapshot/_snapshot/min_max_value.py +++ b/src/inline_snapshot/_snapshot/min_max_value.py @@ -1,12 +1,12 @@ +from typing import Generator from typing import Iterator -from .._change import Change +from inline_snapshot._customize._custom_undefined import CustomUndefined + +from .._change import ChangeBase from .._change import Replace from .._global_state import state -from .._sentinels import undefined -from .._utils import value_to_token from .generic_value import GenericValue -from .generic_value import clone from .generic_value import ignore_old_value @@ -18,46 +18,43 @@ def cmp(a, b): raise NotImplementedError def _generic_cmp(self, other): - if self._old_value is undefined: + if isinstance(self._old_value, CustomUndefined): state().missing_values += 1 - if self._new_value is undefined: - self._new_value = clone(other) - if self._old_value is undefined or ignore_old_value(): + if isinstance(self._new_value, CustomUndefined): + self._new_value = self.to_custom(other) + if isinstance(self._old_value, CustomUndefined) or ignore_old_value(): return True - return self._return(self.cmp(self._old_value, other)) + return self._return(self.cmp(self._old_value._eval(), other)) else: - if not self.cmp(self._new_value, other): - self._new_value = clone(other) + if not self.cmp(self._new_value._eval(), other): + self._new_value = self.to_custom(other) + + return self._return(self.cmp(self._visible_value()._eval(), other)) - return self._return(self.cmp(self._visible_value(), other)) + def _new_code(self) -> Generator[ChangeBase, None, str]: + code = yield from self._new_value._code_repr(self._context) + return code - def _new_code(self): - return self._file._value_to_code(self._new_value) + def _get_changes(self) -> Iterator[ChangeBase]: + new_code = yield from self._new_code() - def _get_changes(self) -> Iterator[Change]: - new_token = value_to_token(self._new_value) - if not self.cmp(self._old_value, self._new_value): + if not self.cmp(self._old_value._eval(), self._new_value._eval()): flag = "fix" - elif not self.cmp(self._new_value, self._old_value): + elif not self.cmp(self._new_value._eval(), self._old_value._eval()): flag = "trim" - elif ( - self._ast_node is not None - and self._file._token_of_node(self._ast_node) != new_token - ): + elif self._file.code_changed(self._ast_node, new_code): flag = "update" else: return - new_code = self._file._token_to_code(new_token) - yield Replace( node=self._ast_node, file=self._file, new_code=new_code, flag=flag, - old_value=self._old_value, - new_value=self._new_value, + old_value=self._old_value._eval(), + new_value=self._new_value._eval(), ) diff --git a/src/inline_snapshot/_snapshot/undecided_value.py b/src/inline_snapshot/_snapshot/undecided_value.py index 3098d2dc..d7dcaaff 100644 --- a/src/inline_snapshot/_snapshot/undecided_value.py +++ b/src/inline_snapshot/_snapshot/undecided_value.py @@ -1,26 +1,138 @@ +import ast +from typing import Any from typing import Iterator -from inline_snapshot._adapter.adapter import adapter_map - -from .._adapter.adapter import AdapterContext -from .._adapter.adapter import get_adapter_type -from .._change import Change -from .._change import Replace -from .._sentinels import undefined -from .._unmanaged import Unmanaged -from .._unmanaged import map_unmanaged -from .._utils import value_to_token +from inline_snapshot._code_repr import mock_repr +from inline_snapshot._compare_context import compare_only +from inline_snapshot._customize._builder import Builder +from inline_snapshot._customize._custom import Custom +from inline_snapshot._customize._custom_call import CustomCall +from inline_snapshot._customize._custom_call import CustomDefault +from inline_snapshot._customize._custom_code import CustomCode +from inline_snapshot._customize._custom_dict import CustomDict +from inline_snapshot._customize._custom_sequence import CustomList +from inline_snapshot._customize._custom_sequence import CustomTuple +from inline_snapshot._customize._custom_undefined import CustomUndefined +from inline_snapshot._customize._custom_unmanaged import CustomUnmanaged +from inline_snapshot._new_adapter import NewAdapter +from inline_snapshot._new_adapter import warn_star_expression +from inline_snapshot._unmanaged import is_unmanaged + +from .._adapter_context import AdapterContext +from .._change import ChangeBase from .generic_value import GenericValue +class AstToCustom: + context: AdapterContext + + def __init__(self, context): + self.eval = context.eval + self.context = context + + def convert(self, value: Any, node: ast.expr): + if is_unmanaged(value): + return CustomUnmanaged(value) + + if warn_star_expression(node, self.context): + return self.convert_generic(value, node) + + t = type(node).__name__ + return getattr(self, "convert_" + t, self.convert_generic)(value, node) + + def eval_convert(self, node): + return self.convert(self.eval(node), node) + + def convert_generic(self, value: Any, node: ast.expr): + if value is ...: + return CustomUndefined() + else: + return CustomCode(value, ast.unparse(node)) + + def convert_Call(self, value: Any, node: ast.Call): + return CustomCall( + self.eval_convert(node.func), + [self.eval_convert(a) for a in node.args], + {kw.arg: self.eval_convert(kw.value) for kw in node.keywords if kw.arg}, + ) + + def convert_List(self, value: list, node: ast.List): + + return CustomList([self.convert(v, n) for v, n in zip(value, node.elts)]) + + def convert_Tuple(self, value: tuple, node: ast.Tuple): + return CustomTuple([self.convert(v, n) for v, n in zip(value, node.elts)]) + + def convert_Dict(self, value: dict, node: ast.Dict): + return CustomDict( + { + self.convert(k, k_node): self.convert(v, v_node) + for (k, v), k_node, v_node in zip(value.items(), node.keys, node.values) + if k_node is not None + } + ) + + +class ValueToCustom: + """this implementation is for cpython <= 3.10 only. + It works similar to AstToCustom but can not handle calls + """ + + context: AdapterContext + + def __init__(self, context): + self.context = context + + def convert(self, value: Any): + if is_unmanaged(value): + return CustomUnmanaged(value) + + if isinstance(value, CustomDefault): + return self.convert(value.value) + + t = type(value).__name__ + return getattr(self, "convert_" + t, self.convert_generic)(value) + + def convert_generic(self, value: Any) -> Custom: + if value is ...: + return CustomUndefined() + else: + with mock_repr(self.context): + result = Builder(self.context, _recursive=False)._get_handler(value) + if isinstance(result, CustomCall) and result.function == type(value): + function = self.convert(result.function) + posonly_args = [self.convert(arg) for arg in result.args] + kwargs = {k: self.convert(arg) for k, arg in result.kwargs.items()} + + return CustomCall( + function=function, + args=posonly_args, + kwargs=kwargs, + ) + return CustomCode(value, "") + + def convert_list(self, value: list): + return CustomList([self.convert(v) for v in value]) + + def convert_tuple(self, value: tuple): + return CustomTuple([self.convert(v) for v in value]) + + def convert_dict(self, value: dict): + return CustomDict( + {self.convert(k): self.convert(v) for (k, v) in value.items()} + ) + + class UndecidedValue(GenericValue): def __init__(self, old_value, ast_node, context: AdapterContext): + self._context = context + self._ast_node = ast_node + + old_value = self.value_to_custom(old_value) - old_value = adapter_map(old_value, map_unmanaged) + assert isinstance(old_value, Custom) self._old_value = old_value - self._new_value = undefined - self._ast_node = ast_node - self._context = context + self._new_value = CustomUndefined() def _change(self, cls): self.__class__ = cls @@ -28,35 +140,21 @@ def _change(self, cls): def _new_code(self): assert False - def _get_changes(self) -> Iterator[Change]: + def _get_changes(self) -> Iterator[ChangeBase]: + assert isinstance(self._new_value, CustomUndefined) - def handle(node, obj): + new_value = self.to_custom(self._old_value._eval()) - adapter = get_adapter_type(obj) - if adapter is not None and hasattr(adapter, "items"): - for item in adapter.items(obj, node): - yield from handle(item.node, item.value) - return + adapter = NewAdapter(self._context) - if not isinstance(obj, Unmanaged) and node is not None: - new_token = value_to_token(obj) - if self._file._token_of_node(node) != new_token: - new_code = self._file._token_to_code(new_token) - - yield Replace( - node=self._ast_node, - file=self._file, - new_code=new_code, - flag="update", - old_value=self._old_value, - new_value=self._old_value, - ) - - yield from handle(self._ast_node, self._old_value) - - # functions which determine the type + for change in adapter.compare(self._old_value, self._ast_node, new_value): + assert change.flag == "update", change + yield change def __eq__(self, other): + if compare_only(): + return False + from .._snapshot.eq_value import EqValue self._change(EqValue) @@ -74,11 +172,11 @@ def __ge__(self, other): self._change(MaxValue) return self >= other - def __contains__(self, other): + def __contains__(self, item): from .._snapshot.collection_value import CollectionValue self._change(CollectionValue) - return other in self + return item in self def __getitem__(self, item): from .._snapshot.dict_value import DictValue diff --git a/src/inline_snapshot/_snapshot_session.py b/src/inline_snapshot/_snapshot_session.py index a103c718..63f26895 100644 --- a/src/inline_snapshot/_snapshot_session.py +++ b/src/inline_snapshot/_snapshot_session.py @@ -1,8 +1,8 @@ -import ast import os import sys import tokenize from pathlib import Path +from types import SimpleNamespace from typing import Dict from typing import List @@ -21,8 +21,6 @@ from ._change import ChangeBase from ._change import ExternalRemove from ._change import apply_all -from ._code_repr import used_hasrepr -from ._external._find_external import ensure_import from ._external._find_external import used_externals_in from ._flags import Flags from ._global_state import state @@ -240,6 +238,32 @@ def header(): class SnapshotSession: + def __init__(self): + self.registered_modules = set() + + def register_customize_hooks_from_module(self, module): + """Find and register functions decorated with @customize from a module""" + + if module.__file__ in self.registered_modules: + return + + self.registered_modules.add(module.__file__) + + hooks = {} + + for name in dir(module): + obj = getattr(module, name, None) + if isinstance(obj, type) and name.startswith("InlineSnapshot"): + state().pm.register(obj(), name=f"") + + if hasattr(obj, "inline_snapshot_impl"): + hooks[name] = obj + + if hooks: + state().pm.register( + SimpleNamespace(**hooks), name=f"" + ) + @staticmethod def test_enter(): state().missing_values = 0 @@ -334,10 +358,12 @@ def load_config(self, pyproject, cli_flags, parallel_run, error, project_root): state().active = "disable" not in flags state().update_flags = Flags(flags & categories) - if state().config.storage_dir is None: - state().config.storage_dir = project_root / ".inline-snapshot" + storage_dir = state().config.storage_dir + if storage_dir is None: + storage_dir = project_root / ".inline-snapshot" + state().config.storage_dir = storage_dir - state().all_storages = default_storages(state().config.storage_dir) + state().all_storages = default_storages(storage_dir) if flags - {"short-report", "disable"} and not is_pytest_compatible(): @@ -454,25 +480,4 @@ def console(): for change in used_changes: change.apply_external_changes() - for test_file in cr.files(): - tree = ast.parse(test_file.new_code()) - used_externals = used_externals_in( - test_file.filename, tree, check_import=False - ) - - required_imports = [] - - if used_externals: - required_imports.append("external") - - if used_hasrepr(tree): - required_imports.append("HasRepr") - - if required_imports: - ensure_import( - test_file.filename, - {"inline_snapshot": required_imports}, - cr, - ) - cr.fix_all() diff --git a/src/inline_snapshot/_source_file.py b/src/inline_snapshot/_source_file.py index d5c3dc15..44dae906 100644 --- a/src/inline_snapshot/_source_file.py +++ b/src/inline_snapshot/_source_file.py @@ -1,4 +1,5 @@ -import ast +import io +import token import tokenize from pathlib import Path @@ -8,9 +9,18 @@ from inline_snapshot._format import format_code from inline_snapshot._utils import normalize from inline_snapshot._utils import simple_token -from inline_snapshot._utils import value_to_token -from ._utils import ignore_tokens +ignore_tokens = (token.NEWLINE, token.ENDMARKER, token.NL) + + +def _token_of_code(code_repr): + input = io.StringIO(code_repr) + + return [ + simple_token(t.type, t.string) + for t in tokenize.generate_tokens(input.readline) + if t.type not in ignore_tokens + ] class SourceFile: @@ -23,29 +33,24 @@ def __init__(self, source: Source): def filename(self) -> str: return self._source.filename - def _format(self, text): + def _format(self, code): if self._source is None or enforce_formatting(): - return text + return code else: - return format_code(text, Path(self._source.filename)) + return format_code(code, Path(self._source.filename)) + + def format_expression(self, code: str) -> str: + return self._format(code).strip() def asttokens(self): return self._source.asttokens() - def _token_to_code(self, tokens): - if len(tokens) == 1 and tokens[0].type == 3: - try: - if ast.literal_eval(tokens[0].string) == "": - # https://github.com/15r10nk/inline-snapshot/issues/281 - # https://github.com/15r10nk/inline-snapshot/issues/258 - # this would otherwise cause a triple-quoted-string because black would format it as a docstring at the beginning of the code - return '""' - except: # pragma: no cover - pass - return self._format(tokenize.untokenize(tokens)).strip() - - def _value_to_code(self, value): - return self._token_to_code(value_to_token(value)) + def code_changed(self, old_node, new_code): + + if old_node is None: + return False + + return self._token_of_node(old_node) != _token_of_code(new_code) def _token_of_node(self, node): diff --git a/src/inline_snapshot/_types.py b/src/inline_snapshot/_types.py index 1f221c27..d0fb6caa 100644 --- a/src/inline_snapshot/_types.py +++ b/src/inline_snapshot/_types.py @@ -35,7 +35,7 @@ class Snapshot(Protocol[T]): ``` python from typing import Optional - from inline_snapshot import snapshot, Snapshot + from inline_snapshot import Snapshot, snapshot # required snapshots diff --git a/src/inline_snapshot/_unmanaged.py b/src/inline_snapshot/_unmanaged.py index 3f33393a..63065770 100644 --- a/src/inline_snapshot/_unmanaged.py +++ b/src/inline_snapshot/_unmanaged.py @@ -8,9 +8,8 @@ def is_dirty_equal(value): else: def is_dirty_equal(value): - return isinstance(value, dirty_equals.DirtyEquals) or ( - isinstance(value, type) and issubclass(value, dirty_equals.DirtyEquals) - ) + t = value if isinstance(value, type) else type(value) + return any(x is dirty_equals.DirtyEquals for x in t.__mro__) def update_allowed(value): @@ -29,23 +28,3 @@ def declare_unmanaged(data_type): global unmanaged_types unmanaged_types.append(data_type) return data_type - - -class Unmanaged: - def __init__(self, value): - self.value = value - - def __eq__(self, other): - assert not isinstance(other, Unmanaged) - - return self.value == other - - def __repr__(self): - return repr(self.value) - - -def map_unmanaged(value): - if is_unmanaged(value): - return Unmanaged(value) - else: - return value diff --git a/src/inline_snapshot/_utils.py b/src/inline_snapshot/_utils.py index 8b1870f7..d1403c42 100644 --- a/src/inline_snapshot/_utils.py +++ b/src/inline_snapshot/_utils.py @@ -1,11 +1,15 @@ import ast -import io +import copy import token -import tokenize from collections import namedtuple from pathlib import Path +from types import BuiltinFunctionType +from types import FunctionType +from types import MethodType -from ._code_repr import code_repr +from inline_snapshot._exceptions import UsageError + +from ._code_repr import real_repr def link(text, link=None): @@ -73,9 +77,6 @@ def normalize(token_sequence): return skip_trailing_comma(normalize_strings(token_sequence)) -ignore_tokens = (token.NEWLINE, token.ENDMARKER, token.NL) - - # based on ast.unparse def _str_literal_helper(string, *, quote_types): """Helper for writing string literals, minimizing escapes. @@ -139,7 +140,9 @@ def __eq__(self, other): for s in (self.string, other.string) for suffix in ("f", "rf", "Rf", "F", "rF", "RF") ): - return False + # I don't know why this is not covered/(maybe needed) with the new customize algo + # but I think it is better to handle it as 'no cover' for now + return False # pragma: no cover return ast.literal_eval(self.string) == ast.literal_eval( other.string @@ -148,28 +151,22 @@ def __eq__(self, other): return super().__eq__(other) -def value_to_token(value): - input = io.StringIO(code_repr(value)) - - def map_string(tok): - """Convert strings with newlines in triple quoted strings.""" - if tok.type == token.STRING: - s = ast.literal_eval(tok.string) - if isinstance(s, str) and ( - ("\n" in s and s[-1] != "\n") or s.count("\n") > 1 - ): - # unparse creates a triple quoted string here, - # because it thinks that the string should be a docstring - triple_quoted_string = triple_quote(s) - - assert ast.literal_eval(triple_quoted_string) == s +def clone(obj): + if isinstance(obj, (type, FunctionType, BuiltinFunctionType, MethodType)): + return obj - return simple_token(tok.type, triple_quoted_string) + new = copy.deepcopy(obj) + if not obj == new: + raise UsageError( + f"""\ +inline-snapshot uses `copy.deepcopy` to copy objects, +but the copied object is not equal to the original one: - return simple_token(tok.type, tok.string) +value = {real_repr(obj)} +copied_value = copy.deepcopy(value) +assert value == copied_value - return [ - map_string(t) - for t in tokenize.generate_tokens(input.readline) - if t.type not in ignore_tokens - ] +Please fix the way your object is copied or your __eq__ implementation. +""" + ) + return new diff --git a/src/inline_snapshot/extra.py b/src/inline_snapshot/extra.py index f840aaf5..1aad98e9 100644 --- a/src/inline_snapshot/extra.py +++ b/src/inline_snapshot/extra.py @@ -83,9 +83,9 @@ def prints(*, stdout: Snapshot[str] = "", stderr: Snapshot[str] = ""): ``` python + import sys from inline_snapshot import snapshot from inline_snapshot.extra import prints - import sys def test_prints(): @@ -98,9 +98,9 @@ def test_prints(): ``` python hl_lines="7 8 9" + import sys from inline_snapshot import snapshot from inline_snapshot.extra import prints - import sys def test_prints(): @@ -114,11 +114,11 @@ def test_prints(): === "ignore stdout" - ``` python hl_lines="3 9 10" + ``` python hl_lines="2 9 10" + import sys + from dirty_equals import IsStr from inline_snapshot import snapshot from inline_snapshot.extra import prints - from dirty_equals import IsStr - import sys def test_prints(): @@ -168,9 +168,9 @@ def warns( ``` python + from warnings import warn from inline_snapshot import snapshot from inline_snapshot.extra import warns - from warnings import warn def test_warns(): @@ -182,9 +182,9 @@ def test_warns(): ``` python hl_lines="7" + from warnings import warn from inline_snapshot import snapshot from inline_snapshot.extra import warns - from warnings import warn def test_warns(): @@ -272,9 +272,9 @@ def test_request(): ``` python - from inline_snapshot.extra import Transformed - from inline_snapshot import snapshot import re + from inline_snapshot import snapshot + from inline_snapshot.extra import Transformed class Thing: @@ -337,9 +337,9 @@ def transformation(func): ``` python - from inline_snapshot.extra import transformation - from inline_snapshot import snapshot import re + from inline_snapshot import snapshot + from inline_snapshot.extra import transformation class Thing: diff --git a/src/inline_snapshot/fix_pytest_diff.py b/src/inline_snapshot/fix_pytest_diff.py index b481f2f2..48f23635 100644 --- a/src/inline_snapshot/fix_pytest_diff.py +++ b/src/inline_snapshot/fix_pytest_diff.py @@ -4,7 +4,6 @@ from inline_snapshot._is import Is from inline_snapshot._snapshot.generic_value import GenericValue -from inline_snapshot._unmanaged import Unmanaged def fix_pytest_diff(): @@ -23,19 +22,6 @@ def _pprint_snapshot( PrettyPrinter._dispatch[GenericValue.__repr__] = _pprint_snapshot - def _pprint_unmanaged( - self, - object: Any, - stream: IO[str], - indent: int, - allowance: int, - context: Set[int], - level: int, - ) -> None: - self._format(object.value, stream, indent, allowance, context, level) - - PrettyPrinter._dispatch[Unmanaged.__repr__] = _pprint_unmanaged - def _pprint_is( self, object: Any, diff --git a/src/inline_snapshot/plugin/__init__.py b/src/inline_snapshot/plugin/__init__.py new file mode 100644 index 00000000..65b8fc68 --- /dev/null +++ b/src/inline_snapshot/plugin/__init__.py @@ -0,0 +1,18 @@ +from inline_snapshot._customize._custom_code import Import +from inline_snapshot._customize._custom_code import ImportFrom + +from .._customize._builder import Builder +from .._customize._custom import Custom +from ._spec import InlineSnapshotPluginSpec +from ._spec import customize +from ._spec import hookimpl + +__all__ = ( + "InlineSnapshotPluginSpec", + "customize", + "hookimpl", + "Builder", + "Custom", + "Import", + "ImportFrom", +) diff --git a/src/inline_snapshot/plugin/_default_plugin.py b/src/inline_snapshot/plugin/_default_plugin.py new file mode 100644 index 00000000..9c11bc7d --- /dev/null +++ b/src/inline_snapshot/plugin/_default_plugin.py @@ -0,0 +1,402 @@ +import ast +import datetime +from collections import Counter +from collections import defaultdict +from dataclasses import MISSING +from dataclasses import fields +from dataclasses import is_dataclass +from enum import Enum +from enum import Flag +from pathlib import Path +from pathlib import PurePath +from types import BuiltinFunctionType +from types import FunctionType +from typing import Any +from typing import Dict + +from inline_snapshot._customize._builder import Builder +from inline_snapshot._customize._custom_code import ImportFrom +from inline_snapshot._customize._custom_undefined import CustomUndefined +from inline_snapshot._customize._custom_unmanaged import CustomUnmanaged +from inline_snapshot._external._outsource import Outsourced +from inline_snapshot._sentinels import undefined +from inline_snapshot._unmanaged import is_dirty_equal +from inline_snapshot._unmanaged import is_unmanaged +from inline_snapshot._utils import triple_quote + +from ._spec import customize + + +class InlineSnapshotPlugin: + @customize + def standard_handler(self, value, builder: Builder): + if isinstance(value, list): + return builder.create_list(value) + + if type(value) is tuple: + return builder.create_tuple(value) + + if isinstance(value, dict): + return builder.create_dict(value) + + @customize + def string_handler(self, value, builder: Builder): + if isinstance(value, str) and ( + ("\n" in value and value[-1] != "\n") or value.count("\n") > 1 + ): + + triple_quoted_string = triple_quote(value) + + assert ast.literal_eval(triple_quoted_string) == value + + return builder.create_code(triple_quoted_string) + + @customize(tryfirst=True) + def counter_handler(self, value, builder: Builder): + if isinstance(value, Counter): + return builder.create_call(Counter, [dict(value)]) + + @customize + def function_and_type_handler( + self, value, builder: Builder, local_vars: Dict[str, Any] + ): + if isinstance(value, (FunctionType, type)): + for name, local_value in local_vars.items(): + if local_value is value: + return builder.create_code(name) + + qualname = value.__qualname__.split("[")[0] + name = qualname.split(".")[0] + return builder.create_code( + qualname, imports=[ImportFrom(value.__module__, name)] + ) + + @customize + def builtin_function_handler(self, value, builder: Builder): + if isinstance(value, BuiltinFunctionType): + return builder.create_code(value.__name__) + + @customize + def datetime_handler(self, value, builder: Builder): + + if isinstance(value, datetime.datetime): + return builder.create_call( + datetime.datetime, + [value.year, value.month, value.day], + { + "hour": builder.with_default(value.hour, 0), + "minute": builder.with_default(value.minute, 0), + "second": builder.with_default(value.second, 0), + "microsecond": builder.with_default(value.microsecond, 0), + }, + ) + + if isinstance(value, datetime.date): + return builder.create_call( + datetime.date, [value.year, value.month, value.day] + ) + + if isinstance(value, datetime.time): + return builder.create_call( + datetime.time, + [], + { + "hour": builder.with_default(value.hour, 0), + "minute": builder.with_default(value.minute, 0), + "second": builder.with_default(value.second, 0), + "microsecond": builder.with_default(value.microsecond, 0), + }, + ) + + if isinstance(value, datetime.timedelta): + return builder.create_call( + datetime.timedelta, + [], + { + "days": builder.with_default(value.days, 0), + "seconds": builder.with_default(value.seconds, 0), + "microseconds": builder.with_default(value.microseconds, 0), + }, + ) + + @customize + def path_handler(self, value, builder: Builder): + if isinstance(value, Path): + return builder.create_call(Path, [value.as_posix()]) + + if isinstance(value, PurePath): + return builder.create_call(PurePath, [value.as_posix()]) + + def sort_set_values(self, set_values): + is_sorted = False + try: + set_values = sorted(set_values) + is_sorted = True + except TypeError: + # can not be sorted by value + pass + + set_values = list(map(repr, set_values)) + if not is_sorted: + # sort by string representation + set_values = sorted(set_values) + + return set_values + + @customize + def set_handler(self, value, builder: Builder): + if isinstance(value, set): + if len(value) == 0: + return builder.create_code("set()") + else: + return builder.create_code( + "{" + ", ".join(self.sort_set_values(value)) + "}" + ) + + @customize + def frozenset_handler(self, value, builder: Builder): + if isinstance(value, frozenset): + if len(value) == 0: + return builder.create_code("frozenset()") + else: + return builder.create_call(frozenset, [set(value)]) + + # -8<- [start:Enum] + @customize + def enum_handler(self, value, builder: Builder): + if isinstance(value, Enum): + qualname = type(value).__qualname__ + name = qualname.split(".")[0] + + return builder.create_code( + f"{type(value).__qualname__}.{value.name}", + imports=[ImportFrom(type(value).__module__, name)], + ) + + # -8<- [end:Enum] + + @customize + def flag_handler(self, value, builder: Builder): + if isinstance(value, Flag): + qualname = type(value).__qualname__ + name = qualname.split(".")[0] + + return builder.create_code( + " | ".join( + f"{qualname}.{flag.name}" for flag in type(value) if flag in value + ), + imports=[ImportFrom(type(value).__module__, name)], + ) + + @customize + def source_file_name_handler(self, value, builder: Builder, global_vars): + if "__file__" in global_vars and value == global_vars["__file__"]: + return builder.create_code("__file__") + + @customize + def dataclass_handler(self, value, builder: Builder): + + if is_dataclass(value) and not isinstance(value, type): + + kwargs = {} + + for field in fields(value): # type: ignore + if field.repr: + field_value = getattr(value, field.name) + + if field.default != MISSING: + field_value = builder.with_default(field_value, field.default) + + if field.default_factory != MISSING: + field_value = builder.with_default( + field_value, field.default_factory() + ) + + kwargs[field.name] = field_value + + return builder.create_call(type(value), [], kwargs) + + @customize + def namedtuple_handler(self, value, builder: Builder): + t = type(value) + b = t.__bases__ + if len(b) != 1 or b[0] != tuple: + return + f = getattr(t, "_fields", None) + if not isinstance(f, tuple): + return + if not all(type(n) == str for n in f): + return + + return builder.create_call( + type(value), + [], + { + field: ( + getattr(value, field) + if field not in value._field_defaults + else builder.with_default( + getattr(value, field), value._field_defaults[field] + ) + ) + for field in value._fields + }, + ) + + @customize(tryfirst=True) + def defaultdict_handler(self, value, builder: Builder): + if isinstance(value, defaultdict): + return builder.create_call( + type(value), [value.default_factory, dict(value)], {} + ) + + @customize + def unmanaged_handler(self, value, builder: Builder): + if is_unmanaged(value): + return CustomUnmanaged(value=value) + + @customize + def undefined_handler(self, value, builder: Builder): + if value is undefined: + return CustomUndefined() + + @customize + def outsource_handler(self, value, builder: Builder): + if isinstance(value, Outsourced): + return builder.create_external( + value.data, format=value.suffix, storage=value.storage + ) + + +try: + import dirty_equals +except ImportError: # pragma: no cover + pass +else: + + from dirty_equals._utils import Omit + + class InlineSnapshotDirtyEqualsPlugin: + @customize(tryfirst=True) + def dirty_equals_handler(self, value, builder: Builder): + + if is_dirty_equal(value) and builder._build_new_value: + + if isinstance(value, type): + return builder.create_code( + value.__name__, + imports=[ImportFrom("dirty_equals", value.__name__)], + ) + else: + + args = [a for a in value._repr_args if a is not Omit] + kwargs = { + k: a for k, a in value._repr_kwargs.items() if a is not Omit + } + if type(value) == dirty_equals.IsNow: + kwargs.pop("approx") + if ( + isinstance(delta := kwargs["delta"], datetime.timedelta) + and delta.total_seconds() == 2 + ): + kwargs.pop("delta") + return builder.create_call(type(value), args, kwargs) + + @customize(tryfirst=True) + def is_now_handler(self, value, builder: Builder): + if value == dirty_equals.IsNow(): + return dirty_equals.IsNow() + + +try: + import attrs +except ImportError: # pragma: no cover + + pass + +else: + + class InlineSnapshotAttrsPlugin: + @customize + def attrs_handler(self, value, builder: Builder): + + if attrs.has(type(value)): + + kwargs = {} + + for field in attrs.fields(type(value)): + if field.repr: + field_value = getattr(value, field.name) + + if field.default is not attrs.NOTHING: + + default_value = ( + field.default + if not isinstance(field.default, attrs.Factory) # type: ignore + else ( + field.default.factory() + if not field.default.takes_self + else field.default.factory(value) + ) + ) + field_value = builder.with_default( + field_value, default_value + ) + + kwargs[field.name] = field_value + + return builder.create_call(type(value), [], kwargs) + + +try: + import pydantic +except ImportError: # pragma: no cover + + pass + +else: + # import pydantic + if pydantic.version.VERSION.startswith("1."): + # pydantic v1 + from pydantic.fields import Undefined as PydanticUndefined # type: ignore[attr-defined,no-redef] + + def get_fields(value): + return value.__fields__ + + else: + # pydantic v2 + from pydantic_core import PydanticUndefined + + def get_fields(value): + return type(value).model_fields + + from pydantic import BaseModel + + class InlineSnapshotPydanticPlugin: + @customize + def pydantic_model_handler(self, value, builder: Builder): + + if isinstance(value, BaseModel): + + kwargs = {} + + for name, field in get_fields(value).items(): # type: ignore + if getattr(field, "repr", True): + field_value = getattr(value, name) + + if ( + field.default is not PydanticUndefined + and field.default == field_value + ): + field_value = builder.with_default( + field_value, field.default + ) + + elif field.default_factory is not None: + field_value = builder.with_default( + field_value, field.default_factory() + ) + + kwargs[name] = field_value + + return builder.create_call(type(value), [], kwargs) diff --git a/src/inline_snapshot/plugin/_spec.py b/src/inline_snapshot/plugin/_spec.py new file mode 100644 index 00000000..e56eeaea --- /dev/null +++ b/src/inline_snapshot/plugin/_spec.py @@ -0,0 +1,108 @@ +from functools import partial +from typing import Any +from typing import Dict + +import pluggy + +from inline_snapshot._customize._builder import Builder + +inline_snapshot_plugin_name = "inline_snapshot" + +hookspec = pluggy.HookspecMarker(inline_snapshot_plugin_name) +""" +The pluggy hookspec marker for inline_snapshot. +""" + +hookimpl = pluggy.HookimplMarker(inline_snapshot_plugin_name) +""" +The pluggy hookimpl marker for inline_snapshot. +""" + +customize = partial(hookimpl, specname="customize") +""" +Decorator to mark a function as an implementation of the `customize` hook which can be used instead of `hookimpl(specname="customize")`. +""" + + +class InlineSnapshotPluginSpec: + @hookspec(firstresult=True) + def customize( + self, + value: Any, + builder: Builder, + local_vars: Dict[str, Any], + global_vars: Dict[str, Any], + ) -> Any: + """ + The customize hook is called every time a snapshot value should be converted into code. + + This hook allows you to control how inline-snapshot represents objects in generated code. + When multiple handlers are registered, they are called until one returns a non-None value. + `customize` is also called for each attribute of the converted hook which is not a Custom node, which means that a hook for `int` does not only apply for `snapshot(5)` but also for `snaspshot([1,2,3])` 3 times. + + Arguments: + value: The Python object that needs to be converted into source code representation. + This is the actual runtime value from your test. + builder: A Builder instance providing methods to construct custom code representations. + Use methods like `create_call()`, `create_dict()`, `create_external()`, etc. + local_vars: Dictionary mapping variable names to their values in the local scope. + Useful for referencing existing variables instead of creating new literals. + global_vars: Dictionary mapping variable names to their values in the global scope. + + Returns: + (Custom): created using [Builder][inline_snapshot.plugin.Builder] `create_*` methods. + (None): if this handler doesn't apply to the given value. + (Something else): when the next handler should process the value. + + Example: + You can use @customize when you want to specify multiple handlers in the same class: + + + === "with @customize" + + ``` python title="conftest.py" + from inline_snapshot.plugin import customize + + + class InlineSnapshotPlugin: + @customize + def binary_numbers(self, value, builder, local_vars, global_vars): + if isinstance(value, int): + return builder.create_code(bin(value)) + + @customize + def repeated_strings(self, value, builder): + if isinstance(value, str) and value == value[0] * len(value): + return builder.create_code(f"'{value[0]}'*{len(value)}") + ``` + + === "by method name" + + + ``` python title="conftest.py" + class InlineSnapshotPlugin: + def customize(self, value, builder, local_vars, global_vars): + if isinstance(value, int): + return builder.create_code(bin(value)) + + if isinstance(value, str) and value == value[0] * len(value): + return builder.create_code(f"'{value[0]}'*{len(value)}") + ``` + + + + ``` python title="test_customizations.py" + from inline_snapshot import snapshot + + + def test_customizations(): + assert ["aaaaaaaaaaaaaaa", "bbbbbb"] == snapshot(["a" * 15, "b" * 6]) + assert 18856 == snapshot(0b100100110101000) + ``` + + + + """ + + # @hookspec + # def format_code(self, filename, str) -> str: ... diff --git a/src/inline_snapshot/pytest_plugin.py b/src/inline_snapshot/pytest_plugin.py index f26250cb..86b3ca7e 100644 --- a/src/inline_snapshot/pytest_plugin.py +++ b/src/inline_snapshot/pytest_plugin.py @@ -126,10 +126,33 @@ class InlineSnapshotPlugin: def __init__(self): self.session = SnapshotSession() + @pytest.hookimpl(tryfirst=True) + def pytest_plugin_registered(self, plugin, manager): + """Register @customize hooks from conftest.py files""" + # Skip internal plugins + + if not hasattr(plugin, "__file__") or plugin.__file__ is None: + return + + # Only process conftest.py files + if "conftest.py" not in plugin.__file__: + return + + self.session.register_customize_hooks_from_module(plugin) + @pytest.hookimpl def pytest_configure(self, config): enter_snapshot_context() + # Register customize hooks from all already loaded conftest.py files + for plugin in config.pluginmanager.get_plugins(): + if ( + hasattr(plugin, "__file__") + and plugin.__file__ + and "conftest.py" in plugin.__file__ + ): + self.session.register_customize_hooks_from_module(plugin) + # setup default flags if is_pytest_compatible(): state().config.default_flags_tui = ["create", "review"] diff --git a/src/inline_snapshot/testing/_example.py b/src/inline_snapshot/testing/_example.py index 4917cfcf..98211fa8 100644 --- a/src/inline_snapshot/testing/_example.py +++ b/src/inline_snapshot/testing/_example.py @@ -1,12 +1,12 @@ from __future__ import annotations +import importlib.util import os import platform import random import re import subprocess as sp import sys -import tokenize import traceback import uuid from argparse import ArgumentParser @@ -14,7 +14,6 @@ from io import StringIO from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any from typing import Callable from unittest.mock import patch @@ -319,6 +318,11 @@ def report_error(message): raise StopTesting(message) snapshot_flags = set() + old_modules = sys.modules + old_path = sys.path[:] + # Add tmp_path to sys.path so modules can be imported normally + sys.path.insert(0, str(tmp_path)) + try: enter_snapshot_context() session.load_config( @@ -332,21 +336,42 @@ def report_error(message): report_output = StringIO() console = Console(file=report_output, width=80) + # Load and register all conftest.py files first + for conftest_path in tmp_path.rglob("conftest.py"): + print("load> conftest", conftest_path) + + # Load conftest module using importlib + spec = importlib.util.spec_from_file_location( + f"conftest_{conftest_path.parent.name}", conftest_path + ) + assert spec and spec.loader + + conftest_module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = conftest_module + conftest_module.__file__ = str(conftest_path) + spec.loader.exec_module(conftest_module) + + # Register customize hooks from this conftest + session.register_customize_hooks_from_module(conftest_module) + tests_found = False for filename in tmp_path.rglob("test_*.py"): - globals: dict[str, Any] = {} - print("run> pytest-inline", filename) - with tokenize.open(filename) as f: - code = f.read() - exec( - compile(code, filename, "exec"), - globals, + print("run> pytest-inline", filename, *args) + + # Load module using importlib + spec = importlib.util.spec_from_file_location( + filename.stem, filename ) + assert spec and spec.loader + + module = importlib.util.module_from_spec(spec) + sys.modules[filename.stem] = module + spec.loader.exec_module(module) # run all test_* functions tests = [ v - for k, v in globals.items() + for k, v in module.__dict__.items() if (k.startswith("test_") or k == "test") and callable(v) ] tests_found |= len(tests) != 0 @@ -378,6 +403,8 @@ def fail(message): except StopTesting as e: assert stderr == f"ERROR: {e}\n" finally: + sys.modules = old_modules + sys.path = old_path leave_snapshot_context() if reported_categories is not None: @@ -445,6 +472,7 @@ def run_pytest( Returns: A new Example instance containing the changed files. """ + self.dump_files() with TemporaryDirectory() as dir: diff --git a/tests/adapter/test_change_types.py b/tests/adapter/test_change_types.py index add6804f..f28d3c17 100644 --- a/tests/adapter/test_change_types.py +++ b/tests/adapter/test_change_types.py @@ -31,7 +31,7 @@ def f(v): def code_repr(v): g = {} - exec(context + f"r=repr({a})", g) + exec(context + f"r=repr({v})", g) return g["r"] def code(a, b): diff --git a/tests/adapter/test_dataclass.py b/tests/adapter/test_dataclass.py index f0847176..4094d24d 100644 --- a/tests/adapter/test_dataclass.py +++ b/tests/adapter/test_dataclass.py @@ -121,6 +121,43 @@ def test_something(): ) +def test_dataclass_add_arguments(): + Example( + """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass,field + +@dataclass +class A: + a:int + b:int=2 + +def test_something(): + for _ in [1,2]: + assert A(a=1,b=5) == snapshot(A(a=1)) +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass,field + +@dataclass +class A: + a:int + b:int=2 + +def test_something(): + for _ in [1,2]: + assert A(a=1,b=5) == snapshot(A(a=1, b=5)) +""" + } + ), + ) + + def test_dataclass_positional_arguments(): Example( """\ @@ -153,7 +190,7 @@ class A: def test_something(): for _ in [1,2]: - assert A(a=1) == snapshot(A(1,2)) + assert A(a=1) == snapshot(A(a=1)) """ } ), @@ -374,7 +411,7 @@ def test_something(): assert A(a=3) == snapshot(A(**{"a":5})),"not equal" """ ).run_inline( - ["--inline-snapshot=fix"], + ["--inline-snapshot=report"], raises=snapshot( """\ AssertionError: @@ -414,7 +451,7 @@ class A: c:int=0 def test_something(): - assert A(a=3,b=3,c=3) == snapshot(A(a = 3, b=3, c = 3)),"not equal" + assert A(a=3,b=3,c=3) == snapshot(A(a=3, b=3, c=3)),"not equal" """ } ), @@ -450,12 +487,8 @@ def test_something(): def test_remove_positional_argument(): Example( - """\ -from inline_snapshot import snapshot - -from inline_snapshot._adapter.generic_call_adapter import GenericCallAdapter,Argument - - + { + "tests/helper.py": """\ class L: def __init__(self,*l): self.l=l @@ -464,20 +497,21 @@ def __eq__(self,other): if not isinstance(other,L): return NotImplemented return other.l==self.l +""", + "tests/conftest.py": """\ +from inline_snapshot.plugin import customize +from helper import L + +class InlineSnapshotPlugin: + @customize + def handle_L(self,value,builder): + if isinstance(value,L): + return builder.create_call(L,value.l) +""", + "tests/test_something.py": """\ +from inline_snapshot import snapshot +from helper import L -class LAdapter(GenericCallAdapter): - @classmethod - def check_type(cls, value_type): - return issubclass(value_type,L) - - @classmethod - def arguments(cls, value): - return ([Argument(x) for x in value.l],{}) - - @classmethod - def argument(cls, value, pos_or_name): - assert isinstance(pos_or_name,int) - return value.l[pos_or_name] def test_L1(): for _ in [1,2]: @@ -490,39 +524,16 @@ def test_L2(): def test_L3(): for _ in [1,2]: assert L(1,2) == snapshot(L(1, 2)), "not equal" -""" +""", + } ).run_pytest(returncode=snapshot(1)).run_pytest( ["--inline-snapshot=fix"], changed_files=snapshot( { "tests/test_something.py": """\ from inline_snapshot import snapshot +from helper import L -from inline_snapshot._adapter.generic_call_adapter import GenericCallAdapter,Argument - - -class L: - def __init__(self,*l): - self.l=l - - def __eq__(self,other): - if not isinstance(other,L): - return NotImplemented - return other.l==self.l - -class LAdapter(GenericCallAdapter): - @classmethod - def check_type(cls, value_type): - return issubclass(value_type,L) - - @classmethod - def arguments(cls, value): - return ([Argument(x) for x in value.l],{}) - - @classmethod - def argument(cls, value, pos_or_name): - assert isinstance(pos_or_name,int) - return value.l[pos_or_name] def test_L1(): for _ in [1,2]: @@ -674,3 +685,90 @@ def test_list(): } ), ) + + +def test_dataclass_custom_init(): + + Example( + """\ +from dataclasses import dataclass +from inline_snapshot import snapshot + +@dataclass(init=False) +class A: + _a:int + + def __init__(self,a:int=None,_a:int=None): + self._a=a or _a + + @property + def a(self): + return self._a + + +def test_A(): + assert A(a=5) == snapshot(A(_a=5)) + assert A(a=5) == snapshot(A(a=5)) + + assert A(a=5) == snapshot(A(_a=4)) + assert A(a=5) == snapshot(A(a=4)) +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from dataclasses import dataclass +from inline_snapshot import snapshot + +@dataclass(init=False) +class A: + _a:int + + def __init__(self,a:int=None,_a:int=None): + self._a=a or _a + + @property + def a(self): + return self._a + + +def test_A(): + assert A(a=5) == snapshot(A(_a=5)) + assert A(a=5) == snapshot(A(a=5)) + + assert A(a=5) == snapshot(A(_a=5)) + assert A(a=5) == snapshot(A(_a=5)) +""" + } + ), + ).run_inline( + ["--inline-snapshot=update"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from dataclasses import dataclass +from inline_snapshot import snapshot + +@dataclass(init=False) +class A: + _a:int + + def __init__(self,a:int=None,_a:int=None): + self._a=a or _a + + @property + def a(self): + return self._a + + +def test_A(): + assert A(a=5) == snapshot(A(_a=5)) + assert A(a=5) == snapshot(A(_a=5)) + + assert A(a=5) == snapshot(A(_a=5)) + assert A(a=5) == snapshot(A(_a=5)) +""" + } + ), + ) diff --git a/tests/adapter/test_general.py b/tests/adapter/test_general.py index 3afad504..71fb5684 100644 --- a/tests/adapter/test_general.py +++ b/tests/adapter/test_general.py @@ -45,3 +45,25 @@ def test_thing(): assert (i,) == (Is(i),) """ ).run_pytest(["--inline-snapshot=short-report"], report=snapshot("")) + + +def test_usageerror_unmanaged(): + + Example( + """\ +from inline_snapshot import snapshot,Is + + +def test_thing(): + assert [Is(5)] == snapshot([6]) +""" + ).run_inline( + ["--inline-snapshot=fix"], + report=snapshot(""), + raises=snapshot( + """\ +UsageError: +unmanaged values can not be compared with snapshots\ +""" + ), + ) diff --git a/tests/adapter/test_namedtuple.py b/tests/adapter/test_namedtuple.py new file mode 100644 index 00000000..74fdd417 --- /dev/null +++ b/tests/adapter/test_namedtuple.py @@ -0,0 +1,227 @@ +from inline_snapshot import snapshot +from inline_snapshot.testing._example import Example + + +def test_namedtuple_default_value(): + # Note: namedtuples with defaults are created using a different approach + Example( + """\ +from inline_snapshot import snapshot, Is +from collections import namedtuple + +A = namedtuple('A', ['a', 'b', 'c'], defaults=[2, []]) + +def test_something(): + for _ in [1, 2]: + assert A(a=1) == snapshot(A(a=1, b=2, c=[])) +""" + ).run_inline( + ["--inline-snapshot=update"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from inline_snapshot import snapshot, Is +from collections import namedtuple + +A = namedtuple('A', ['a', 'b', 'c'], defaults=[2, []]) + +def test_something(): + for _ in [1, 2]: + assert A(a=1) == snapshot(A(a=1)) +""" + } + ), + ) + + +def test_namedtuple_add_arguments(): + Example( + """\ +from inline_snapshot import snapshot, Is +from collections import namedtuple + +A = namedtuple('A', ['a', 'b'], defaults=[2]) + +def test_something(): + for _ in [1, 2]: + assert A(a=1, b=5) == snapshot(A(a=1)) +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from inline_snapshot import snapshot, Is +from collections import namedtuple + +A = namedtuple('A', ['a', 'b'], defaults=[2]) + +def test_something(): + for _ in [1, 2]: + assert A(a=1, b=5) == snapshot(A(a=1, b=5)) +""" + } + ), + ) + + +def test_namedtuple_positional_arguments(): + Example( + """\ +from inline_snapshot import snapshot, Is +from collections import namedtuple + +A = namedtuple('A', ['a', 'b', 'c'], defaults=[2, []]) + +def test_something(): + for _ in [1, 2]: + assert A(a=1) == snapshot(A(1, 2, c=[])) +""" + ).run_inline( + ["--inline-snapshot=update"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from inline_snapshot import snapshot, Is +from collections import namedtuple + +A = namedtuple('A', ['a', 'b', 'c'], defaults=[2, []]) + +def test_something(): + for _ in [1, 2]: + assert A(a=1) == snapshot(A(a=1)) +""" + } + ), + ) + + +def test_namedtuple_typing(): + Example( + """\ +from inline_snapshot import snapshot +from typing import NamedTuple + +class A(NamedTuple): + a: int + b: int + +def test_something(): + assert A(a=1, b=2) == snapshot() +""" + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from inline_snapshot import snapshot +from typing import NamedTuple + +class A(NamedTuple): + a: int + b: int + +def test_something(): + assert A(a=1, b=2) == snapshot(A(a=1, b=2)) +""" + } + ), + ) + + +def test_namedtuple_typing_defaults(): + Example( + """\ +from inline_snapshot import snapshot +from typing import NamedTuple + +class A(NamedTuple): + a: int + b: int = 2 + c: list = [] + +def test_something(): + for _ in [1, 2]: + assert A(a=1) == snapshot(A(a=1, b=2, c=[])) +""" + ).run_inline( + ["--inline-snapshot=update"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from inline_snapshot import snapshot +from typing import NamedTuple + +class A(NamedTuple): + a: int + b: int = 2 + c: list = [] + +def test_something(): + for _ in [1, 2]: + assert A(a=1) == snapshot(A(a=1)) +""" + } + ), + ) + + +def test_namedtuple_nested(): + Example( + """\ +from inline_snapshot import snapshot +from collections import namedtuple + +Inner = namedtuple('Inner', ['x', 'y']) +Outer = namedtuple('Outer', ['a', 'inner']) + +def test_something(): + assert Outer(a=1, inner=Inner(x=2, y=3)) == snapshot() +""" + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from inline_snapshot import snapshot +from collections import namedtuple + +Inner = namedtuple('Inner', ['x', 'y']) +Outer = namedtuple('Outer', ['a', 'inner']) + +def test_something(): + assert Outer(a=1, inner=Inner(x=2, y=3)) == snapshot(Outer(a=1, inner=Inner(x=2, y=3))) +""" + } + ), + ) + + +def test_namedtuple_mixed_args(): + # Test mixing positional and keyword arguments + Example( + """\ +from inline_snapshot import snapshot +from collections import namedtuple + +A = namedtuple('A', ['a', 'b', 'c']) + +def test_something(): + assert A(1, b=2, c=3) == snapshot() +""" + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from inline_snapshot import snapshot +from collections import namedtuple + +A = namedtuple('A', ['a', 'b', 'c']) + +def test_something(): + assert A(1, b=2, c=3) == snapshot(A(a=1, b=2, c=3)) +""" + } + ), + ) diff --git a/tests/conftest.py b/tests/conftest.py index c1524e2f..48ad4570 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import importlib.util import os import platform import re @@ -24,6 +25,7 @@ from inline_snapshot._global_state import snapshot_env from inline_snapshot._rewrite_code import ChangeRecorder from inline_snapshot._types import Category +from inline_snapshot.testing._example import deterministic_uuid pytest_plugins = "pytester" @@ -102,7 +104,7 @@ def run(self, *flags_arg: Category): print("input:") print(textwrap.indent(source, " |", lambda line: True).rstrip()) - with snapshot_env() as state: + with snapshot_env() as state, deterministic_uuid(): recorder = ChangeRecorder() state.update_flags = flags state.all_storages["hash"] = inline_snapshot._external.HashStorage( @@ -113,7 +115,16 @@ def run(self, *flags_arg: Category): error = False try: - exec(compile(filename.read_text("utf-8"), filename, "exec"), {}) + # Load module using importlib + spec = importlib.util.spec_from_file_location( + filename.stem, filename + ) + if spec and spec.loader: + module = importlib.util.module_from_spec(spec) + sys.modules[filename.stem] = module + spec.loader.exec_module(module) + else: + assert False, f"Could not load module from {filename}" except AssertionError: traceback.print_exc() error = True @@ -300,21 +311,35 @@ def setup(self, source: str, add_header=True): import datetime import pytest from freezegun.api import FakeDatetime,FakeDate -from inline_snapshot import customize_repr +from inline_snapshot.plugin import customize -@customize_repr -def _(value:FakeDatetime): - return value.__repr__().replace("FakeDatetime","datetime.datetime") +class InlineSnapshotPlugin: + @customize + def fakedatetime_handler(self,value,builder): + if isinstance(value,FakeDatetime): + return builder.create_code(value.__repr__().replace("FakeDatetime","datetime.datetime")) -@customize_repr -def _(value:FakeDate): - return value.__repr__().replace("FakeDate","datetime.date") + @customize + def fakedate_handler(self,value,builder): + if isinstance(value,FakeDate): + return builder.create_code(value.__repr__().replace("FakeDate","datetime.date")) @pytest.fixture(autouse=True) def set_time(freezer): freezer.move_to(datetime.datetime(2024, 3, 14, 0, 0, 0, 0)) yield + +import uuid +import random + +rd = random.Random(0) + +def f(): + return uuid.UUID(int=rd.getrandbits(128), version=4) + +uuid.uuid4 = f + """ ) diff --git a/tests/external/storage/test_hash.py b/tests/external/storage/test_hash.py index e62bf818..2ccd28af 100644 --- a/tests/external/storage/test_hash.py +++ b/tests/external/storage/test_hash.py @@ -29,3 +29,27 @@ def test_a(): } ), ) + + +def test_same_hash(): + Example( + """\ +from inline_snapshot import external +def test_a(): + assert "a" == external("hash:") + assert "a" == external("hash:") +""", + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + ".inline-snapshot/external/ca978112ca1bbdcafac231b39a23dc4da786eff8147c4e72b9807785afee48bb.txt": "a", + "tests/test_something.py": """\ +from inline_snapshot import external +def test_a(): + assert "a" == external("hash:ca978112ca1b*.txt") + assert "a" == external("hash:ca978112ca1b*.txt") +""", + } + ), + ) diff --git a/tests/external/test_external.py b/tests/external/test_external.py index 17283721..6a78578a 100644 --- a/tests/external/test_external.py +++ b/tests/external/test_external.py @@ -19,7 +19,7 @@ def test_basic(check_update): assert check_update( "assert outsource('text') == snapshot()", flags="create" ) == snapshot( - "assert outsource('text') == snapshot(external(\"hash:982d9e3eb996*.txt\"))" + "assert outsource('text') == snapshot(external(\"uuid:f728b4fa-4248-4e3a-8a5d-2f346baa9455.txt\"))" ) @@ -48,9 +48,6 @@ def test_a(): ["--inline-snapshot=create"], changed_files=snapshot( { - ".inline-snapshot/external/9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08.bin": "test", - ".inline-snapshot/external/9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08.log": "test", - ".inline-snapshot/external/9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08.txt": "test", "tests/__inline_snapshot__/test_something/test_a/e3e70682-c209-4cac-a29f-6fbed82c07cd.txt": "test", "tests/__inline_snapshot__/test_something/test_a/eb1167b3-67a9-4378-bc65-c1e582e2e662.bin": "test", "tests/__inline_snapshot__/test_something/test_a/f728b4fa-4248-4e3a-8a5d-2f346baa9455.log": "test", @@ -74,37 +71,89 @@ def test_a(): ) -def test_diskstorage(): +def test_compare_outsource(): with snapshot_env(): + assert outsource("one") == outsource("one") + assert outsource("one") != outsource("two") - assert outsource("test4") == snapshot(external("hash:a4e624d686e0*.txt")) - assert outsource("test5") == snapshot(external("hash:a140c0c1eda2*.txt")) - assert outsource("test6") == snapshot(external("hash:ed0cb90bdfa4*.txt")) - with raises( - snapshot( - "StorageLookupError: hash collision files=['a140c0c1eda2def2b830363ba362aa4d7d255c262960544821f556e16661b6ff.txt', 'a4e624d686e03ed2767c0abd85c14426b0b1157d2ce81d27bb4fe4f6f01d688a.txt']" - ) - ): - assert outsource("test4") == external("hash:a*.txt") +def test_hash_collision(): + e = ( + Example( + { + "pyproject.toml": """\ +[tool.inline-snapshot] +default-storage="hash" +""", + "tests/test_a.py": """ +from inline_snapshot import outsource,snapshot,external - with raises( - snapshot( - "StorageLookupError: hash 'bbbbb*.txt' is not found in the HashStorage" - ) - ): +def test_a(): + assert outsource("test4") == snapshot() + assert outsource("test5") == snapshot() + assert outsource("test6") == snapshot() + if False: + assert outsource("test4") == external("hash:a*.txt") +""", + } + ) + .run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + ".inline-snapshot/external/a140c0c1eda2def2b830363ba362aa4d7d255c262960544821f556e16661b6ff.txt": "test5", + ".inline-snapshot/external/a4e624d686e03ed2767c0abd85c14426b0b1157d2ce81d27bb4fe4f6f01d688a.txt": "test4", + ".inline-snapshot/external/ed0cb90bdfa4f93981a7d03cff99213a86aa96a6cbcf89ec5e8889871f088727.txt": "test6", + "tests/test_a.py": """\ + +from inline_snapshot import outsource,snapshot,external + +def test_a(): + assert outsource("test4") == snapshot(external("hash:a4e624d686e0*.txt")) + assert outsource("test5") == snapshot(external("hash:a140c0c1eda2*.txt")) + assert outsource("test6") == snapshot(external("hash:ed0cb90bdfa4*.txt")) + if False: + assert outsource("test4") == external("hash:a*.txt") +""", + } + ), + ) + .replace("False", "True") + ) + + with raises( + snapshot( + "StorageLookupError: hash collision files=['a140c0c1eda2def2b830363ba362aa4d7d255c262960544821f556e16661b6ff.txt', 'a4e624d686e03ed2767c0abd85c14426b0b1157d2ce81d27bb4fe4f6f01d688a.txt']" + ) + ): + e.run_inline() + + +def test_hash_not_found(): + with raises( + snapshot( + "StorageLookupError: hash 'bbbbb*.txt' is not found in the HashStorage" + ) + ): + with snapshot_env(): assert outsource("test4") == external("hash:bbbbb*.txt") -def test_update_legacy_external_names(project): +def test_update_legacy_external_names(): ( Example( - """\ + { + "pyproject.toml": """\ +[tool.inline-snapshot] +default-storage="hash" +""", + "tests/test_something.py": """\ from inline_snapshot import outsource,snapshot def test_something(): assert outsource("foo") == snapshot() -""" +""", + } ) .run_pytest( ["--inline-snapshot=create"], @@ -179,13 +228,10 @@ def test_pytest_compare_external_bytes(project): from inline_snapshot import external def test_a(): - assert outsource(b"test") == snapshot( - external("hash:9f86d081884c*.bin") - ) + s=snapshot() + assert outsource(b"test") == s - assert outsource(b"test2") == snapshot( - external("hash:9f86d081884c*.bin") - ) + assert outsource(b"test2") == s """ ) @@ -194,7 +240,7 @@ def test_a(): assert result.errorLines() == ( snapshot( """\ -> assert outsource(b"test2") == snapshot( +> assert outsource(b"test2") == s E AssertionError: assert b'test2' == b'test' E \n\ E Use -v to get more diff @@ -203,7 +249,7 @@ def test_a(): if sys.version_info >= (3, 11) else snapshot( """\ -> assert outsource(b"test2") == snapshot( +> assert outsource(b"test2") == s E AssertionError """ ) @@ -227,12 +273,19 @@ def test_a(): from inline_snapshot import external def test_a(): - assert outsource("test") == snapshot(external("hash:9f86d081884c*.txt")) + assert outsource("test") == snapshot(external("uuid:f728b4fa-4248-4e3a-8a5d-2f346baa9455.txt")) """ ) def test_pytest_trim_external(project): + project.pyproject( + """\ +[tool.inline-snapshot] +default-storage="hash" +""" + ) + project.setup( """\ def test_a(): @@ -288,6 +341,13 @@ def test_a(): def test_pytest_new_external(project): + project.pyproject( + """\ +[tool.inline-snapshot] +default-storage="hash" +""" + ) + project.setup( """\ def test_a(): @@ -296,9 +356,7 @@ def test_a(): ) project.run() - assert project.storage() == snapshot( - ["9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08.txt"] - ) + assert project.storage() == snapshot([]) project.run("--inline-snapshot=create") @@ -308,6 +366,12 @@ def test_a(): def test_pytest_config_hash_length(project): + project.pyproject( + """\ +[tool.inline-snapshot] +default-storage="hash" +""" + ) project.setup( """\ def test_a(): @@ -415,7 +479,7 @@ def test_something(): from inline_snapshot import external def test_something(): from inline_snapshot import outsource,snapshot - assert outsource("test") == snapshot(external("hash:9f86d081884c*.txt")) + assert outsource("test") == snapshot(external("uuid:f728b4fa-4248-4e3a-8a5d-2f346baa9455.txt")) test_something() \ """ @@ -432,7 +496,7 @@ def test_ensure_imports(tmp_path): ) with apply_changes() as recorder: - ensure_import(file, {"os": ["chdir", "environ"]}, recorder) + ensure_import(file, {"os": ["chdir", "environ"]}, set(), recorder) assert file.read_text("utf-8") == snapshot( """\ @@ -453,7 +517,7 @@ def test_ensure_imports_with_comment(tmp_path): ) with apply_changes() as recorder: - ensure_import(file, {"os": ["chdir"]}, recorder) + ensure_import(file, {"os": ["chdir"]}, set(), recorder) assert file.read_text("utf-8") == snapshot( """\ @@ -464,7 +528,36 @@ def test_ensure_imports_with_comment(tmp_path): ) +def test_ensure_imports_with_docstring(tmp_path): + file = tmp_path / "file.py" + file.write_bytes( + b"""\ +''' docstring ''' +from __future__ import annotations +""" + ) + + with apply_changes() as recorder: + ensure_import(file, {"os": ["chdir"]}, set(), recorder) + + assert file.read_text("utf-8") == snapshot( + """\ +''' docstring ''' +from __future__ import annotations + +from os import chdir +""" + ) + + def test_new_externals(project): + project.pyproject( + """\ +[tool.inline-snapshot] +default-storage="hash" +""" + ) + project.setup( """ @@ -479,10 +572,7 @@ def test_something(): project.run("--inline-snapshot=create") assert project.storage() == snapshot( - [ - "2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae.txt", - "8dc140e6fe831481a2005ae152ffe32a9974aa92a260dfbac780d6a87154bb0b.txt", - ] + ["2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae.txt"] ) assert project.source == snapshot( @@ -504,10 +594,7 @@ def test_something(): project.run() assert project.storage() == snapshot( - [ - "2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae.txt", - "8dc140e6fe831481a2005ae152ffe32a9974aa92a260dfbac780d6a87154bb0b.txt", - ] + ["2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae.txt"] ) @@ -540,15 +627,15 @@ def test_something(): ["--inline-snapshot=create"], changed_files=snapshot( { - ".inline-snapshot/external/2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae.txt": "foo", - "tests/__inline_snapshot__/test_something/test_something/e3e70682-c209-4cac-a29f-6fbed82c07cd.txt": "foo", + "tests/__inline_snapshot__/test_something/test_something/eb1167b3-67a9-4378-bc65-c1e582e2e662.txt": "foo", + "tests/__inline_snapshot__/test_something/test_something/f728b4fa-4248-4e3a-8a5d-2f346baa9455.txt": "foo", "tests/test_something.py": """\ from inline_snapshot import external, snapshot,outsource def test_something(): - assert outsource("foo") == snapshot(external("hash:2c26b46b68ff*.txt")) - assert "foo" == external("uuid:e3e70682-c209-4cac-a29f-6fbed82c07cd.txt") + assert outsource("foo") == snapshot(external("uuid:eb1167b3-67a9-4378-bc65-c1e582e2e662.txt")) + assert "foo" == external("uuid:f728b4fa-4248-4e3a-8a5d-2f346baa9455.txt") """, } ), @@ -569,7 +656,7 @@ def test_something(): error=snapshot( """\ > assert "foo" == external("hash:aaaaaaaaaaaa*.txt") -> raise StorageLookupError(f"hash {name!r} is not found in the HashStorage") +> raise StorageLookupError( E inline_snapshot._external._storage._protocol.StorageLookupError: hash 'aaaaaaaaaaaa*.txt' is not found in the HashStorage """ ), diff --git a/tests/external/test_module_name.py b/tests/external/test_module_name.py new file mode 100644 index 00000000..78d84aee --- /dev/null +++ b/tests/external/test_module_name.py @@ -0,0 +1,59 @@ +"""Tests for module_name_of function coverage.""" + +from inline_snapshot import snapshot +from inline_snapshot.testing import Example + + +def test_module_name_init_py_with_snapshot(): + """Test module_name_of when __init__.py itself contains the snapshot call.""" + Example( + { + "tests/__init__.py": "", + "tests/mypackage/b.py": """\ +from dataclasses import dataclass + +@dataclass +class B: + b:int +""", + "tests/mypackage/__init__.py": """\ +from inline_snapshot import snapshot +from dataclasses import dataclass + +@dataclass +class A: + a:int + +s = snapshot() +""", + "tests/test_something.py": """\ +from inline_snapshot import snapshot + +def test_a(): + from .mypackage import s,A + from .mypackage.b import B + assert s == [A(5),B(5)] +""", + } + ).run_pytest( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "tests/mypackage/__init__.py": """\ +from inline_snapshot import snapshot +from dataclasses import dataclass + +from tests.mypackage.b import B + +@dataclass +class A: + a:int + +s = snapshot([A(a=5), B(b=5)]) +""" + } + ), + returncode=1, + ).run_pytest( + ["--inline-snapshot=disable"] + ) diff --git a/tests/test_builder.py b/tests/test_builder.py new file mode 100644 index 00000000..3ef6505d --- /dev/null +++ b/tests/test_builder.py @@ -0,0 +1,34 @@ +from inline_snapshot import snapshot +from inline_snapshot.testing._example import Example + + +def test_default_arg(): + + e = Example( + { + "conftest.py": """\ +from inline_snapshot.plugin import customize + +class InlineSnapshotPlugin: + @customize + def handler(self,value,builder): + if value==5: + return builder.with_default(5,builder.create_code("8")) +""", + "test_a.py": """\ +from inline_snapshot import snapshot +def test_a(): + assert 5==snapshot() +""", + } + ) + + e.run_inline( + ["--inline-snapshot=create"], + raises=snapshot( + """\ +UsageError: +default value can not be an Custom value\ +""" + ), + ) diff --git a/tests/test_code_repr.py b/tests/test_code_repr.py index 3657f3b5..8ceb291e 100644 --- a/tests/test_code_repr.py +++ b/tests/test_code_repr.py @@ -19,33 +19,91 @@ def test_enum(check_update): - assert ( - check_update( - """ + Example( + { + "color.py": """ from enum import Enum class color(Enum): val="val" +def get_color(): + return [color.val,color.val] -assert [color.val] == snapshot() + """, + "test_color.py": """\ +from inline_snapshot import snapshot +from color import get_color - """, - flags="create", - ) - == snapshot( - """\ +def test_enum(): + assert get_color() == snapshot() +""", + } + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "test_color.py": """\ +from inline_snapshot import snapshot +from color import get_color -from enum import Enum +from color import color -class color(Enum): - val="val" +def test_enum(): + assert get_color() == snapshot([color.val, color.val]) +""" + } + ), + ) + + +def test_path(): + Example( + """\ +from pathlib import Path,PurePath +from inline_snapshot import snapshot -assert [color.val] == snapshot([color.val]) +folder="a" +def test_a(): + assert Path(folder,"b.txt") == snapshot() + assert PurePath(folder,"b.txt") == snapshot() """ - ) + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from pathlib import Path,PurePath +from inline_snapshot import snapshot + +folder="a" + +def test_a(): + assert Path(folder,"b.txt") == snapshot(Path("a/b.txt")) + assert PurePath(folder,"b.txt") == snapshot(PurePath("a/b.txt")) +""" + } + ), + ).replace( + '"a"', '"c"' + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from pathlib import Path,PurePath +from inline_snapshot import snapshot + +folder="c" + +def test_a(): + assert Path(folder,"b.txt") == snapshot(Path("c/b.txt")) + assert PurePath(folder,"b.txt") == snapshot(PurePath("c/b.txt")) +""" + } + ), ) @@ -334,11 +392,21 @@ def test_datatypes_explicit(): assert code_repr(default_dict) == snapshot("defaultdict(list, {5: [2], 3: [1]})") -def test_tuple(): +def test_fake_tuple1(): + + class FakeTuple(tuple): + _fields = 5 + + def __repr__(self): + return "FakeTuple()" + + assert code_repr(FakeTuple()) == snapshot("FakeTuple()") + + +def test_fake_tuple2(): class FakeTuple(tuple): - def __init__(self): - self._fields = 5 + _fields = (1, 2) def __repr__(self): return "FakeTuple()" @@ -347,9 +415,11 @@ def __repr__(self): def test_invalid_repr(check_update): - assert ( - check_update( - """\ + + Example( + """\ +from inline_snapshot import snapshot + class Thing: def __repr__(self): return "+++" @@ -359,12 +429,18 @@ def __eq__(self,other): return NotImplemented return True -assert Thing() == snapshot() -""", - flags="create", - ) - == snapshot( - """\ +def test_a(): + assert Thing() == snapshot() +""" + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from inline_snapshot import snapshot + +from inline_snapshot import HasRepr + class Thing: def __repr__(self): return "+++" @@ -374,9 +450,11 @@ def __eq__(self,other): return NotImplemented return True -assert Thing() == snapshot(HasRepr(Thing, "+++")) +def test_a(): + assert Thing() == snapshot(HasRepr(Thing, "+++")) """ - ) + } + ), ) diff --git a/tests/test_customize.py b/tests/test_customize.py new file mode 100644 index 00000000..a27b2f29 --- /dev/null +++ b/tests/test_customize.py @@ -0,0 +1,406 @@ +import pytest + +from inline_snapshot import snapshot +from inline_snapshot.testing import Example + + +@pytest.mark.parametrize( + "original,flag", [("'a'", "update"), ("'b'", "fix"), ("", "create")] +) +def test_custom_dirty_equal(original, flag): + + Example( + { + "tests/conftest.py": """\ +from inline_snapshot.plugin import customize +from inline_snapshot.plugin import Builder +from dirty_equals import IsStr + +class InlineSnapshotPlugin: + @customize + def re_handler(self,value, builder: Builder): + if value == IsStr(regex="[a-z]"): + return builder.create_call(IsStr, [], {"regex": "[a-z]"}) +""", + "tests/test_something.py": f"""\ +from inline_snapshot import snapshot + +def test_a(): + assert snapshot({original}) == "a" +""", + } + ).run_inline( + [f"--inline-snapshot={flag}"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from inline_snapshot import snapshot + +from dirty_equals import IsStr + +def test_a(): + assert snapshot(IsStr(regex="[a-z]")) == "a" +""" + } + ), + ) + + +@pytest.mark.parametrize( + "original,flag", + [("{'1': 1, '2': 2}", "update"), ("5", "fix"), ("", "create")], +) +def test_create_imports(original, flag): + + Example( + { + "tests/test_something.py": f"""\ +from inline_snapshot import snapshot + +def counter(): + from collections import Counter + return Counter("122") + +def test(): + assert counter() == snapshot({original}) +""" + } + ).run_inline( + [f"--inline-snapshot={flag}"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from inline_snapshot import snapshot + +from collections import Counter + +def counter(): + from collections import Counter + return Counter("122") + +def test(): + assert counter() == snapshot(Counter({"1": 1, "2": 2})) +""" + } + ), + ) + + +@pytest.mark.parametrize( + "original,flag", + [("ComplexObj(1, 2)", "update"), ("'wrong'", "fix"), ("", "create")], +) +def test_with_import(original, flag): + """Test that with_import adds both simple and nested module import statements correctly.""" + + Example( + { + "conftest.py": """\ +from inline_snapshot.plugin import customize +from inline_snapshot.plugin import Builder, Import +from pkg.subpkg import ComplexObj + +class InlineSnapshotPlugin: + @customize + def complex_handler(self, value, builder: Builder): + if isinstance(value, ComplexObj): + return builder.create_code( + f"mod1.helper(pkg.subpkg.create({value.a!r}, {value.b!r}))", + imports=[Import("mod1"), Import("pkg.subpkg")] + ) +""", + "mod1.py": """\ +def helper(obj): + return obj +""", + "pkg/__init__.py": "", + "pkg/subpkg.py": """\ +class ComplexObj: + def __init__(self, a, b): + self.a = a + self.b = b + + def __eq__(self, other): + return isinstance(other, ComplexObj) and self.a == other.a and self.b == other.b + +def create(a, b): + return ComplexObj(a, b) +""", + "test_something.py": f"""\ +from inline_snapshot import snapshot +from pkg.subpkg import ComplexObj + +def test_a(): + assert snapshot({original}) == ComplexObj(1, 2) +""", + } + ).run_inline( + [f"--inline-snapshot={flag}"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot +from pkg.subpkg import ComplexObj + +import mod1 +import pkg.subpkg + +def test_a(): + assert snapshot(mod1.helper(pkg.subpkg.create(1, 2))) == ComplexObj(1, 2) +""" + } + ), + ).run_inline() + + +@pytest.mark.parametrize( + "original,flag", [("MyClass('value')", "update"), ("'wrong'", "fix")] +) +@pytest.mark.parametrize("existing_import", ["\nimport mymodule\n", ""]) +def test_with_import_preserves_existing(original, flag, existing_import): + """Test that with_import preserves existing import statements.""" + + Example( + { + "conftest.py": """\ +from inline_snapshot.plugin import customize +from inline_snapshot.plugin import Builder, Import +from mymodule import MyClass + +class InlineSnapshotPlugin: + @customize + def myclass_handler(self, value, builder: Builder): + if isinstance(value, MyClass): + return builder.create_code( + f"mymodule.MyClass({value.value!r})", + imports=[Import("mymodule")] + ) +""", + "mymodule.py": """\ +class MyClass: + def __init__(self, value): + self.value = value + + def __eq__(self, other): + return isinstance(other, MyClass) and self.value == other.value +""", + "test_something.py": f"""\ +from inline_snapshot import snapshot +from mymodule import MyClass + +import os # just another import +{existing_import}\ + +def test_a(): + assert snapshot({original}) == MyClass("value") +""", + } + ).run_inline( + [f"--inline-snapshot={flag}"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot +from mymodule import MyClass + +import os # just another import + +import mymodule + +def test_a(): + assert snapshot(mymodule.MyClass("value")) == MyClass("value") +""" + } + ), + ).run_inline() + + +def test_customized_value_mismatch_error(): + """Test that UsageError is raised when customized value doesn't match original.""" + + Example( + { + "conftest.py": """\ +from inline_snapshot.plugin import customize +from inline_snapshot.plugin import Builder + +class InlineSnapshotPlugin: + @customize + def bad_handler(self, value, builder: Builder): + if value == 42: + # Return a CustomCode with wrong value - repr evaluates to 100 but original is 42 + return builder.create_code("100") +""", + "test_something.py": """\ +from inline_snapshot import snapshot + +def test_a(): + assert snapshot() == 42 +""", + } + ).run_inline( + ["--inline-snapshot=create"], + raises=snapshot( + """\ +UsageError: +Customized value does not match original value: + +original_value=42 + +customized_value=100 +customized_representation=CustomCode('100') +""" + ), + ) + + +@pytest.mark.parametrize("original,flag", [("'wrong'", "fix"), ("", "create")]) +def test_global_var_lookup(original, flag): + """Test that create_code can look up global variables.""" + + Example( + { + "conftest.py": """\ +from inline_snapshot.plugin import customize +from inline_snapshot.plugin import Builder + +class InlineSnapshotPlugin: + @customize + def use_global(self, value, builder: Builder): + if value == "test_value": + return builder.create_code("GLOBAL_VAR") +""", + "test_something.py": f"""\ +from inline_snapshot import snapshot + +GLOBAL_VAR = "test_value" + +def test_a(): + assert snapshot({original}) == "test_value" +""", + } + ).run_inline( + [f"--inline-snapshot={flag}"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot + +GLOBAL_VAR = "test_value" + +def test_a(): + assert snapshot(GLOBAL_VAR) == "test_value" +""" + } + ), + ) + + +@pytest.mark.parametrize("original,flag", [("'wrong'", "fix"), ("", "create")]) +def test_file_handler(original, flag): + """Test that __file__ handler creates correct code.""" + + Example( + { + "test_something.py": f"""\ +from inline_snapshot import snapshot + +def test_a(): + assert snapshot({original}) == __file__ +""", + } + ).run_inline( + [f"--inline-snapshot={flag}"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot + +def test_a(): + assert snapshot(__file__) == __file__ +""" + } + ), + ) + + +def test_datetime_types(): + """Test that datetime types generate correct snapshots with proper imports.""" + + Example( + { + "test_something.py": """\ +from datetime import datetime, date, time, timedelta +from inline_snapshot import snapshot + +def test_datetime_types(): + assert snapshot() == datetime(2024, 1, 15, 10, 30, 45, 123456) + assert snapshot() == date(2024, 1, 15) + assert snapshot() == time(10, 30, 45, 123456) + assert snapshot() == timedelta(days=1, hours=2, minutes=30) + assert snapshot() == timedelta(seconds=5, microseconds=123456) +""", + } + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "test_something.py": """\ +from datetime import datetime, date, time, timedelta +from inline_snapshot import snapshot + +def test_datetime_types(): + assert snapshot(datetime(2024, 1, 15, hour=10, minute=30, second=45, microsecond=123456)) == datetime(2024, 1, 15, 10, 30, 45, 123456) + assert snapshot(date(2024, 1, 15)) == date(2024, 1, 15) + assert snapshot(time(hour=10, minute=30, second=45, microsecond=123456)) == time(10, 30, 45, 123456) + assert snapshot(timedelta(days=1, seconds=9000)) == timedelta(days=1, hours=2, minutes=30) + assert snapshot(timedelta(seconds=5, microseconds=123456)) == timedelta(seconds=5, microseconds=123456) +""" + } + ), + ).run_inline() + + +def test_custom_children(): + Example( + { + "c.py": """\ +from dataclasses import dataclass + +@dataclass +class C: + i:int +""", + "conftest.py": """\ +from inline_snapshot.plugin import customize +from c import C + +@customize +def handler(value,builder): + if isinstance(value,C) and value.i==2: + return builder.create_call(C,[],{"i":builder.create_code("1+1")}) + """, + "test_example.py": """\ +from inline_snapshot import snapshot +from c import C + +def test(): + assert C(i=2) == snapshot() +""", + } + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "test_example.py": """\ +from inline_snapshot import snapshot +from c import C + +def test(): + assert C(i=2) == snapshot(C(i=1 + 1)) +""" + } + ), + ).run_inline( + ["--inline-snapshot=fix"] + ) diff --git a/tests/test_dirty_equals.py b/tests/test_dirty_equals.py index 2d51188f..a1231a47 100644 --- a/tests/test_dirty_equals.py +++ b/tests/test_dirty_equals.py @@ -159,3 +159,77 @@ def test_number(): } ), ) + + +def test_is_now_without_approx() -> None: + """Test that IsNow handler correctly removes 'approx' kwarg.""" + + Example( + """\ +from datetime import datetime +from dirty_equals import IsNow +from inline_snapshot import snapshot + +def test_time(): + assert datetime.now() == snapshot() +""" + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from datetime import datetime +from dirty_equals import IsNow +from inline_snapshot import snapshot + +def test_time(): + assert datetime.now() == snapshot(IsNow()) +""" + } + ), + ) + + +def test_is_now_with_delta() -> None: + """Test that IsNow handler works with custom delta parameter in customization.""" + + Example( + { + "conftest.py": """\ +from datetime import timedelta +from inline_snapshot.plugin import customize +from dirty_equals import IsNow + +@customize +def is_now_with_tolerance(value): + if value == IsNow(delta=timedelta(minutes=10)): + return IsNow(delta=timedelta(minutes=10)) +""", + "test_something.py": """\ +from datetime import datetime, timedelta +from inline_snapshot import snapshot + +def test_time(): + # Compare a time from 5 minutes ago with snapshot + past_time = datetime.now() - timedelta(minutes=5) + assert past_time == snapshot() +""", + } + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "test_something.py": """\ +from datetime import datetime, timedelta +from inline_snapshot import snapshot + +from dirty_equals import IsNow + +def test_time(): + # Compare a time from 5 minutes ago with snapshot + past_time = datetime.now() - timedelta(minutes=5) + assert past_time == snapshot(IsNow(delta=timedelta(seconds=600))) +""" + } + ), + ).run_inline() diff --git a/tests/test_docs.py b/tests/test_docs.py index 26ea291c..9c85ae2f 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -13,10 +13,12 @@ from typing import Optional from typing import TypeVar +import isort.api import pytest from executing import is_pytest_compatible from inline_snapshot import snapshot +from inline_snapshot._align import align from inline_snapshot._external._external_file import external_file from inline_snapshot._flags import Flags from inline_snapshot._global_state import snapshot_env @@ -29,7 +31,7 @@ class Block: code: str code_header: Optional[str] - block_options: str + block_options: Dict[str, str] line: int @@ -47,17 +49,19 @@ def map_code_blocks(file: Path, func): code = None indent = "" block_start_linenum: Optional[int] = None - block_options: Optional[str] = None + block_options: dict[str, str] = {} code_header = None header_line = "" + block_found = False for linenumber, line in enumerate(current_code.splitlines(), start=1): m = block_start.fullmatch(line) if m and not is_block: # ``` python + block_found = True block_start_linenum = linenumber indent = m[1] - block_options = m[2] + block_options = {m[0]: m[1] for m in re.findall(r'(\w*)="([^"]*)"', m[2])} block_lines = [] is_block = True continue @@ -65,7 +69,6 @@ def map_code_blocks(file: Path, func): if block_end.fullmatch(line.strip()) and is_block: # ``` is_block = False - assert block_options is not None assert block_start_linenum is not None code = "\n".join(block_lines) + "\n" @@ -92,9 +95,9 @@ def map_code_blocks(file: Path, func): if new_block.code_header is not None: new_lines.append(f"{indent}") - new_lines.append( - f"{indent}``` {('python '+new_block.block_options.strip()).strip()}" - ) + options = " ".join(f'{k}="{v}"' for k, v in new_block.block_options.items()) + + new_lines.append(f"{indent}``` {('python '+options).strip()}") new_code = new_block.code.rstrip() if file.suffix == ".py": @@ -130,7 +133,8 @@ def map_code_blocks(file: Path, func): new_code = "\n".join(new_lines) + "\n" - assert external_file(file, format=".txt") == new_code + if block_found: + assert external_file(file, format=".txt") == new_code def test_map_code_blocks(tmp_path): @@ -202,12 +206,12 @@ def test_block(block): blocks=snapshot( [ Block( - code="print(1 + 1)\n", code_header=None, block_options="", line=2 + code="print(1 + 1)\n", code_header=None, block_options={}, line=2 ), Block( code="print(1 - 1)\n", code_header="inline-snapshot: create test", - block_options=' hl_lines="1 2 3"', + block_options={"hl_lines": "1 2 3"}, line=7, ), ] @@ -217,7 +221,7 @@ def test_block(block): def change_block(block): block.code = "# removed" block.code_header = "header" - block.block_options = "option a b c" + block.block_options = {"a": "b c"} test_doc( """\ @@ -232,7 +236,7 @@ def change_block(block): Block( code="# removed", code_header="header", - block_options="option a b c", + block_options={"a": "b c"}, line=2, ) ] @@ -241,7 +245,7 @@ def change_block(block): """\ text -``` python option a b c +``` python a="b c" # removed ``` """ @@ -286,7 +290,6 @@ def __eq__(self, other: Any): def file_test( file: Path, width: int = 80, - use_hl_lines: bool = True, ): """Test code blocks with the header @@ -310,15 +313,18 @@ def file_test( import datetime import pytest from freezegun.api import FakeDatetime,FakeDate -from inline_snapshot import customize_repr +from inline_snapshot.plugin import customize -@customize_repr -def _(value:FakeDatetime): - return value.__repr__().replace("FakeDatetime","datetime.datetime") +class InlineSnapshotPlugin: + @customize + def fakedatetime_handler(self,value,builder): + if isinstance(value,FakeDatetime): + return builder.create_code(value.__repr__().replace("FakeDatetime","datetime.datetime")) -@customize_repr -def _(value:FakeDate): - return value.__repr__().replace("FakeDate","datetime.date") + @customize + def fakedate_handler(self,value,builder): + if isinstance(value,FakeDate): + return builder.create_code(value.__repr__().replace("FakeDate","datetime.date")) @pytest.fixture(autouse=True) @@ -347,11 +353,15 @@ def test_block(block: Block): return block if block.code_header.startswith("inline-snapshot-lib:"): - extra_files[block.code_header.split()[1]].append(block.code) + name = block.code_header.split()[1] + extra_files[name].append(block.code) + block.block_options["title"] = name return block if block.code_header.startswith("inline-snapshot-lib-set:"): - extra_files[block.code_header.split()[1]] = [block.code] + name = block.code_header.split()[1] + extra_files[name] = [block.code] + block.block_options["title"] = name return block if block.code_header.startswith("todo-inline-snapshot:"): @@ -381,6 +391,15 @@ def test_block(block: Block): assert last_code is not None test_files = {"tests/test_example.py": last_code} else: + code = isort.api.sort_code_string( + code, + config=isort.Config( + profile="black", + combine_as_imports=True, + lines_between_sections=0, + ), + ) + block.code = code test_files = {"tests/test_example.py": code} example = Example({**std_files, **test_files}) @@ -427,30 +446,23 @@ def test_block(block: Block): ] ) - if use_hl_lines: - from inline_snapshot._align import align - - linenum = 1 - hl_lines = "" - - if last_code is not None and "first_block" not in options: - changed_lines = [] - alignment = align(last_code.split("\n"), new_code.split("\n")) - for c in alignment: - if c == "d": - continue - elif c == "m": - linenum += 1 - else: - changed_lines.append(str(linenum)) - linenum += 1 - if changed_lines: - hl_lines = f'hl_lines="{" ".join(changed_lines)}"' + linenum = 1 + + if last_code is not None and "first_block" not in options: + changed_lines = [] + alignment = align(last_code.split("\n"), new_code.split("\n")) + for c in alignment: + if c == "d": + continue + elif c == "m": + linenum += 1 else: - assert False, "no lines changed" - block.block_options = hl_lines - else: - pass # pragma: no cover + changed_lines.append(str(linenum)) + linenum += 1 + if changed_lines: + block.block_options["hl_lines"] = " ".join(changed_lines) + else: + assert False, "no lines changed" block.code = new_code @@ -467,4 +479,4 @@ def test_block(block: Block): print(file) - file_test(file, width=60, use_hl_lines=True) + file_test(file, width=60) diff --git a/tests/test_formatting.py b/tests/test_formatting.py index f918f1e6..706ce274 100644 --- a/tests/test_formatting.py +++ b/tests/test_formatting.py @@ -34,7 +34,7 @@ def test_something(): def test_something(): assert 1==snapshot(1) assert 1==snapshot(1) - assert list(range(20)) == snapshot([0 ,1 ,2 ,3 ,4 ,5 ,6 ,7 ,8 ,9 ,10 ,11 ,12 ,13 ,14 ,15 ,16 ,17 ,18 ,19 ]) + assert list(range(20)) == snapshot([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]) """ } ), @@ -60,8 +60,8 @@ def test_something(): | + assert 1==snapshot(1) | | assert 1==snapshot(2) | | - assert list(range(20)) == snapshot() | -| + assert list(range(20)) == snapshot([0 ,1 ,2 ,3 ,4 ,5 ,6 ,7 ,8 ,9 ,10 | -| ,11 ,12 ,13 ,14 ,15 ,16 ,17 ,18 ,19 ]) | +| + assert list(range(20)) == snapshot([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, | +| 11, 12, 13, 14, 15, 16, 17, 18, 19]) | +------------------------------------------------------------------------------+ These changes will be applied, because you used create @@ -74,8 +74,8 @@ def test_something(): | assert 1==snapshot(1) | | - assert 1==snapshot(2) | | + assert 1==snapshot(1) | -| assert list(range(20)) == snapshot([0 ,1 ,2 ,3 ,4 ,5 ,6 ,7 ,8 ,9 ,10 | -| ,11 ,12 ,13 ,14 ,15 ,16 ,17 ,18 ,19 ]) | +| assert list(range(20)) == snapshot([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, | +| 11, 12, 13, 14, 15, 16, 17, 18, 19]) | +------------------------------------------------------------------------------+ These changes will be applied, because you used fix diff --git a/tests/test_preserve_values.py b/tests/test_preserve_values.py index 3017e144..bbaacbec 100644 --- a/tests/test_preserve_values.py +++ b/tests/test_preserve_values.py @@ -174,7 +174,7 @@ def test_preserve_case_from_original_mr(check_update): 5, ], }, - "e": ({"f": 6, "g": 7},), + "e": ({"f": 3 + 3, "g": 7},), } ) """ diff --git a/tests/test_pydantic.py b/tests/test_pydantic.py index fdb9f016..de11ca65 100644 --- a/tests/test_pydantic.py +++ b/tests/test_pydantic.py @@ -1,3 +1,6 @@ +import pydantic +import pytest + from inline_snapshot import snapshot from inline_snapshot.testing import Example @@ -181,3 +184,46 @@ def test_something(): ), returncode=1, ) + + +@pytest.mark.skipif( + pydantic.version.VERSION.startswith("1."), + reason="pydantic 1 can not compare C[int]() with C()", +) +def test_pydantic_generic_class(): + Example( + """\ +from typing import Generic, TypeVar +from inline_snapshot import snapshot +from pydantic import BaseModel + +I=TypeVar("I") +class C(BaseModel,Generic[I]): + a:int + +def test_a(): + c=C[int](a=5) + + assert c == snapshot() +""" + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "tests/test_something.py": """\ +from typing import Generic, TypeVar +from inline_snapshot import snapshot +from pydantic import BaseModel + +I=TypeVar("I") +class C(BaseModel,Generic[I]): + a:int + +def test_a(): + c=C[int](a=5) + + assert c == snapshot(C(a=5)) +""" + } + ), + ).run_inline() diff --git a/tests/test_pytest_plugin.py b/tests/test_pytest_plugin.py index 2a9cb981..aa11b2b9 100644 --- a/tests/test_pytest_plugin.py +++ b/tests/test_pytest_plugin.py @@ -814,6 +814,7 @@ def test_storage_dir_config(project, tmp_path, storage_dir): "pyproject.toml": f"""\ [tool.inline-snapshot] storage-dir = {str(storage_dir)!r} +default-storage="hash" """, "tests/test_a.py": """\ from inline_snapshot import outsource, snapshot @@ -871,6 +872,7 @@ def test_find_pyproject_in_parent_directories(): "pyproject.toml": """\ [tool.inline-snapshot] hash-length=2 +default-storage="hash" """, "project/pytest.ini": "", "project/test_something.py": """\ @@ -904,6 +906,7 @@ def test_find_pyproject_in_workspace_project(): "sub_project/pyproject.toml": """\ [tool.inline-snapshot] hash-length=2 +default-storage="hash" """, "pyproject.toml": "[tool.pytest.ini_options]", "sub_project/test_something.py": """\ diff --git a/tests/test_without_node.py b/tests/test_without_node.py new file mode 100644 index 00000000..5f13d9c5 --- /dev/null +++ b/tests/test_without_node.py @@ -0,0 +1,74 @@ +import pytest +from executing import is_pytest_compatible + +from inline_snapshot import snapshot +from inline_snapshot.testing import Example + + +@pytest.mark.skipIf( + is_pytest_compatible, reason="this is only a problem when executing can return None" +) +def test_without_node(): + + Example( + { + "conftest.py": """\ +from inline_snapshot.plugin import customize + +@customize +def handler(value,builder): + if value=="foo": + return builder.create_code("'foo'") +""", + "test_example.py": """\ +from inline_snapshot import snapshot +from dirty_equals import IsStr + +def test_foo(): + assert "not_foo" == snapshot(IsStr()) +""", + } + ).run_pytest() + + +def test_custom_default_case_in_ValueToCustom(executing_used): + Example( + """\ +from inline_snapshot import snapshot +from dataclasses import dataclass + +@dataclass +class A: + a:int=5 + +def test_something(): + assert A(a=3) == snapshot(A(a=5)),"not equal" +""" + ).run_inline( + changed_files=snapshot({}), + raises=snapshot( + """\ +AssertionError: +not equal\ +""" + ), + ) + + +def test_tuple_case_in_ValueToCustom(executing_used): + Example( + """\ +from inline_snapshot import snapshot +from dataclasses import dataclass + +@dataclass +class A: + a:int=5 + +def test_something(): + assert (1,2) == snapshot((1,2)),"not equal" +""" + ).run_inline( + changed_files=snapshot({}), + raises=snapshot(None), + )