25
25
LiteralType ,
26
26
OpenAPIScope ,
27
27
PythonVersion ,
28
- cached_property ,
29
28
snooper_to_methods ,
30
29
)
31
30
from datamodel_code_generator .imports import Import , Imports
32
31
from datamodel_code_generator .model import DataModel , DataModelFieldBase
33
32
from datamodel_code_generator .model import pydantic as pydantic_model
34
- from datamodel_code_generator .model .pydantic import DataModelField
33
+ from datamodel_code_generator .model .pydantic import CustomRootType , DataModelField
35
34
from datamodel_code_generator .parser .jsonschema import JsonSchemaObject
36
35
from datamodel_code_generator .parser .openapi import MediaObject
37
36
from datamodel_code_generator .parser .openapi import OpenAPIParser as OpenAPIModelParser
43
42
ResponseObject ,
44
43
)
45
44
from datamodel_code_generator .types import DataType , DataTypeManager , StrictTypes
46
- from pydantic import BaseModel
45
+ from datamodel_code_generator .util import cached_property
46
+ from pydantic import BaseModel , ValidationInfo
47
47
48
48
RE_APPLICATION_JSON_PATTERN : Pattern [str ] = re .compile (r'^application/.*json$' )
49
49
@@ -72,7 +72,7 @@ def __get_validators__(cls) -> Any:
72
72
yield cls .validate
73
73
74
74
@classmethod
75
- def validate (cls , v : Any ) -> Any :
75
+ def validate (cls , v : Any , info : ValidationInfo ) -> Any :
76
76
return cls (v )
77
77
78
78
@property
@@ -91,8 +91,8 @@ def camelcase(self) -> str:
91
91
class Argument (CachedPropertyModel ):
92
92
name : UsefulStr
93
93
type_hint : UsefulStr
94
- default : Optional [UsefulStr ]
95
- default_value : Optional [UsefulStr ]
94
+ default : Optional [UsefulStr ] = None
95
+ default_value : Optional [UsefulStr ] = None
96
96
required : bool
97
97
98
98
def __str__ (self ) -> str :
@@ -108,20 +108,20 @@ def argument(self) -> str:
108
108
class Operation (CachedPropertyModel ):
109
109
method : UsefulStr
110
110
path : UsefulStr
111
- operationId : Optional [UsefulStr ]
112
- description : Optional [str ]
113
- summary : Optional [str ]
111
+ operationId : Optional [UsefulStr ] = None
112
+ description : Optional [str ] = None
113
+ summary : Optional [str ] = None
114
114
parameters : List [Dict [str , Any ]] = []
115
115
responses : Dict [UsefulStr , Any ] = {}
116
116
deprecated : bool = False
117
117
imports : List [Import ] = []
118
118
security : Optional [List [Dict [str , List [str ]]]] = None
119
- tags : Optional [List [str ]]
119
+ tags : Optional [List [str ]] = []
120
120
arguments : str = ''
121
121
snake_case_arguments : str = ''
122
122
request : Optional [Argument ] = None
123
123
response : str = ''
124
- additional_responses : Dict [str , Dict [str , str ]] = {}
124
+ additional_responses : Dict [Union [ str , int ] , Dict [str , str ]] = {}
125
125
return_type : str = ''
126
126
127
127
@cached_property
@@ -245,16 +245,22 @@ def parse_info(self) -> Optional[Dict[str, Any]]:
245
245
result ['servers' ] = servers
246
246
return result or None
247
247
248
- def parse_parameters (self , parameters : ParameterObject , path : List [str ]) -> None :
249
- super ().parse_parameters (parameters , path )
250
- self ._temporary_operation ['_parameters' ].append (parameters )
248
+ def parse_all_parameters (
249
+ self ,
250
+ name : str ,
251
+ parameters : List [Union [ReferenceObject , ParameterObject ]],
252
+ path : List [str ],
253
+ ) -> None :
254
+ super ().parse_all_parameters (name , parameters , path )
255
+ self ._temporary_operation ['_parameters' ].extend (parameters )
251
256
252
257
def get_parameter_type (
253
258
self ,
254
- parameters : ParameterObject ,
259
+ parameters : Union [ ReferenceObject , ParameterObject ] ,
255
260
snake_case : bool ,
256
261
path : List [str ],
257
262
) -> Optional [Argument ]:
263
+ parameters = self .resolve_object (parameters , ParameterObject )
258
264
orig_name = parameters .name
259
265
if snake_case :
260
266
name = stringcase .snakecase (parameters .name )
@@ -274,7 +280,10 @@ def get_parameter_type(
274
280
if not data_type :
275
281
if not schema :
276
282
schema = parameters .schema_
283
+ if schema is None :
284
+ raise RuntimeError ("schema is None" ) # pragma: no cover
277
285
data_type = self .parse_schema (name , schema , [* path , name ])
286
+ data_type = self ._collapse_root_model (data_type )
278
287
if not schema :
279
288
return None
280
289
@@ -290,16 +299,18 @@ def get_parameter_type(
290
299
self .imports_for_fastapi .append (
291
300
Import (from_ = 'fastapi' , import_ = param_is )
292
301
)
293
- default : Optional [
294
- str
295
- ] = f" { param_is } ( { '...' if field . required else repr ( schema . default ) } , alias=' { orig_name } ')"
302
+ default : Optional [str ] = (
303
+ f" { param_is } ( { '...' if field . required else repr ( schema . default ) } , alias=' { orig_name } ')"
304
+ )
296
305
else :
297
306
default = repr (schema .default ) if schema .has_default else None
298
307
self .imports_for_fastapi .append (field .imports )
299
308
self .data_types .append (field .data_type )
309
+ if field .name is None :
310
+ raise RuntimeError ("field.name is None" ) # pragma: no cover
300
311
return Argument (
301
- name = field .name ,
302
- type_hint = field .type_hint ,
312
+ name = UsefulStr ( field .name ) ,
313
+ type_hint = UsefulStr ( field .type_hint ) ,
303
314
default = default , # type: ignore
304
315
default_value = schema .default ,
305
316
required = field .required ,
@@ -361,11 +372,12 @@ def parse_request_body(
361
372
data_type = self .parse_schema (
362
373
name , media_obj .schema_ , [* path , media_type ]
363
374
)
375
+ data_type = self ._collapse_root_model (data_type )
364
376
arguments .append (
365
377
# TODO: support multiple body
366
378
Argument (
367
379
name = 'body' , # type: ignore
368
- type_hint = data_type .type_hint ,
380
+ type_hint = UsefulStr ( data_type .type_hint ) ,
369
381
required = request_body .required ,
370
382
)
371
383
)
@@ -406,17 +418,18 @@ def parse_request_body(
406
418
)
407
419
self ._temporary_operation ['_request' ] = arguments [0 ] if arguments else None
408
420
409
- def parse_responses (
421
+ def parse_responses ( # type: ignore[override]
410
422
self ,
411
423
name : str ,
412
424
responses : Dict [str , Union [ResponseObject , ReferenceObject ]],
413
425
path : List [str ],
414
- ) -> Dict [str , Dict [str , DataType ]]:
415
- data_types = super ().parse_responses (name , responses , path )
426
+ ) -> Dict [Union [ str , int ] , Dict [str , DataType ]]:
427
+ data_types = super ().parse_responses (name , responses , path ) # type: ignore[arg-type]
416
428
status_code_200 = data_types .get ('200' )
417
429
if status_code_200 :
418
430
data_type = list (status_code_200 .values ())[0 ]
419
431
if data_type :
432
+ data_type = self ._collapse_root_model (data_type )
420
433
self .data_types .append (data_type )
421
434
else :
422
435
data_type = DataType (type = 'None' )
@@ -466,3 +479,24 @@ def parse_operation(
466
479
path = f'/{ path_name } ' , # type: ignore
467
480
method = method , # type: ignore
468
481
)
482
+
483
+ def _collapse_root_model (self , data_type : DataType ) -> DataType :
484
+ reference = data_type .reference
485
+ import functools
486
+
487
+ if not (
488
+ reference
489
+ and (
490
+ len (reference .children ) == 1
491
+ or functools .reduce (lambda a , b : a == b , reference .children )
492
+ )
493
+ ):
494
+ return data_type
495
+ source = reference .source
496
+ if not isinstance (source , CustomRootType ):
497
+ return data_type
498
+ data_type .remove_reference ()
499
+ data_type = source .fields [0 ].data_type
500
+ if source in self .results :
501
+ self .results .remove (source )
502
+ return data_type
0 commit comments