inline-snapshot
Advanced tools
| import pytest | ||
| from inline_snapshot import snapshot | ||
| from inline_snapshot.testing import Example | ||
| def _format_call(args, kwargs): | ||
| """Format a Call constructor with args and kwargs""" | ||
| args_str = ", ".join(str(a) for a in args) | ||
| kwargs_str = ", ".join(f"{k}={v}" for k, v in kwargs.items()) | ||
| parts = [p for p in [args_str, kwargs_str] if p] | ||
| return f"Call({', '.join(parts)})" | ||
| @pytest.mark.parametrize("actual_args", [(), (1,), (1, 0), (1, 0, 2), (5,)]) | ||
| @pytest.mark.parametrize( | ||
| "actual_kwargs", [{}, {"kw1": 2}, {"kw1": 2, "kw2": 3}, {"x": 1}] | ||
| ) | ||
| @pytest.mark.parametrize("snap_args", [(), (0,), (0, 1), (1, 2, 3), (5,)]) | ||
| @pytest.mark.parametrize( | ||
| "snap_kwargs", [{}, {"kw2": 5}, {"kw1": 5}, {"kw1": 2, "kw2": 3}, {"x": 1}] | ||
| ) | ||
| def test_call_fix(actual_args, actual_kwargs, snap_args, snap_kwargs): | ||
| """Test that fix flag properly updates mismatched Call snapshots""" | ||
| actual_call = _format_call(actual_args, actual_kwargs) | ||
| snap_call = _format_call(snap_args, snap_kwargs) | ||
| # Skip cases where actual and snapshot are the same | ||
| if actual_call == snap_call: | ||
| pytest.skip("No mismatch to fix") | ||
| Example( | ||
| { | ||
| "call_type.py": """\ | ||
| class Call: | ||
| def __init__(self, *args, **kwargs): | ||
| self.args = args | ||
| self.kwargs = kwargs | ||
| def __eq__(self, other): | ||
| if not isinstance(other, Call): | ||
| return NotImplemented | ||
| return self.args == other.args and self.kwargs == other.kwargs | ||
| """, | ||
| "conftest.py": """\ | ||
| from inline_snapshot.plugin import customize | ||
| from call_type import Call | ||
| @customize | ||
| def call_handler(value, builder): | ||
| if isinstance(value, Call): | ||
| return builder.create_call( | ||
| Call, | ||
| list(value.args), | ||
| dict(value.kwargs), | ||
| ) | ||
| """, | ||
| "test_call.py": f"""\ | ||
| from inline_snapshot import snapshot | ||
| from call_type import Call | ||
| def test_thing(): | ||
| assert {actual_call} == snapshot({snap_call}), "not equal" | ||
| """, | ||
| } | ||
| ).run_inline( | ||
| ["--inline-snapshot=fix"], | ||
| ).run_inline() | ||
| def test_call_map_exception(): | ||
| """Test that CustomCall._map raises TypeError with proper message when call fails""" | ||
| Example( | ||
| { | ||
| "bad_call_type.py": """\ | ||
| class BadCall: | ||
| def __init__(self, *args): | ||
| self.args = args | ||
| assert args, "args required" | ||
| def __eq__(self, other): | ||
| if not isinstance(other, BadCall): | ||
| return NotImplemented | ||
| return True | ||
| """, | ||
| "conftest.py": """\ | ||
| from inline_snapshot.plugin import customize | ||
| from bad_call_type import BadCall | ||
| @customize | ||
| def call_handler(value, builder): | ||
| if isinstance(value, BadCall): | ||
| return builder.create_call( | ||
| BadCall, | ||
| [], | ||
| ) | ||
| """, | ||
| "test_call.py": """\ | ||
| from inline_snapshot import snapshot | ||
| from bad_call_type import BadCall | ||
| def test_thing(): | ||
| assert BadCall(1) == snapshot(BadCall(4)) | ||
| """, | ||
| } | ||
| ).run_inline( | ||
| ["--inline-snapshot=fix"], | ||
| raises=snapshot( | ||
| """\ | ||
| TypeError: | ||
| can not call CustomCode('BadCall')()\ | ||
| """ | ||
| ), | ||
| ) |
+12
-0
| <a id='changelog-0.32.3'></a> | ||
| # 0.32.3 — 2026-02-24 | ||
| ## Changed | ||
| - Improved performance of `snapshot()` by using lazy evaluation. | ||
| ## Fixed | ||
| - Fixed code generation for Call objects created with `builder.create_call()`. | ||
| - use the handler for datetime types only for the concrete types and not the subclasses. | ||
| <a id='changelog-0.32.2'></a> | ||
@@ -3,0 +15,0 @@ # 0.32.2 — 2026-02-21 |
+1
-1
| Metadata-Version: 2.4 | ||
| Name: inline-snapshot | ||
| Version: 0.32.2 | ||
| Version: 0.32.3 | ||
| Summary: golden master/snapshot/approval testing library which puts the values right into your source code | ||
@@ -5,0 +5,0 @@ Project-URL: Changelog, https://15r10nk.github.io/inline-snapshot/latest/changelog/ |
| import ast | ||
| import sys | ||
| from dataclasses import dataclass | ||
| from typing import Optional | ||
| from functools import cached_property | ||
| from types import FrameType | ||
| from typing import cast | ||
| from executing import Source | ||
| from inline_snapshot._source_file import SourceFile | ||
@@ -14,8 +19,53 @@ | ||
| @dataclass | ||
| class AdapterContext: | ||
| file: SourceFile | ||
| frame: Optional[FrameContext] | ||
| qualname: str | ||
| _frame: FrameType | ||
| def __init__(self, frame: FrameType): | ||
| self._frame = frame | ||
| @cached_property | ||
| def expr(self): | ||
| return Source.executing(self._frame) | ||
| @property | ||
| def source(self) -> Source: | ||
| return cast( | ||
| Source, | ||
| getattr(self.expr, "source", None) if self.expr is not None else None, | ||
| ) | ||
| @property | ||
| def file(self): | ||
| return SourceFile(self.source) | ||
| @property | ||
| def frame(self) -> FrameContext: | ||
| return FrameContext(globals=self._frame.f_globals, locals=self._frame.f_locals) | ||
| @cached_property | ||
| def local_vars(self): | ||
| """Get local vars from snapshot context.""" | ||
| return { | ||
| var_name: var_value | ||
| for var_name, var_value in self._frame.f_locals.items() | ||
| if "@" not in var_name | ||
| } | ||
| @cached_property | ||
| def global_vars(self): | ||
| """Get global vars from snapshot context.""" | ||
| return { | ||
| var_name: var_value | ||
| for var_name, var_value in self._frame.f_globals.items() | ||
| if "@" not in var_name | ||
| } | ||
| @cached_property | ||
| def qualname(self) -> str: | ||
| if sys.version_info >= (3, 11): | ||
| return self._frame.f_code.co_qualname | ||
| else: | ||
| return self.expr.code_qualname() | ||
| def eval(self, node): | ||
@@ -22,0 +72,0 @@ assert self.frame is not None |
@@ -239,5 +239,6 @@ from __future__ import annotations | ||
| if len(parent_elements) in to_insert: | ||
| new_code += to_insert[len(parent_elements)] | ||
| elements += len(new_code) | ||
| for i, insert_elements in sorted(to_insert.items()): | ||
| if i >= len(parent_elements): | ||
| new_code += insert_elements | ||
| elements += len(new_code) | ||
@@ -340,17 +341,26 @@ if new_code or deleted or elements == 1 or len(parent_elements) <= 1: | ||
| other_call_args: list[CallArg] = [] | ||
| call_args: list[CallArg] = [] | ||
| for change in changes: | ||
| if isinstance(change, CallArg): | ||
| if change.arg_name is not None: | ||
| position = ( | ||
| change.arg_pos | ||
| if change.arg_pos is not None | ||
| else len(parent.args) + len(parent.keywords) | ||
| ) | ||
| to_insert[position].append( | ||
| f"{change.arg_name}={change.new_code}" | ||
| ) | ||
| else: | ||
| assert change.arg_pos is not None | ||
| to_insert[change.arg_pos].append(change.new_code) | ||
| call_args.append(change) | ||
| max_pos = len(parent.args) | ||
| for change in sorted(call_args, key=lambda arg: arg.arg_pos or 0): | ||
| if change.arg_name is None: | ||
| assert change.arg_pos is not None | ||
| to_insert[min(change.arg_pos, max_pos)].append(change.new_code) | ||
| else: | ||
| other_call_args.append(change) | ||
| end_pos = ( | ||
| max([len(parent.args) + len(parent.keywords), *to_insert.keys()]) + 1 | ||
| ) | ||
| for change in other_call_args: | ||
| position = change.arg_pos if change.arg_pos is not None else end_pos | ||
| to_insert[position].append(f"{change.arg_name}={change.new_code}") | ||
| generic_sequence_update( | ||
@@ -357,0 +367,0 @@ source, |
| from __future__ import annotations | ||
| import inspect | ||
| import warnings | ||
@@ -20,2 +21,15 @@ from contextlib import contextmanager | ||
| def adapter_context_for_parent_frame(): | ||
| from inline_snapshot._adapter_context import AdapterContext | ||
| frame = inspect.currentframe() | ||
| assert frame | ||
| frame = frame.f_back | ||
| assert frame | ||
| frame = frame.f_back | ||
| assert frame | ||
| return AdapterContext(frame) | ||
| class HasRepr: | ||
@@ -40,2 +54,3 @@ """This class is used for objects where `__repr__()` returns an non- | ||
| def __eq__(self, other): | ||
| if isinstance(other, HasRepr): | ||
@@ -48,3 +63,3 @@ if other._type is not self._type: | ||
| with mock_repr(None): | ||
| with mock_repr(adapter_context_for_parent_frame()): | ||
| other_repr = value_code_repr(other) | ||
@@ -84,6 +99,3 @@ return other_repr == self._str_repr or other_repr == real_repr(self) | ||
| def code_repr(obj): | ||
| from inline_snapshot._adapter_context import AdapterContext | ||
| context = AdapterContext(None, None, "<qualname>") | ||
| with mock_repr(context): | ||
| with mock_repr(adapter_context_for_parent_frame()): | ||
| return repr(obj) | ||
@@ -94,2 +106,4 @@ | ||
| def mock_repr(context: AdapterContext): | ||
| assert context is not None | ||
| def new_repr(obj): | ||
@@ -96,0 +110,0 @@ from inline_snapshot._customize._builder import Builder |
| from __future__ import annotations | ||
| from dataclasses import dataclass | ||
| from functools import cached_property | ||
| from typing import Any | ||
@@ -50,4 +49,4 @@ from typing import Callable | ||
| builder=self, | ||
| local_vars=self._get_local_vars, | ||
| global_vars=self._get_global_vars, | ||
| local_vars=self._local_vars, | ||
| global_vars=self._global_vars, | ||
| ) | ||
@@ -163,29 +162,11 @@ if r is None: | ||
| @cached_property | ||
| def _get_local_vars(self): | ||
| @property | ||
| def _local_vars(self): | ||
| """Get local vars from snapshot context.""" | ||
| if ( | ||
| self._snapshot_context is not None | ||
| and (frame := self._snapshot_context.frame) is not None | ||
| ): | ||
| return { | ||
| var_name: var_value | ||
| for var_name, var_value in frame.locals.items() | ||
| if "@" not in var_name | ||
| } | ||
| return {} | ||
| return self._snapshot_context.local_vars | ||
| @cached_property | ||
| def _get_global_vars(self): | ||
| @property | ||
| def _global_vars(self): | ||
| """Get global vars from snapshot context.""" | ||
| if ( | ||
| self._snapshot_context is not None | ||
| and (frame := self._snapshot_context.frame) is not None | ||
| ): | ||
| return { | ||
| var_name: var_value | ||
| for var_name, var_value in frame.globals.items() | ||
| if "@" not in var_name | ||
| } | ||
| return {} | ||
| return self._snapshot_context.global_vars | ||
@@ -231,4 +212,4 @@ def _build_import_vars(self, imports): | ||
| # Direct lookup with proper precedence: local > import > global | ||
| if code in self._get_local_vars: | ||
| return CustomCode(self._get_local_vars[code], code, imports) | ||
| if code in self._local_vars: | ||
| return CustomCode(self._local_vars[code], code, imports) | ||
@@ -240,4 +221,4 @@ # Build import vars only if needed | ||
| if code in self._get_global_vars: | ||
| return CustomCode(self._get_global_vars[code], code, imports) | ||
| if code in self._global_vars: | ||
| return CustomCode(self._global_vars[code], code, imports) | ||
@@ -255,6 +236,6 @@ # Try ast.literal_eval for simple literals (fast and safe) | ||
| eval_context = { | ||
| **self._get_global_vars, | ||
| **self._global_vars, | ||
| **import_vars, | ||
| **self._get_local_vars, | ||
| **self._local_vars, | ||
| } | ||
| return CustomCode(eval(code, eval_context), code, imports) |
@@ -60,5 +60,14 @@ from __future__ import annotations | ||
| def _map(self, f): | ||
| return self.function._map(f)( | ||
| *[f(x._map(f)) for x in self.args], | ||
| **{k: f(v._map(f)) for k, v in self.kwargs.items()}, | ||
| ) | ||
| args = [f(x._map(f)) for x in self.args] | ||
| kwargs = {k: f(v._map(f)) for k, v in self.kwargs.items()} | ||
| try: | ||
| return self.function._map(f)( | ||
| *args, | ||
| **kwargs, | ||
| ) | ||
| except Exception as e: | ||
| call_args = args + [f"{k}={v}" for k, v in kwargs.items()] | ||
| raise TypeError( | ||
| f"can not call {self.function}({', '.join(map(str,call_args))})" | ||
| ) from e |
@@ -6,11 +6,6 @@ import ast | ||
| from typing import TypeVar | ||
| from typing import cast | ||
| from executing import Source | ||
| from inline_snapshot._adapter_context import AdapterContext | ||
| from inline_snapshot._adapter_context import FrameContext | ||
| from inline_snapshot._customize._custom_undefined import CustomUndefined | ||
| from inline_snapshot._generator_utils import with_flag | ||
| from inline_snapshot._source_file import SourceFile | ||
| from inline_snapshot._types import SnapshotRefBase | ||
@@ -56,13 +51,4 @@ | ||
| expr = Source.executing(frame) | ||
| context = AdapterContext(frame) | ||
| source = cast(Source, getattr(expr, "source", None) if expr is not None else None) | ||
| context = AdapterContext( | ||
| file=SourceFile(source), | ||
| frame=FrameContext(globals=frame.f_globals, locals=frame.f_locals), | ||
| qualname=expr.code_qualname(), | ||
| ) | ||
| Type.check_context(context) | ||
| if not state().active: | ||
@@ -76,13 +62,14 @@ if obj is undefined: | ||
| Type.check_context(context) | ||
| key = id(frame.f_code), frame.f_lasti | ||
| if key not in state().snapshots: | ||
| node = expr.node | ||
| node = context.expr.node | ||
| if node is None: | ||
| # we can run without knowing of the calling expression but we will not be able to fix code | ||
| new = Type(obj, None, context) | ||
| state().snapshots[key] = Type(obj, None, context) | ||
| else: | ||
| assert isinstance(node, ast.Call) | ||
| new = Type(obj, expr, context) | ||
| new = Type(obj, context.expr, context) | ||
| state().snapshots[key] = new | ||
@@ -89,0 +76,0 @@ else: |
@@ -442,2 +442,3 @@ from __future__ import annotations | ||
| ) | ||
| result_args.append(insert_value) | ||
@@ -470,3 +471,3 @@ # keyword arguments | ||
| to_insert = [] | ||
| insert_pos = 0 | ||
| insert_pos = len(old_value.args) | ||
| for key, new_value_element in new_kwargs.items(): | ||
@@ -513,3 +514,3 @@ if isinstance(new_value_element, CustomDefault): | ||
| node=old_node, | ||
| arg_pos=insert_pos, | ||
| arg_pos=None, | ||
| arg_name=key, | ||
@@ -516,0 +517,0 @@ new_code=new_code, |
@@ -81,3 +81,3 @@ import ast | ||
| def timezone_handler(self, value, builder: Builder): | ||
| if isinstance(value, datetime.timezone): | ||
| if type(value) is datetime.timezone: | ||
| # Handle timezone.utc specially - it's a constant, not a constructor call | ||
@@ -98,3 +98,3 @@ if value == datetime.timezone.utc: | ||
| if isinstance(value, datetime.datetime): | ||
| if type(value) is datetime.datetime: | ||
| return builder.create_call( | ||
@@ -112,3 +112,3 @@ datetime.datetime, | ||
| if isinstance(value, datetime.date): | ||
| if type(value) is datetime.date: | ||
| return builder.create_call( | ||
@@ -118,3 +118,3 @@ datetime.date, [value.year, value.month, value.day] | ||
| if isinstance(value, datetime.time): | ||
| if type(value) is datetime.time: | ||
| return builder.create_call( | ||
@@ -132,3 +132,3 @@ datetime.time, | ||
| if isinstance(value, datetime.timedelta): | ||
| if type(value) is datetime.timedelta: | ||
| return builder.create_call( | ||
@@ -135,0 +135,0 @@ datetime.timedelta, |
@@ -5,5 +5,5 @@ is_insider = False | ||
| __version__ = "0.32.2" | ||
| __version__ = "0.32.3" | ||
| if is_insider: | ||
| __version__ += "." + insider_version |
@@ -8,4 +8,5 @@ import pytest | ||
| @pytest.mark.skipIf( | ||
| is_pytest_compatible, reason="this is only a problem when executing can return None" | ||
| @pytest.mark.skipif( | ||
| is_pytest_compatible(), | ||
| reason="this is only a problem when executing can return None", | ||
| ) | ||
@@ -12,0 +13,0 @@ def test_without_node(): |
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
3277349
0.14%191
0.53%14333
0.94%