Skip to content

Commit df8996a

Browse files
authored
stubgen: properly sort & add newlines to imports in generated stubs (#462)
1 parent 754e1b2 commit df8996a

10 files changed

+54
-13
lines changed

src/stubgen.py

+36-5
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class and repeatedly call ``.put()`` to register modules or contents within the
5959
import textwrap
6060
import importlib
6161
import importlib.machinery
62+
import importlib.util
6263
import types
6364
import typing
6465
from dataclasses import dataclass
@@ -1089,13 +1090,43 @@ def type_str(self, tp: Union[List[Any], Tuple[Any, ...], Dict[Any, Any], Any]) -
10891090
result = repr(tp)
10901091
return self.simplify_types(result)
10911092

1093+
def check_party(self, module: str) -> Literal[0, 1, 2]:
1094+
"""
1095+
Check source of module
1096+
0 = From stdlib
1097+
1 = From 3rd party package
1098+
2 = From the package being built
1099+
"""
1100+
if module.startswith(".") or module == self.module.__name__.split('.')[0]:
1101+
return 2
1102+
1103+
try:
1104+
spec = importlib.util.find_spec(module)
1105+
except ModuleNotFoundError:
1106+
return 1
1107+
1108+
if spec:
1109+
if spec.origin and "site-packages" in spec.origin:
1110+
return 1
1111+
else:
1112+
return 0
1113+
else:
1114+
return 1
1115+
10921116
def get(self) -> str:
10931117
"""Generate the final stub output"""
10941118
s = ""
1119+
last_party = None
10951120

1096-
for module in sorted(self.imports):
1121+
for module in sorted(self.imports, key=lambda i: str(self.check_party(i)) + i):
10971122
imports = self.imports[module]
10981123
items: List[str] = []
1124+
party = self.check_party(module)
1125+
1126+
if party != last_party:
1127+
if last_party is not None:
1128+
s += "\n"
1129+
last_party = party
10991130

11001131
for (k, v1), v2 in imports.items():
11011132
if k is None:
@@ -1108,15 +1139,16 @@ def get(self) -> str:
11081139
items.append(f"{k} as {v2}")
11091140
else:
11101141
items.append(k)
1111-
1142+
1143+
items = sorted(items)
11121144
if items:
11131145
items_v0 = ", ".join(items)
11141146
items_v0 = f"from {module} import {items_v0}\n"
11151147
items_v1 = "(\n " + ",\n ".join(items) + "\n)"
11161148
items_v1 = f"from {module} import {items_v1}\n"
11171149
s += items_v0 if len(items_v0) <= 70 else items_v1
1118-
if s:
1119-
s += "\n"
1150+
1151+
s += "\n\n"
11201152
s += self.put_abstract_enum_class()
11211153

11221154
# Append the main generated stub
@@ -1335,7 +1367,6 @@ def add_pattern(query: str, lines: List[str]):
13351367

13361368
def main(args: Optional[List[str]] = None) -> None:
13371369
import sys
1338-
import os
13391370

13401371
# Ensure that the current directory is on the path
13411372
if "" not in sys.path and "." not in sys.path:

tests/common.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
is_pypy = platform.python_implementation() == 'PyPy'
66
is_darwin = platform.system() == 'Darwin'
77

8-
def collect():
8+
def collect() -> None:
99
if is_pypy:
10-
for i in range(3):
10+
for _ in range(3):
1111
gc.collect()
1212
else:
1313
gc.collect()

tests/py_stub_test.pyi.ref

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections.abc import Callable
2-
from typing import overload, TypeVar
2+
from typing import TypeVar, overload
3+
34

45
class AClass:
56
__annotations__: dict = {'STATIC_VAR' : int}

tests/test_classes_ext.pyi.ref

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import overload
22

3+
34
class A:
45
def __init__(self, arg: int, /) -> None: ...
56

tests/test_enum_ext.pyi.ref

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import overload
22

3+
34
class _Enum:
45
def __init__(self, arg: object, /) -> None: ...
56
def __repr__(self, /) -> str: ...

tests/test_functions_ext.pyi.ref

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections.abc import Callable
22
import types
3-
from typing import overload, Annotated, Any
3+
from typing import Annotated, Any, overload
4+
45

56
def call_guard_value() -> int: ...
67

tests/test_make_iterator_ext.pyi.ref

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections.abc import Iterator, Mapping
22
from typing import overload
33

4+
45
class IdentityMap:
56
def __init__(self) -> None: ...
67

tests/test_ndarray_ext.pyi.ref

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from numpy.typing import ArrayLike
21
from typing import Annotated, overload
32

3+
from numpy.typing import ArrayLike
4+
5+
46
class Cls:
57
def __init__(self) -> None: ...
68

tests/test_stl_ext.pyi.ref

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from collections.abc import Sequence, Callable, Mapping, Set
1+
from collections.abc import Callable, Mapping, Sequence, Set
22
import os
33
import pathlib
44
from typing import overload
55

6+
67
class ClassWithMovableField:
78
def __init__(self) -> None: ...
89

tests/test_typing_ext.pyi.ref

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from collections.abc import Iterable
2+
from typing import Generic, Optional, Self, TypeAlias, TypeVar
3+
14
from . import submodule as submodule
25
from .submodule import F as F, f as f2
3-
from collections.abc import Iterable
4-
from typing import Self, Optional, TypeAlias, TypeVar, Generic
6+
57

68
# a prefix
79

0 commit comments

Comments
 (0)