@@ -59,6 +59,7 @@ class and repeatedly call ``.put()`` to register modules or contents within the
59
59
import textwrap
60
60
import importlib
61
61
import importlib .machinery
62
+ import importlib .util
62
63
import types
63
64
import typing
64
65
from dataclasses import dataclass
@@ -1089,13 +1090,43 @@ def type_str(self, tp: Union[List[Any], Tuple[Any, ...], Dict[Any, Any], Any]) -
1089
1090
result = repr (tp )
1090
1091
return self .simplify_types (result )
1091
1092
1093
+ def check_party (self , module : str ) -> Literal [0 , 1 , 2 ]:
1094
+ """
1095
+ Check source of module
1096
+ 0 = From stdlib
1097
+ 1 = From 3rd party package
1098
+ 2 = From the package being built
1099
+ """
1100
+ if module .startswith ("." ) or module == self .module .__name__ .split ('.' )[0 ]:
1101
+ return 2
1102
+
1103
+ try :
1104
+ spec = importlib .util .find_spec (module )
1105
+ except ModuleNotFoundError :
1106
+ return 1
1107
+
1108
+ if spec :
1109
+ if spec .origin and "site-packages" in spec .origin :
1110
+ return 1
1111
+ else :
1112
+ return 0
1113
+ else :
1114
+ return 1
1115
+
1092
1116
def get (self ) -> str :
1093
1117
"""Generate the final stub output"""
1094
1118
s = ""
1119
+ last_party = None
1095
1120
1096
- for module in sorted (self .imports ):
1121
+ for module in sorted (self .imports , key = lambda i : str ( self . check_party ( i )) + i ):
1097
1122
imports = self .imports [module ]
1098
1123
items : List [str ] = []
1124
+ party = self .check_party (module )
1125
+
1126
+ if party != last_party :
1127
+ if last_party is not None :
1128
+ s += "\n "
1129
+ last_party = party
1099
1130
1100
1131
for (k , v1 ), v2 in imports .items ():
1101
1132
if k is None :
@@ -1108,15 +1139,16 @@ def get(self) -> str:
1108
1139
items .append (f"{ k } as { v2 } " )
1109
1140
else :
1110
1141
items .append (k )
1111
-
1142
+
1143
+ items = sorted (items )
1112
1144
if items :
1113
1145
items_v0 = ", " .join (items )
1114
1146
items_v0 = f"from { module } import { items_v0 } \n "
1115
1147
items_v1 = "(\n " + ",\n " .join (items ) + "\n )"
1116
1148
items_v1 = f"from { module } import { items_v1 } \n "
1117
1149
s += items_v0 if len (items_v0 ) <= 70 else items_v1
1118
- if s :
1119
- s += "\n "
1150
+
1151
+ s += "\n \n "
1120
1152
s += self .put_abstract_enum_class ()
1121
1153
1122
1154
# Append the main generated stub
@@ -1335,7 +1367,6 @@ def add_pattern(query: str, lines: List[str]):
1335
1367
1336
1368
def main (args : Optional [List [str ]] = None ) -> None :
1337
1369
import sys
1338
- import os
1339
1370
1340
1371
# Ensure that the current directory is on the path
1341
1372
if "" not in sys .path and "." not in sys .path :
0 commit comments