Skip to content

Commit e8e381f

Browse files
authored
Merge pull request #408 from airtai/master
Update pydantic to v2 and update datamodel-code-generator to 0.25.6
2 parents fb88eba + f11f922 commit e8e381f

File tree

30 files changed

+740
-960
lines changed

30 files changed

+740
-960
lines changed

.github/workflows/test.yml

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,24 @@ jobs:
1010
strategy:
1111
fail-fast: false
1212
matrix:
13-
python-version: [3.7, 3.8, 3.9]
13+
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
1414
os: [ubuntu-latest, windows-latest, macos-latest]
1515

1616
steps:
17-
- uses: actions/checkout@v1
18-
- uses: actions/cache@v1
17+
- uses: actions/checkout@v4
18+
- uses: actions/cache@v4
1919
with:
2020
path: ~/.cache/pip
2121
key: ${{ runner.os }}-pip-${{ hashFiles('**/poetry.lock') }}
2222
restore-keys: |
2323
${{ runner.os }}-pip-
2424
- name: Set up Python ${{ matrix.python-version }}
25-
uses: actions/setup-python@v1
25+
uses: actions/setup-python@v5
2626
with:
2727
python-version: ${{ matrix.python-version }}
2828
- name: Install dependencies
2929
run: |
30+
python -m pip install --upgrade pip --disable-pip-version-check
3031
python -m pip install poetry
3132
poetry install
3233
- name: Lint
@@ -43,9 +44,21 @@ jobs:
4344
./scripts/poetry_test.bat
4445
- name: Upload coverage to Codecov
4546
if: matrix.os == 'ubuntu-latest'
46-
uses: codecov/codecov-action@v1
47+
uses: codecov/codecov-action@v4
4748
with:
4849
token: ${{ secrets.CODECOV_TOKEN }}
4950
file: ./coverage.xml
5051
flags: unittests
51-
# fail_ci_if_error: true
52+
check: # This job does nothing and is only used for the branch protection
53+
if: github.event.pull_request.draft == false
54+
55+
needs:
56+
- test
57+
58+
runs-on: ubuntu-latest
59+
60+
steps:
61+
- name: Decide whether the needed jobs succeeded or failed
62+
uses: re-actors/alls-green@release/v1 # nosemgrep
63+
with:
64+
jobs: ${{ toJSON(needs) }}

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
repos:
22
- repo: https://github.com/psf/black
3-
rev: 23.7.0
3+
rev: 24.4.2
44
hooks:
55
- id: black
66
files: "^fastapi_code_generator|^tests"
77
exclude: "^tests/data"
88
- repo: https://github.com/pycqa/isort
9-
rev: 5.12.0
9+
rev: 5.13.2
1010
hooks:
1111
- id: isort
1212
files: "^fastapi_code_generator|^tests"

fastapi_code_generator/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def main(
7777
output_dir,
7878
template_dir,
7979
model_path,
80-
enum_field_as_literal,
80+
enum_field_as_literal, # type: ignore[arg-type]
8181
custom_visitors=custom_visitors,
8282
disable_timestamp=disable_timestamp,
8383
generate_routers=generate_routers,
@@ -131,7 +131,7 @@ def generate_code(
131131
BUILTIN_MODULAR_TEMPLATE_DIR if generate_routers else BUILTIN_TEMPLATE_DIR
132132
)
133133
if enum_field_as_literal:
134-
parser = OpenAPIParser(input_text, enum_field_as_literal=enum_field_as_literal)
134+
parser = OpenAPIParser(input_text, enum_field_as_literal=enum_field_as_literal) # type: ignore[arg-type]
135135
else:
136136
parser = OpenAPIParser(input_text)
137137
with chdir(output_dir):

fastapi_code_generator/parser.py

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,12 @@
2525
LiteralType,
2626
OpenAPIScope,
2727
PythonVersion,
28-
cached_property,
2928
snooper_to_methods,
3029
)
3130
from datamodel_code_generator.imports import Import, Imports
3231
from datamodel_code_generator.model import DataModel, DataModelFieldBase
3332
from datamodel_code_generator.model import pydantic as pydantic_model
34-
from datamodel_code_generator.model.pydantic import DataModelField
33+
from datamodel_code_generator.model.pydantic import CustomRootType, DataModelField
3534
from datamodel_code_generator.parser.jsonschema import JsonSchemaObject
3635
from datamodel_code_generator.parser.openapi import MediaObject
3736
from datamodel_code_generator.parser.openapi import OpenAPIParser as OpenAPIModelParser
@@ -43,7 +42,8 @@
4342
ResponseObject,
4443
)
4544
from datamodel_code_generator.types import DataType, DataTypeManager, StrictTypes
46-
from pydantic import BaseModel
45+
from datamodel_code_generator.util import cached_property
46+
from pydantic import BaseModel, ValidationInfo
4747

4848
RE_APPLICATION_JSON_PATTERN: Pattern[str] = re.compile(r'^application/.*json$')
4949

@@ -72,7 +72,7 @@ def __get_validators__(cls) -> Any:
7272
yield cls.validate
7373

7474
@classmethod
75-
def validate(cls, v: Any) -> Any:
75+
def validate(cls, v: Any, info: ValidationInfo) -> Any:
7676
return cls(v)
7777

7878
@property
@@ -91,8 +91,8 @@ def camelcase(self) -> str:
9191
class Argument(CachedPropertyModel):
9292
name: UsefulStr
9393
type_hint: UsefulStr
94-
default: Optional[UsefulStr]
95-
default_value: Optional[UsefulStr]
94+
default: Optional[UsefulStr] = None
95+
default_value: Optional[UsefulStr] = None
9696
required: bool
9797

9898
def __str__(self) -> str:
@@ -108,20 +108,20 @@ def argument(self) -> str:
108108
class Operation(CachedPropertyModel):
109109
method: UsefulStr
110110
path: UsefulStr
111-
operationId: Optional[UsefulStr]
112-
description: Optional[str]
113-
summary: Optional[str]
111+
operationId: Optional[UsefulStr] = None
112+
description: Optional[str] = None
113+
summary: Optional[str] = None
114114
parameters: List[Dict[str, Any]] = []
115115
responses: Dict[UsefulStr, Any] = {}
116116
deprecated: bool = False
117117
imports: List[Import] = []
118118
security: Optional[List[Dict[str, List[str]]]] = None
119-
tags: Optional[List[str]]
119+
tags: Optional[List[str]] = []
120120
arguments: str = ''
121121
snake_case_arguments: str = ''
122122
request: Optional[Argument] = None
123123
response: str = ''
124-
additional_responses: Dict[str, Dict[str, str]] = {}
124+
additional_responses: Dict[Union[str, int], Dict[str, str]] = {}
125125
return_type: str = ''
126126

127127
@cached_property
@@ -245,16 +245,22 @@ def parse_info(self) -> Optional[Dict[str, Any]]:
245245
result['servers'] = servers
246246
return result or None
247247

248-
def parse_parameters(self, parameters: ParameterObject, path: List[str]) -> None:
249-
super().parse_parameters(parameters, path)
250-
self._temporary_operation['_parameters'].append(parameters)
248+
def parse_all_parameters(
249+
self,
250+
name: str,
251+
parameters: List[Union[ReferenceObject, ParameterObject]],
252+
path: List[str],
253+
) -> None:
254+
super().parse_all_parameters(name, parameters, path)
255+
self._temporary_operation['_parameters'].extend(parameters)
251256

252257
def get_parameter_type(
253258
self,
254-
parameters: ParameterObject,
259+
parameters: Union[ReferenceObject, ParameterObject],
255260
snake_case: bool,
256261
path: List[str],
257262
) -> Optional[Argument]:
263+
parameters = self.resolve_object(parameters, ParameterObject)
258264
orig_name = parameters.name
259265
if snake_case:
260266
name = stringcase.snakecase(parameters.name)
@@ -274,7 +280,10 @@ def get_parameter_type(
274280
if not data_type:
275281
if not schema:
276282
schema = parameters.schema_
283+
if schema is None:
284+
raise RuntimeError("schema is None") # pragma: no cover
277285
data_type = self.parse_schema(name, schema, [*path, name])
286+
data_type = self._collapse_root_model(data_type)
278287
if not schema:
279288
return None
280289

@@ -290,16 +299,18 @@ def get_parameter_type(
290299
self.imports_for_fastapi.append(
291300
Import(from_='fastapi', import_=param_is)
292301
)
293-
default: Optional[
294-
str
295-
] = f"{param_is}({'...' if field.required else repr(schema.default)}, alias='{orig_name}')"
302+
default: Optional[str] = (
303+
f"{param_is}({'...' if field.required else repr(schema.default)}, alias='{orig_name}')"
304+
)
296305
else:
297306
default = repr(schema.default) if schema.has_default else None
298307
self.imports_for_fastapi.append(field.imports)
299308
self.data_types.append(field.data_type)
309+
if field.name is None:
310+
raise RuntimeError("field.name is None") # pragma: no cover
300311
return Argument(
301-
name=field.name,
302-
type_hint=field.type_hint,
312+
name=UsefulStr(field.name),
313+
type_hint=UsefulStr(field.type_hint),
303314
default=default, # type: ignore
304315
default_value=schema.default,
305316
required=field.required,
@@ -361,11 +372,12 @@ def parse_request_body(
361372
data_type = self.parse_schema(
362373
name, media_obj.schema_, [*path, media_type]
363374
)
375+
data_type = self._collapse_root_model(data_type)
364376
arguments.append(
365377
# TODO: support multiple body
366378
Argument(
367379
name='body', # type: ignore
368-
type_hint=data_type.type_hint,
380+
type_hint=UsefulStr(data_type.type_hint),
369381
required=request_body.required,
370382
)
371383
)
@@ -406,17 +418,18 @@ def parse_request_body(
406418
)
407419
self._temporary_operation['_request'] = arguments[0] if arguments else None
408420

409-
def parse_responses(
421+
def parse_responses( # type: ignore[override]
410422
self,
411423
name: str,
412424
responses: Dict[str, Union[ResponseObject, ReferenceObject]],
413425
path: List[str],
414-
) -> Dict[str, Dict[str, DataType]]:
415-
data_types = super().parse_responses(name, responses, path)
426+
) -> Dict[Union[str, int], Dict[str, DataType]]:
427+
data_types = super().parse_responses(name, responses, path) # type: ignore[arg-type]
416428
status_code_200 = data_types.get('200')
417429
if status_code_200:
418430
data_type = list(status_code_200.values())[0]
419431
if data_type:
432+
data_type = self._collapse_root_model(data_type)
420433
self.data_types.append(data_type)
421434
else:
422435
data_type = DataType(type='None')
@@ -466,3 +479,24 @@ def parse_operation(
466479
path=f'/{path_name}', # type: ignore
467480
method=method, # type: ignore
468481
)
482+
483+
def _collapse_root_model(self, data_type: DataType) -> DataType:
484+
reference = data_type.reference
485+
import functools
486+
487+
if not (
488+
reference
489+
and (
490+
len(reference.children) == 1
491+
or functools.reduce(lambda a, b: a == b, reference.children)
492+
)
493+
):
494+
return data_type
495+
source = reference.source
496+
if not isinstance(source, CustomRootType):
497+
return data_type
498+
data_type.remove_reference()
499+
data_type = source.fields[0].data_type
500+
if source in self.results:
501+
self.results.remove(source)
502+
return data_type

0 commit comments

Comments
 (0)