Skip to content

Commit b6ea407

Browse files
committed
Added PydanticInputObjectType and tests.
1 parent 1d35ba3 commit b6ea407

File tree

11 files changed

+997
-50
lines changed

11 files changed

+997
-50
lines changed

graphene_pydantic/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .inputobjecttype.types import PydanticInputObjectType
12
from .objecttype.types import PydanticObjectType
23

3-
__all__ = ["PydanticObjectType"]
4+
__all__ = ["PydanticObjectType", "PydanticInputObjectType"]

graphene_pydantic/inputobjecttype/__init__.py

Whitespace-only changes.
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
import collections
2+
import collections.abc
3+
import datetime
4+
import decimal
5+
import enum
6+
import sys
7+
import typing as T
8+
import uuid
9+
10+
from graphene import UUID, Boolean, Enum, Float, InputField, Int, List, String, Union
11+
from graphene.types.base import BaseType
12+
from graphene.types.datetime import Date, DateTime, Time
13+
from pydantic import BaseModel
14+
from pydantic.fields import Field as PydanticField
15+
16+
from ..util import construct_union_class_name
17+
from .registry import Registry
18+
19+
try:
20+
# Pydantic pre-1.0
21+
from pydantic.fields import Shape
22+
23+
SHAPE_SINGLETON = (Shape.SINGLETON,)
24+
SHAPE_SEQUENTIAL = (
25+
Shape.LIST,
26+
Shape.TUPLE,
27+
Shape.TUPLE_ELLIPS,
28+
Shape.SEQUENCE,
29+
Shape.SET,
30+
)
31+
SHAPE_MAPPING = (Shape.MAPPING,)
32+
except ImportError:
33+
# Pydantic 1.0+
34+
from pydantic import fields
35+
36+
SHAPE_SINGLETON = (fields.SHAPE_SINGLETON,)
37+
SHAPE_SEQUENTIAL = (
38+
fields.SHAPE_LIST,
39+
fields.SHAPE_TUPLE,
40+
fields.SHAPE_TUPLE_ELLIPSIS,
41+
fields.SHAPE_SEQUENCE,
42+
fields.SHAPE_SET,
43+
)
44+
SHAPE_MAPPING = (fields.SHAPE_MAPPING,)
45+
46+
47+
try:
48+
from graphene.types.decimal import Decimal as GrapheneDecimal
49+
50+
DECIMAL_SUPPORTED = True
51+
except ImportError: # pragma: no cover
52+
# graphene 2.1.5+ is required for Decimals
53+
DECIMAL_SUPPORTED = False
54+
55+
56+
NONE_TYPE = None.__class__ # need to do this because mypy complains about type(None)
57+
58+
59+
class ConversionError(TypeError):
60+
pass
61+
62+
63+
def convert_pydantic_field(
64+
field: PydanticField,
65+
registry: Registry,
66+
parent_type: T.Type = None,
67+
model: T.Type[BaseModel] = None,
68+
**field_kwargs,
69+
) -> InputField:
70+
"""
71+
Convert a Pydantic model field into a Graphene type field that we can add
72+
to the generated Graphene data model type.
73+
"""
74+
declared_type = getattr(field, "type_", None)
75+
field_kwargs.setdefault(
76+
"type",
77+
convert_pydantic_type(
78+
declared_type, field, registry, parent_type=parent_type, model=model
79+
),
80+
)
81+
field_kwargs.setdefault("required", field.required)
82+
field_kwargs.setdefault("default_value", field.default)
83+
# TODO: find a better way to get a field's description. Some ideas include:
84+
# - hunt down the description from the field's schema, or the schema
85+
# from the field's base model
86+
# - maybe even (Sphinx-style) parse attribute documentation
87+
field_kwargs.setdefault("description", field.__doc__)
88+
89+
return InputField(**field_kwargs)
90+
91+
92+
def convert_pydantic_type(
93+
type_: T.Type,
94+
field: PydanticField,
95+
registry: Registry = None,
96+
parent_type: T.Type = None,
97+
model: T.Type[BaseModel] = None,
98+
) -> BaseType: # noqa: C901
99+
"""
100+
Convert a Pydantic type to a Graphene Field type, including not just the
101+
native Python type but any additional metadata (e.g. shape) that Pydantic
102+
knows about.
103+
"""
104+
graphene_type = find_graphene_type(
105+
type_, field, registry, parent_type=parent_type, model=model
106+
)
107+
if field.shape in SHAPE_SINGLETON:
108+
return graphene_type
109+
elif field.shape in SHAPE_SEQUENTIAL:
110+
# TODO: _should_ Sets remain here?
111+
return List(graphene_type)
112+
elif field.shape in SHAPE_MAPPING:
113+
raise ConversionError(f"Don't know how to handle mappings in Graphene.")
114+
115+
116+
def find_graphene_type(
117+
type_: T.Type,
118+
field: PydanticField,
119+
registry: Registry = None,
120+
parent_type: T.Type = None,
121+
model: T.Type[BaseModel] = None,
122+
) -> BaseType: # noqa: C901
123+
"""
124+
Map a native Python type to a Graphene-supported Field type, where possible,
125+
throwing an error if we don't know what to map it to.
126+
"""
127+
if type_ == uuid.UUID:
128+
return UUID
129+
elif type_ in (str, bytes):
130+
return String
131+
elif type_ == datetime.datetime:
132+
return DateTime
133+
elif type_ == datetime.date:
134+
return Date
135+
elif type_ == datetime.time:
136+
return Time
137+
elif type_ == bool:
138+
return Boolean
139+
elif type_ == float:
140+
return Float
141+
elif type_ == decimal.Decimal:
142+
return GrapheneDecimal if DECIMAL_SUPPORTED else Float
143+
elif type_ == int:
144+
return Int
145+
elif type_ in (tuple, list, set):
146+
# TODO: do Sets really belong here?
147+
return List
148+
elif registry and registry.get_type_for_model(type_):
149+
return registry.get_type_for_model(type_)
150+
elif registry and isinstance(type_, BaseModel):
151+
# If it's a Pydantic model that hasn't yet been wrapped with a ObjectType,
152+
# we can put a placeholder in and request that `resolve_placeholders()`
153+
# be called to update it.
154+
registry.add_placeholder_for_model(type_)
155+
# NOTE: this has to come before any `issubclass()` checks, because annotated
156+
# generic types aren't valid arguments to `issubclass`
157+
elif hasattr(type_, "__origin__"):
158+
return convert_generic_python_type(
159+
type_, field, registry, parent_type=parent_type, model=model
160+
)
161+
elif isinstance(type_, T.ForwardRef):
162+
# A special case! We have to do a little hackery to try and resolve
163+
# the type that this points to, by trying to reference a "sibling" type
164+
# to where this was defined so we can get access to that namespace...
165+
sibling = model or parent_type
166+
if not sibling:
167+
raise ConversionError(
168+
"Don't know how to convert the Pydantic field "
169+
f"{field!r} ({field.type_}), could not resolve "
170+
"the forward reference. Did you call `resolve_placeholders()`? "
171+
"See the README for more on forward references."
172+
)
173+
module_ns = sys.modules[sibling.__module__].__dict__
174+
resolved = type_._evaluate(module_ns, None)
175+
# TODO: make this behavior optional. maybe this is a place for the TypeOptions to play a role?
176+
if registry:
177+
registry.add_placeholder_for_model(resolved)
178+
return find_graphene_type(
179+
resolved, field, registry, parent_type=parent_type, model=model
180+
)
181+
elif issubclass(type_, enum.Enum):
182+
return Enum.from_enum(type_)
183+
else:
184+
raise ConversionError(
185+
f"Don't know how to convert the Pydantic field {field!r} ({field.type_})"
186+
)
187+
188+
189+
def convert_generic_python_type(
190+
type_: T.Type,
191+
field: PydanticField,
192+
registry: Registry = None,
193+
parent_type: T.Type = None,
194+
model: T.Type[BaseModel] = None,
195+
) -> BaseType: # noqa: C901
196+
"""
197+
Convert annotated Python generic types into the most appropriate Graphene
198+
Field type -- e.g. turn `typing.Union` into a Graphene Union.
199+
"""
200+
origin = type_.__origin__
201+
if not origin: # pragma: no cover # this really should be impossible
202+
raise ConversionError(f"Don't know how to convert type {type_!r} ({field})")
203+
204+
# NOTE: This is a little clumsy, but working with generic types is; it's hard to
205+
# decide whether the origin type is a subtype of, say, T.Iterable since typical
206+
# Python functions like `isinstance()` don't work
207+
if origin == T.Union:
208+
return convert_union_type(
209+
type_, field, registry, parent_type=parent_type, model=model
210+
)
211+
elif origin in (
212+
T.Tuple,
213+
T.List,
214+
T.Set,
215+
T.Collection,
216+
T.Iterable,
217+
list,
218+
set,
219+
) or issubclass(origin, collections.abc.Sequence):
220+
# TODO: find a better way of divining that the origin is sequence-like
221+
inner_types = getattr(type_, "__args__", [])
222+
if not inner_types: # pragma: no cover # this really should be impossible
223+
raise ConversionError(
224+
f"Don't know how to handle {type_} (generic: {origin})"
225+
)
226+
# Of course, we can only return a homogeneous type here, so we pick the
227+
# first of the wrapped types
228+
inner_type = inner_types[0]
229+
return List(
230+
find_graphene_type(
231+
inner_type, field, registry, parent_type=parent_type, model=model
232+
)
233+
)
234+
elif origin in (T.Dict, T.Mapping, collections.OrderedDict, dict) or issubclass(
235+
origin, collections.abc.Mapping
236+
):
237+
raise ConversionError("Don't know how to handle mappings in Graphene")
238+
else:
239+
raise ConversionError(f"Don't know how to handle {type_} (generic: {origin})")
240+
241+
242+
def convert_union_type(
243+
type_: T.Type,
244+
field: PydanticField,
245+
registry: Registry = None,
246+
parent_type: T.Type = None,
247+
model: T.Type[BaseModel] = None,
248+
):
249+
"""
250+
Convert an annotated Python Union type into a Graphene Union.
251+
"""
252+
inner_types = type_.__args__
253+
# We use a little metaprogramming -- create our own unique
254+
# subclass of graphene.Union that knows its constituent Graphene types
255+
parent_types = tuple(
256+
find_graphene_type(x, field, registry, parent_type=parent_type, model=model)
257+
for x in inner_types
258+
if x != NONE_TYPE
259+
)
260+
261+
# This is effectively a typing.Optional[T], which decomposes into a
262+
# typing.Union[None, T] -- we can return the Graphene type for T directly
263+
# since Pydantic will have already parsed it as optional
264+
if len(parent_types) == 1:
265+
return parent_types[0]
266+
267+
internal_meta_cls = type("Meta", (), {"types": parent_types})
268+
269+
union_cls = type(
270+
construct_union_class_name(inner_types), (Union,), {"Meta": internal_meta_cls}
271+
)
272+
return union_cls
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import typing as T
2+
from collections import defaultdict
3+
4+
from pydantic import BaseModel
5+
from pydantic.fields import Field
6+
7+
if T.TYPE_CHECKING: # pragma: no cover
8+
from .types import PydanticInputObjectType
9+
10+
11+
def assert_is_pydantic_input_object_type(obj_type: T.Type["PydanticInputObjectType"]):
12+
"""An object in this registry must be a PydanticInputObjectType."""
13+
from .types import PydanticInputObjectType
14+
15+
if not isinstance(obj_type, type) or not issubclass(
16+
obj_type, PydanticInputObjectType
17+
):
18+
raise TypeError(f"Expected PydanticInputObjectType, but got: {obj_type!r}")
19+
20+
21+
class Placeholder:
22+
def __init__(self, model: T.Type[BaseModel]):
23+
self.model = model
24+
25+
def __repr__(self):
26+
return f"{self.__class__.__name__}({self.model})"
27+
28+
29+
class Registry:
30+
"""Hold information about Pydantic models and how they (and their fields) map to Graphene types."""
31+
32+
def __init__(self):
33+
self._registry = {}
34+
self._registry_models = {}
35+
self._registry_object_fields = defaultdict(dict)
36+
37+
def register(self, obj_type: T.Type["PydanticInputObjectType"]):
38+
assert_is_pydantic_input_object_type(obj_type)
39+
40+
assert (
41+
obj_type._meta.registry == self
42+
), "Can't register models linked to another Registry"
43+
self._registry[obj_type._meta.model] = obj_type
44+
45+
def get_type_for_model(self, model: T.Type[BaseModel]) -> "PydanticInputObjectType":
46+
return self._registry.get(model)
47+
48+
def add_placeholder_for_model(self, model: T.Type[BaseModel]):
49+
if model in self._registry:
50+
return
51+
self._registry[model] = Placeholder(model)
52+
53+
def register_object_field(
54+
self,
55+
obj_type: T.Type["PydanticInputObjectType"],
56+
field_name: str,
57+
obj_field: Field,
58+
model: T.Type[BaseModel] = None,
59+
):
60+
assert_is_pydantic_input_object_type(obj_type)
61+
62+
if not field_name or not isinstance(field_name, str): # pragma: no cover
63+
raise TypeError(f"Expected a field name, but got: {field_name!r}")
64+
self._registry_object_fields[obj_type][field_name] = obj_field
65+
66+
def get_object_field_for_graphene_field(
67+
self, obj_type: "PydanticInputObjectType", field_name: str
68+
) -> Field:
69+
return self._registry_object_fields.get(obj_type, {}).get(field_name)
70+
71+
72+
registry: T.Optional[Registry] = None
73+
74+
75+
def get_global_registry() -> Registry:
76+
"""Return a global instance of Registry for common use."""
77+
global registry
78+
if not registry:
79+
registry = Registry()
80+
return registry
81+
82+
83+
def reset_global_registry():
84+
"""Clear the global instance of the registry."""
85+
global registry
86+
registry = None

0 commit comments

Comments
 (0)