From fc00e55113474ff483afb43dd3024a0c30f23c9a Mon Sep 17 00:00:00 2001 From: Chris Guidry Date: Wed, 22 Feb 2023 08:54:41 -0500 Subject: [PATCH] JSON codec installer should be nestable. --- src/measured/json.py | 54 ++++++++++--------- ..._pydantic.py => test_pydantic_and_json.py} | 17 ++++++ 2 files changed, 47 insertions(+), 24 deletions(-) rename tests/{test_pydantic.py => test_pydantic_and_json.py} (91%) diff --git a/src/measured/json.py b/src/measured/json.py index fc7bd22..1fd5105 100644 --- a/src/measured/json.py +++ b/src/measured/json.py @@ -56,38 +56,44 @@ def object_hook(self, o: Dict[str, Any]) -> Union[MeasuredType, Dict[str, Any]]: def codecs_installed() -> Generator[None, None, None]: """A context within which the standard library's `json` module will be aware of how to encode and decode `measured` objects.""" - original_encoder = json._default_encoder # type: ignore - original_decoder = json._default_decoder # type: ignore - original_object_hook = json.loads.__kwdefaults__["object_hook"] + outermost_context = False + if not isinstance(json._default_encoder, MeasuredJSONEncoder): # type: ignore + outermost_context = True - encoder = MeasuredJSONEncoder() - decoder = MeasuredJSONDecoder() + if outermost_context: + original_encoder = json._default_encoder # type: ignore + original_decoder = json._default_decoder # type: ignore + original_object_hook = json.loads.__kwdefaults__["object_hook"] - json._default_encoder = encoder # type: ignore - json._default_decoder = decoder # type: ignore - json.loads.__kwdefaults__["object_hook"] = decoder.object_hook + encoder = MeasuredJSONEncoder() + decoder = MeasuredJSONDecoder() - try: - from pydantic.json import ENCODERS_BY_TYPE as PYDANTIC_ENCODERS_BY_TYPE - except ImportError: # pragma: no cover - PYDANTIC_ENCODERS_BY_TYPE = {} + json._default_encoder = encoder # type: ignore + json._default_decoder = decoder # type: ignore + json.loads.__kwdefaults__["object_hook"] = decoder.object_hook + + try: + from pydantic.json import ENCODERS_BY_TYPE as PYDANTIC_ENCODERS_BY_TYPE + except ImportError: # pragma: no cover + PYDANTIC_ENCODERS_BY_TYPE = {} - PYDANTIC_ENCODERS_BY_TYPE[Dimension] = encoder - PYDANTIC_ENCODERS_BY_TYPE[Prefix] = encoder - PYDANTIC_ENCODERS_BY_TYPE[Unit] = encoder - PYDANTIC_ENCODERS_BY_TYPE[Quantity] = encoder + PYDANTIC_ENCODERS_BY_TYPE[Dimension] = encoder + PYDANTIC_ENCODERS_BY_TYPE[Prefix] = encoder + PYDANTIC_ENCODERS_BY_TYPE[Unit] = encoder + PYDANTIC_ENCODERS_BY_TYPE[Quantity] = encoder try: yield finally: - json._default_encoder = original_encoder # type: ignore - json._default_decoder = original_decoder # type: ignore - json.loads.__kwdefaults__["object_hook"] = original_object_hook - - del PYDANTIC_ENCODERS_BY_TYPE[Dimension] - del PYDANTIC_ENCODERS_BY_TYPE[Prefix] - del PYDANTIC_ENCODERS_BY_TYPE[Unit] - del PYDANTIC_ENCODERS_BY_TYPE[Quantity] + if outermost_context: + json._default_encoder = original_encoder # type: ignore + json._default_decoder = original_decoder # type: ignore + json.loads.__kwdefaults__["object_hook"] = original_object_hook + + del PYDANTIC_ENCODERS_BY_TYPE[Dimension] + del PYDANTIC_ENCODERS_BY_TYPE[Prefix] + del PYDANTIC_ENCODERS_BY_TYPE[Unit] + del PYDANTIC_ENCODERS_BY_TYPE[Quantity] _installer = codecs_installed() diff --git a/tests/test_pydantic.py b/tests/test_pydantic_and_json.py similarity index 91% rename from tests/test_pydantic.py rename to tests/test_pydantic_and_json.py index d2fc7d4..504387d 100644 --- a/tests/test_pydantic.py +++ b/tests/test_pydantic_and_json.py @@ -259,3 +259,20 @@ async def test_parent_api_roundtrip(client: AsyncClient, parent: ParentModel) -> response = await client.post("/parent", content=parent.json()) assert response.status_code == 200 assert ParentModel.parse_raw(response.text) == parent + + +def test_codec_installation_is_nestable() -> None: + with pytest.raises(TypeError, match="not JSON serializable"): + json.dumps(Length) + + with measured.json.codecs_installed(): + assert json.loads(json.dumps(Length)) is Length + with measured.json.codecs_installed(): + assert json.loads(json.dumps(Length)) is Length + with measured.json.codecs_installed(): + assert json.loads(json.dumps(Length)) is Length + assert json.loads(json.dumps(Length)) is Length + assert json.loads(json.dumps(Length)) is Length + + with pytest.raises(TypeError, match="not JSON serializable"): + json.dumps(Length)