Skip to content

Commit 98d4e59

Browse files
chore(internal): support serialising iterable types (#1127)
1 parent 1ecf8f6 commit 98d4e59

File tree

5 files changed

+55
-3
lines changed

5 files changed

+55
-3
lines changed

src/openai/_utils/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
is_mapping as is_mapping,
1010
is_tuple_t as is_tuple_t,
1111
parse_date as parse_date,
12+
is_iterable as is_iterable,
1213
is_sequence as is_sequence,
1314
coerce_float as coerce_float,
1415
is_mapping_t as is_mapping_t,
@@ -33,6 +34,7 @@
3334
is_list_type as is_list_type,
3435
is_union_type as is_union_type,
3536
extract_type_arg as extract_type_arg,
37+
is_iterable_type as is_iterable_type,
3638
is_required_type as is_required_type,
3739
is_annotated_type as is_annotated_type,
3840
strip_annotated_type as strip_annotated_type,

src/openai/_utils/_transform.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
from ._utils import (
1010
is_list,
1111
is_mapping,
12+
is_iterable,
1213
)
1314
from ._typing import (
1415
is_list_type,
1516
is_union_type,
1617
extract_type_arg,
18+
is_iterable_type,
1719
is_required_type,
1820
is_annotated_type,
1921
strip_annotated_type,
@@ -157,7 +159,12 @@ def _transform_recursive(
157159
if is_typeddict(stripped_type) and is_mapping(data):
158160
return _transform_typeddict(data, stripped_type)
159161

160-
if is_list_type(stripped_type) and is_list(data):
162+
if (
163+
# List[T]
164+
(is_list_type(stripped_type) and is_list(data))
165+
# Iterable[T]
166+
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
167+
):
161168
inner_type = extract_type_arg(stripped_type, 0)
162169
return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
163170

src/openai/_utils/_typing.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

3-
from typing import Any, TypeVar, cast
3+
from typing import Any, TypeVar, Iterable, cast
4+
from collections import abc as _c_abc
45
from typing_extensions import Required, Annotated, get_args, get_origin
56

67
from .._types import InheritsGeneric
@@ -15,6 +16,12 @@ def is_list_type(typ: type) -> bool:
1516
return (get_origin(typ) or typ) == list
1617

1718

19+
def is_iterable_type(typ: type) -> bool:
20+
"""If the given type is `typing.Iterable[T]`"""
21+
origin = get_origin(typ) or typ
22+
return origin == Iterable or origin == _c_abc.Iterable
23+
24+
1825
def is_union_type(typ: type) -> bool:
1926
return _is_union(get_origin(typ))
2027

src/openai/_utils/_utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,10 @@ def is_list(obj: object) -> TypeGuard[list[object]]:
164164
return isinstance(obj, list)
165165

166166

167+
def is_iterable(obj: object) -> TypeGuard[Iterable[object]]:
168+
return isinstance(obj, Iterable)
169+
170+
167171
def deepcopy_minimal(item: _T) -> _T:
168172
"""Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
169173

tests/test_transform.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any, List, Union, Optional
3+
from typing import Any, List, Union, Iterable, Optional, cast
44
from datetime import date, datetime
55
from typing_extensions import Required, Annotated, TypedDict
66

@@ -265,3 +265,35 @@ def test_pydantic_default_field() -> None:
265265
assert model.with_none_default == "bar"
266266
assert model.with_str_default == "baz"
267267
assert transform(model, Any) == {"with_none_default": "bar", "with_str_default": "baz"}
268+
269+
270+
class TypedDictIterableUnion(TypedDict):
271+
foo: Annotated[Union[Bar8, Iterable[Baz8]], PropertyInfo(alias="FOO")]
272+
273+
274+
class Bar8(TypedDict):
275+
foo_bar: Annotated[str, PropertyInfo(alias="fooBar")]
276+
277+
278+
class Baz8(TypedDict):
279+
foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")]
280+
281+
282+
def test_iterable_of_dictionaries() -> None:
283+
assert transform({"foo": [{"foo_baz": "bar"}]}, TypedDictIterableUnion) == {"FOO": [{"fooBaz": "bar"}]}
284+
assert cast(Any, transform({"foo": ({"foo_baz": "bar"},)}, TypedDictIterableUnion)) == {"FOO": [{"fooBaz": "bar"}]}
285+
286+
def my_iter() -> Iterable[Baz8]:
287+
yield {"foo_baz": "hello"}
288+
yield {"foo_baz": "world"}
289+
290+
assert transform({"foo": my_iter()}, TypedDictIterableUnion) == {"FOO": [{"fooBaz": "hello"}, {"fooBaz": "world"}]}
291+
292+
293+
class TypedDictIterableUnionStr(TypedDict):
294+
foo: Annotated[Union[str, Iterable[Baz8]], PropertyInfo(alias="FOO")]
295+
296+
297+
def test_iterable_union_str() -> None:
298+
assert transform({"foo": "bar"}, TypedDictIterableUnionStr) == {"FOO": "bar"}
299+
assert cast(Any, transform(iter([{"foo_baz": "bar"}]), Union[str, Iterable[Baz8]])) == [{"fooBaz": "bar"}]

0 commit comments

Comments
 (0)