Skip to content

Commit

Permalink
Add de-serialization for tagged union types
Browse files Browse the repository at this point in the history
  • Loading branch information
hunyadi committed Nov 16, 2022
1 parent 55f870d commit a2262a3
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 2 deletions.
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,35 @@ class Study:

Here, the two properties of `Study` (`left` and `right`) will refer to the same subtype `#/definitions/Image`.

## Union types

Serializing a union type entails serializing the active member type.

De-serializing discriminated (tagged) union types is based on a disjoint set of property values with type annotation `Literal[...]`. Consider the following example:

```python
@dataclass
class ClassA:
name: Literal["A", "a"]
value: str


@dataclass
class ClassB:
name: Literal["B", "b"]
value: str
```

Here, JSON representations of `ClassA` and `ClassB` are indistinguishable based on property names alone. However, the property `name` for `ClassA` can only take values `"A"` and `"a"`, and property `name` for `ClassB` can only take values `"B"` and `"b"`, hence a JSON object such as
```json
{ "name": "A", "value": "string" }
```
uniquely identifies `ClassA`, and can never match `ClassB`. The de-serializer can instantiate the appropriate class, and populate properties of the newly created instance.

Tagged union types must have at least one property of a literal type, and the values for that type must be all different.

When de-serializing regular union types that have no type tags, the first successfully matching type is selected. It is a parse error if all union member types have been exhausted without a finding match.

## Name mangling

If a Python class has a property augmented with an underscore (`_`) as per [PEP 8](https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles) to avoid conflict with a Python keyword (e.g. `for` or `in`), the underscore is removed when reading from or writing to JSON.
95 changes: 93 additions & 2 deletions strong_typing/deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,24 @@
import inspect
import typing
import uuid
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Set, Tuple, Type, Union

from .core import JsonType
from .exception import JsonKeyError, JsonTypeError, JsonValueError
from .inspection import (
create_object,
enum_value_types,
get_class_properties,
get_class_property,
get_resolved_hints,
is_dataclass_instance,
is_dataclass_type,
is_named_tuple_type,
is_type_annotated,
is_type_literal,
is_type_optional,
unwrap_annotated_type,
unwrap_literal_values,
unwrap_optional_type,
)
from .mapping import python_field_to_json_property
Expand Down Expand Up @@ -312,7 +315,91 @@ def parse(self, data: JsonType) -> Any:
)


def get_literal_properties(typ: type) -> Set[str]:
"Returns the names of all properties in a class that are of a literal type."

return set(
property_name
for property_name, property_type in get_class_properties(typ)
if is_type_literal(property_type)
)


def get_discriminating_properties(types: Tuple[type, ...]) -> Set[str]:
"Returns a set of properties with literal type that are common across all specified classes."

if not types or not all(isinstance(typ, type) for typ in types):
return set()

props = get_literal_properties(types[0])
for typ in types[1:]:
props = props & get_literal_properties(typ)

return props


class TaggedUnionDeserializer(Deserializer):
"De-serializes a JSON value with one or more disambiguating properties into a Python union type."

member_types: Tuple[type, ...]
disambiguating_properties: Set[str]
member_parsers: Dict[Tuple[str, Any], Deserializer]

def __init__(self, member_types: Tuple[type, ...]) -> None:
self.member_types = member_types
self.disambiguating_properties = get_discriminating_properties(member_types)
self.member_parsers = {}
for member_type in member_types:
for property_name in self.disambiguating_properties:
literal_type = get_class_property(member_type, property_name)
if not literal_type:
continue

for literal_value in unwrap_literal_values(literal_type):
tpl = (property_name, literal_value)
if tpl in self.member_parsers:
raise JsonTypeError(
f"disambiguating property `{property_name}` in type `{self.union_type}` has a duplicate value: {literal_value}"
)

self.member_parsers[tpl] = create_deserializer(member_type)

@property
def union_type(self) -> str:
type_names = ", ".join(
python_type_to_str(member_type) for member_type in self.member_types
)
return f"Union[{type_names}]"

def parse(self, data: JsonType) -> Any:
if not isinstance(data, dict):
raise JsonTypeError(
f"tagged union type `{self.union_type}` expects JSON `object` data but instead received: {data}"
)

for property_name in self.disambiguating_properties:
disambiguating_value = data.get(property_name)
if disambiguating_value is None:
continue

member_parser = self.member_parsers.get(
(property_name, disambiguating_value)
)
if member_parser is None:
raise JsonTypeError(
f"disambiguating property value is invalid for tagged union type `{self.union_type}`: {data}"
)

return member_parser.parse(data)

raise JsonTypeError(
f"disambiguating property value is missing for tagged union type `{self.union_type}`: {data}"
)


class LiteralDeserializer(Deserializer):
"De-serializes a JSON value into a Python literal type."

values: Tuple[Any, ...]
parser: Deserializer

Expand Down Expand Up @@ -698,7 +785,11 @@ def _create_deserializer(typ: type) -> Deserializer:
elif origin_type is tuple:
return TupleDeserializer(typing.get_args(typ))
elif origin_type is Union:
return UnionDeserializer(typing.get_args(typ))
union_args = typing.get_args(typ)
if get_discriminating_properties(union_args):
return TaggedUnionDeserializer(union_args)
else:
return UnionDeserializer(union_args)
elif origin_type is Literal:
return LiteralDeserializer(typing.get_args(typ))

Expand Down
24 changes: 24 additions & 0 deletions strong_typing/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,21 @@ def is_type_literal(typ: type) -> bool:
return typing.get_origin(typ) is Literal


def unwrap_literal_value(typ: type) -> Any:
"""
Extracts the single constant value captured by a literal type.
:param typ: The literal type `Literal[value]`.
:returns: The values captured by the literal type.
"""

args = unwrap_literal_values(typ)
if len(args) != 1:
raise TypeError("too many values in literal type")

return args[0]


def unwrap_literal_values(typ: type) -> Tuple[Any, ...]:
"""
Extracts the constant values captured by a literal type.
Expand Down Expand Up @@ -356,6 +371,15 @@ def get_class_properties(typ: type) -> Iterable[Tuple[str, type]]:
return resolved_hints.items()


def get_class_property(typ: type, name: str) -> Optional[type]:
"Looks up the annotated type of a property in a class by its property name."

for property_name, property_type in get_class_properties(typ):
if name == property_name:
return property_type
return None


def get_referenced_types(typ: type) -> List[type]:
"""
Extracts types indirectly referenced by this type.
Expand Down
20 changes: 20 additions & 0 deletions tests/sample_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,23 @@ def __init__(self):
"b": ValueExample(value=4),
"c": ValueExample(value=5),
}


@dataclass
class ClassA:
name: Literal["A", "a"]
type: Literal["A"]
value: str


@dataclass
class ClassB:
name: Literal["B", "b"]
type: Literal["B"]
value: str


@dataclass
class ClassC:
name: Literal["C", "c"]
type: Literal["C"]
30 changes: 30 additions & 0 deletions tests/test_deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,36 @@ def test_deserialization_union(self):
SimpleDataclass(int_value=2004),
)

# class types with literal-based disambiguation
self.assertEqual(
json_to_object(
Union[ClassA, ClassB, ClassC],
{"type": "A", "name": "A", "value": "string"},
),
ClassA(name="A", type="A", value="string"),
)
self.assertEqual(
json_to_object(
Union[ClassA, ClassB, ClassC],
{"type": "B", "name": "B", "value": "string"},
),
ClassB(name="B", type="B", value="string"),
)
self.assertEqual(
json_to_object(
Union[ClassA, ClassB, ClassC],
{"type": "A", "name": "a", "value": "string"},
),
ClassA(name="a", type="A", value="string"),
)
self.assertEqual(
json_to_object(
Union[ClassA, ClassB, ClassC],
{"type": "B", "name": "b", "value": "string"},
),
ClassB(name="b", type="B", value="string"),
)

def test_object_deserialization(self):
"""Test composition and inheritance with object de-serialization."""

Expand Down

0 comments on commit a2262a3

Please sign in to comment.