@@ -20,19 +20,23 @@ def get_param_type_string(self) -> str:
2020 """Get type string used by client codegen"""
2121 return self .get_type_string ()
2222
23- def get_imports (self ) -> Set [str ]:
24- """Get schema needed imports for creating the schema """
23+ def get_model_imports (self ) -> Set [str ]:
24+ """Get schema needed imports for model codegen """
2525 return set ()
2626
27+ def get_type_imports (self ) -> Set [str ]:
28+ """Get schema needed imports for types codegen"""
29+ return self .get_model_imports ()
30+
2731 def get_param_imports (self ) -> Set [str ]:
28- """Get schema needed imports for typing params """
29- return self .get_imports ()
32+ """Get schema needed imports for client param codegen """
33+ return self .get_model_imports ()
3034
3135 def get_using_imports (self ) -> Set [str ]:
32- """Get schema needed imports for using """
33- return self .get_imports ()
36+ """Get schema needed imports for client request codegen """
37+ return self .get_model_imports ()
3438
35- def get_default_args (self ) -> Dict [str , str ]:
39+ def _get_default_args (self ) -> Dict [str , str ]:
3640 """Get pydantic field info args"""
3741 default = self .default
3842 args = {}
@@ -59,18 +63,26 @@ def get_type_string(self) -> str:
5963 return type_string if self .required else f"Union[Unset, { type_string } ]"
6064
6165 def get_param_type_string (self ) -> str :
66+ """Get type string used by client codegen"""
6267 type_string = self .schema_data .get_param_type_string ()
6368 return type_string if self .required else f"Union[Unset, { type_string } ]"
6469
65- def get_imports (self ) -> Set [str ]:
66- """Get schema needed imports for creating the schema """
67- imports = self .schema_data .get_imports ()
70+ def get_model_imports (self ) -> Set [str ]:
71+ """Get schema needed imports for model codegen """
72+ imports = self .schema_data .get_model_imports ()
6873 imports .add ("from pydantic import Field" )
6974 if not self .required :
7075 imports .add ("from typing import Union" )
7176 imports .add ("from githubkit.utils import UNSET, Unset" )
7277 return imports
7378
79+ def get_type_imports (self ) -> Set [str ]:
80+ """Get schema needed imports for type codegen"""
81+ imports = self .schema_data .get_type_imports ()
82+ if not self .required :
83+ imports .add ("from typing_extensions import NotRequired" )
84+ return imports
85+
7486 def get_param_imports (self ) -> Set [str ]:
7587 """Get schema needed imports for typing params"""
7688 imports = self .schema_data .get_param_imports ()
@@ -79,8 +91,8 @@ def get_param_imports(self) -> Set[str]:
7991 imports .add ("from githubkit.utils import UNSET, Unset" )
8092 return imports
8193
82- def get_default_string (self ) -> str :
83- args = self .schema_data .get_default_args ()
94+ def _get_default_string (self ) -> str :
95+ args = self .schema_data ._get_default_args ()
8496 if "default" not in args and "default_factory" not in args :
8597 args ["default" ] = "..." if self .required else "UNSET"
8698 if self .prop_name != self .name :
@@ -103,12 +115,12 @@ def get_param_defination(self) -> str:
103115 def get_model_defination (self ) -> str :
104116 """Get defination used by model codegen"""
105117 type_ = self .get_type_string ()
106- default = self .get_default_string ()
118+ default = self ._get_default_string ()
107119 return f"{ self .prop_name } : { type_ } = { default } "
108120
109121 def get_type_defination (self ) -> str :
110122 """Get defination used by types codegen"""
111- type_ = self .get_param_type_string ()
123+ type_ = self .schema_data . get_param_type_string ()
112124 return (
113125 f"{ self .prop_name } : { type_ if self .required else f'NotRequired[{ type_ } ]' } "
114126 )
@@ -117,8 +129,8 @@ def get_type_defination(self) -> str:
117129class AnySchema (SchemaData ):
118130 _type_string : ClassVar [str ] = "Any"
119131
120- def get_imports (self ) -> Set [str ]:
121- imports = super ().get_imports ()
132+ def get_model_imports (self ) -> Set [str ]:
133+ imports = super ().get_model_imports ()
122134 imports .add ("from typing import Any" )
123135 return imports
124136
@@ -140,8 +152,8 @@ class IntSchema(SchemaData):
140152
141153 _type_string : ClassVar [str ] = "int"
142154
143- def get_default_args (self ) -> Dict [str , str ]:
144- args = super ().get_default_args ()
155+ def _get_default_args (self ) -> Dict [str , str ]:
156+ args = super ()._get_default_args ()
145157 if self .multiple_of is not None :
146158 args ["multiple_of" ] = repr (self .multiple_of )
147159 if self .maximum is not None :
@@ -164,8 +176,8 @@ class FloatSchema(SchemaData):
164176
165177 _type_string : ClassVar [str ] = "float"
166178
167- def get_default_args (self ) -> Dict [str , str ]:
168- args = super ().get_default_args ()
179+ def _get_default_args (self ) -> Dict [str , str ]:
180+ args = super ()._get_default_args ()
169181 if self .multiple_of is not None :
170182 args ["multiple_of" ] = str (self .multiple_of )
171183 if self .maximum is not None :
@@ -186,8 +198,8 @@ class StringSchema(SchemaData):
186198
187199 _type_string : ClassVar [str ] = "str"
188200
189- def get_default_args (self ) -> Dict [str , str ]:
190- args = super ().get_default_args ()
201+ def _get_default_args (self ) -> Dict [str , str ]:
202+ args = super ()._get_default_args ()
191203 if self .min_length is not None :
192204 args ["min_length" ] = str (self .min_length )
193205 if self .max_length is not None :
@@ -200,26 +212,26 @@ def get_default_args(self) -> Dict[str, str]:
200212class DateTimeSchema (SchemaData ):
201213 _type_string : ClassVar [str ] = "datetime"
202214
203- def get_imports (self ) -> Set [str ]:
204- imports = super ().get_imports ()
215+ def get_model_imports (self ) -> Set [str ]:
216+ imports = super ().get_model_imports ()
205217 imports .add ("from datetime import datetime" )
206218 return imports
207219
208220
209221class DateSchema (SchemaData ):
210222 _type_string : ClassVar [str ] = "date"
211223
212- def get_imports (self ) -> Set [str ]:
213- imports = super ().get_imports ()
224+ def get_model_imports (self ) -> Set [str ]:
225+ imports = super ().get_model_imports ()
214226 imports .add ("from datetime import date" )
215227 return imports
216228
217229
218230class FileSchema (SchemaData ):
219231 _type_string : ClassVar [str ] = "FileTypes"
220232
221- def get_imports (self ) -> Set [str ]:
222- imports = super ().get_imports ()
233+ def get_model_imports (self ) -> Set [str ]:
234+ imports = super ().get_model_imports ()
223235 imports .add ("from githubkit.typing import FileTypes" )
224236 return imports
225237
@@ -236,10 +248,15 @@ def get_type_string(self) -> str:
236248 def get_param_type_string (self ) -> str :
237249 return f"List[{ self .item_schema .get_param_type_string ()} ]"
238250
239- def get_imports (self ) -> Set [str ]:
240- imports = super ().get_imports ()
251+ def get_model_imports (self ) -> Set [str ]:
252+ imports = super ().get_model_imports ()
241253 imports .add ("from typing import List" )
242- imports .update (self .item_schema .get_imports ())
254+ imports .update (self .item_schema .get_model_imports ())
255+ return imports
256+
257+ def get_type_imports (self ) -> Set [str ]:
258+ imports = {"from typing import List" }
259+ imports .update (self .item_schema .get_type_imports ())
243260 return imports
244261
245262 def get_param_imports (self ) -> Set [str ]:
@@ -252,8 +269,8 @@ def get_using_imports(self) -> Set[str]:
252269 imports .update (self .item_schema .get_using_imports ())
253270 return imports
254271
255- def get_default_args (self ) -> Dict [str , str ]:
256- args = super ().get_default_args ()
272+ def _get_default_args (self ) -> Dict [str , str ]:
273+ args = super ()._get_default_args ()
257274 # FIXME: remove list constraints due to forwardref not supported
258275 # See https://github.com/samuelcolvin/pydantic/issues/3745
259276 if isinstance (self .item_schema , (ModelSchema , UnionSchema )):
@@ -282,8 +299,8 @@ def is_float_enum(self) -> bool:
282299 def get_type_string (self ) -> str :
283300 return f"Literal[{ ', ' .join (repr (value ) for value in self .values )} ]"
284301
285- def get_imports (self ) -> Set [str ]:
286- imports = super ().get_imports ()
302+ def get_model_imports (self ) -> Set [str ]:
303+ imports = super ().get_model_imports ()
287304 imports .add ("from typing import Literal" )
288305 return imports
289306
@@ -299,13 +316,19 @@ def get_type_string(self) -> str:
299316 def get_param_type_string (self ) -> str :
300317 return f"{ self .class_name } Type"
301318
302- def get_imports (self ) -> Set [str ]:
303- imports = super ().get_imports ()
319+ def get_model_imports (self ) -> Set [str ]:
320+ imports = super ().get_model_imports ()
304321 imports .add ("from pydantic import BaseModel" )
305322 if self .allow_extra :
306323 imports .add ("from pydantic import Extra" )
307324 for prop in self .properties :
308- imports .update (prop .get_imports ())
325+ imports .update (prop .get_model_imports ())
326+ return imports
327+
328+ def get_type_imports (self ) -> Set [str ]:
329+ imports = {"from typing_extensions import TypedDict" }
330+ for prop in self .properties :
331+ imports .update (prop .get_type_imports ())
309332 return imports
310333
311334 def get_param_imports (self ) -> Set [str ]:
@@ -331,11 +354,17 @@ def get_param_type_string(self) -> str:
331354 return self .schemas [0 ].get_param_type_string ()
332355 return f"Union[{ ', ' .join (schema .get_param_type_string () for schema in self .schemas )} ]"
333356
334- def get_imports (self ) -> Set [str ]:
335- imports = super ().get_imports ()
357+ def get_model_imports (self ) -> Set [str ]:
358+ imports = super ().get_model_imports ()
336359 imports .add ("from typing import Union" )
337360 for schema in self .schemas :
338- imports .update (schema .get_imports ())
361+ imports .update (schema .get_model_imports ())
362+ return imports
363+
364+ def get_type_imports (self ) -> Set [str ]:
365+ imports = {"from typing import Union" }
366+ for schema in self .schemas :
367+ imports .update (schema .get_type_imports ())
339368 return imports
340369
341370 def get_param_imports (self ) -> Set [str ]:
@@ -350,11 +379,11 @@ def get_using_imports(self) -> Set[str]:
350379 imports .update (schema .get_using_imports ())
351380 return imports
352381
353- def get_default_args (self ) -> Dict [str , str ]:
382+ def _get_default_args (self ) -> Dict [str , str ]:
354383 args = {}
355384 for schema in self .schemas :
356- args .update (schema .get_default_args ())
357- args .update (super ().get_default_args ())
385+ args .update (schema ._get_default_args ())
386+ args .update (super ()._get_default_args ())
358387 if self .discriminator :
359388 args ["discriminator" ] = self .discriminator
360389 return args
0 commit comments