-
Notifications
You must be signed in to change notification settings - Fork 853
/
Copy pathstrict_schema.py
162 lines (132 loc) · 5.56 KB
/
strict_schema.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from __future__ import annotations
from typing import Any
from openai import NOT_GIVEN
from typing_extensions import TypeGuard
from .exceptions import UserError
_EMPTY_SCHEMA = {
"additionalProperties": False,
"type": "object",
"properties": {},
"required": [],
}
def ensure_strict_json_schema(
schema: dict[str, Any],
) -> dict[str, Any]:
"""Mutates the given JSON schema to ensure it conforms to the `strict` standard
that the OpenAI API expects.
"""
if schema == {}:
return _EMPTY_SCHEMA
return _ensure_strict_json_schema(schema, path=(), root=schema)
# Adapted from https://github.com/openai/openai-python/blob/main/src/openai/lib/_pydantic.py
def _ensure_strict_json_schema(
json_schema: object,
*,
path: tuple[str, ...],
root: dict[str, object],
) -> dict[str, Any]:
if not is_dict(json_schema):
raise TypeError(f"Expected {json_schema} to be a dictionary; path={path}")
defs = json_schema.get("$defs")
if is_dict(defs):
for def_name, def_schema in defs.items():
_ensure_strict_json_schema(def_schema, path=(*path, "$defs", def_name), root=root)
definitions = json_schema.get("definitions")
if is_dict(definitions):
for definition_name, definition_schema in definitions.items():
_ensure_strict_json_schema(
definition_schema, path=(*path, "definitions", definition_name), root=root
)
typ = json_schema.get("type")
if typ == "object" and "additionalProperties" not in json_schema:
json_schema["additionalProperties"] = False
elif (
typ == "object"
and "additionalProperties" in json_schema
and json_schema["additionalProperties"] is True
):
raise UserError(
"additionalProperties should not be set for object types. This could be because "
"you're using an older version of Pydantic, or because you configured additional "
"properties to be allowed. If you really need this, update the function or output tool "
"to not use a strict schema."
)
# object types
# { 'type': 'object', 'properties': { 'a': {...} } }
properties = json_schema.get("properties")
if is_dict(properties):
json_schema["required"] = list(properties.keys())
json_schema["properties"] = {
key: _ensure_strict_json_schema(prop_schema, path=(*path, "properties", key), root=root)
for key, prop_schema in properties.items()
}
# arrays
# { 'type': 'array', 'items': {...} }
items = json_schema.get("items")
if is_dict(items):
json_schema["items"] = _ensure_strict_json_schema(items, path=(*path, "items"), root=root)
# unions
any_of = json_schema.get("anyOf")
if is_list(any_of):
json_schema["anyOf"] = [
_ensure_strict_json_schema(variant, path=(*path, "anyOf", str(i)), root=root)
for i, variant in enumerate(any_of)
]
# intersections
all_of = json_schema.get("allOf")
if is_list(all_of):
if len(all_of) == 1:
json_schema.update(
_ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0"), root=root)
)
json_schema.pop("allOf")
else:
json_schema["allOf"] = [
_ensure_strict_json_schema(entry, path=(*path, "allOf", str(i)), root=root)
for i, entry in enumerate(all_of)
]
# strip `None` defaults as there's no meaningful distinction here
# the schema will still be `nullable` and the model will default
# to using `None` anyway
if json_schema.get("default", NOT_GIVEN) is None:
json_schema.pop("default")
# we can't use `$ref`s if there are also other properties defined, e.g.
# `{"$ref": "...", "description": "my description"}`
#
# so we unravel the ref
# `{"type": "string", "description": "my description"}`
ref = json_schema.get("$ref")
if ref and has_more_than_n_keys(json_schema, 1):
assert isinstance(ref, str), f"Received non-string $ref - {ref}"
resolved = resolve_ref(root=root, ref=ref)
if not is_dict(resolved):
raise ValueError(
f"Expected `$ref: {ref}` to resolved to a dictionary but got {resolved}"
)
# properties from the json schema take priority over the ones on the `$ref`
json_schema.update({**resolved, **json_schema})
json_schema.pop("$ref")
# Since the schema expanded from `$ref` might not have `additionalProperties: false` applied
# we call `_ensure_strict_json_schema` again to fix the inlined schema and ensure it's valid
return _ensure_strict_json_schema(json_schema, path=path, root=root)
return json_schema
def resolve_ref(*, root: dict[str, object], ref: str) -> object:
if not ref.startswith("#/"):
raise ValueError(f"Unexpected $ref format {ref!r}; Does not start with #/")
path = ref[2:].split("/")
resolved = root
for key in path:
value = resolved[key]
assert is_dict(value), (
f"encountered non-dictionary entry while resolving {ref} - {resolved}"
)
resolved = value
return resolved
def is_dict(obj: object) -> TypeGuard[dict[str, object]]:
# just pretend that we know there are only `str` keys
# as that check is not worth the performance cost
return isinstance(obj, dict)
def is_list(obj: object) -> TypeGuard[list[object]]:
return isinstance(obj, list)
def has_more_than_n_keys(obj: dict[str, object], n: int) -> bool:
return len(obj) > n