Skip to content

Commit f4903fa

Browse files
committed
Some more cases
1 parent 13476f9 commit f4903fa

File tree

2 files changed

+29
-12
lines changed

2 files changed

+29
-12
lines changed

mypy/build.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -860,35 +860,41 @@ def find_module(self, id: str) -> Optional[str]:
860860
return None
861861

862862
def find_modules_recursive(self, module: str) -> List[BuildSource]:
863+
hits = set() # type: Set[str]
864+
result = [] # type: List[BuildSource]
865+
for src in self._find_modules_recursive(module):
866+
if src.module not in hits:
867+
hits.add(src.module)
868+
result.append(src)
869+
return result
870+
871+
def _find_modules_recursive(self, module: str) -> List[BuildSource]:
863872
module_paths = self._find_module(module)
864873

865-
result = []
866-
hits = set() # type: Set[str]
874+
srcs = [] # type: List[BuildSource]
867875
for path in module_paths:
868876
if is_module_path(path) or is_pkg_path(path):
869-
if module not in hits:
870-
result.append(BuildSource(path, module, None))
871-
hits.add(module)
877+
srcs.append(BuildSource(path, module, None))
872878

873879
if is_pkg_path(path):
874880
path = dirname(path)
875-
result += self._traverse_package(module, path)
881+
for submodule in self._find_submodules(module, path):
882+
srcs += self._find_modules_recursive(submodule)
876883
elif is_namespace_path(path):
877-
result += self._traverse_package(module, path)
884+
for submodule in self._find_submodules(module, path):
885+
srcs += self._find_modules_recursive(submodule)
878886

879-
return result
887+
return srcs
880888

881-
def _traverse_package(self, module, path) -> List[BuildSource]:
882-
result = [] # type: List[BuildSource]
889+
def _find_submodules(self, module, path) -> Iterator[str]:
883890
for item in list_dir(path):
884891
if item == '__init__.py' or item == '__init__.pyi':
885892
continue
886893

887894
if item.endswith(tuple(PYTHON_EXTENSIONS)):
888895
item = item.split('.')[0]
889896

890-
result += self.find_modules_recursive(module + '.' + item)
891-
return result
897+
yield module + '.' + item
892898

893899
def _collect_paths(self, paths: List[str], last_comp: str) -> List[str]:
894900
"""

mypy/test/testmodulediscovery.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,14 @@ def test_find_modules_recursive_with_namespace(self) -> None:
125125
m = ModuleDiscovery(['dir1', 'dir2'], namespaces_allowed=True)
126126
srcs = m.find_modules_recursive('mod')
127127
assert_equal([s.module for s in srcs], ['mod.a', 'mod.b'])
128+
129+
def test_find_modules_recursive_with_stubs(self) -> None:
130+
self.files = {
131+
os.path.join('dir1', 'mod', '__init__.py'),
132+
os.path.join('dir1', 'mod', 'a.py'),
133+
os.path.join('dir2', 'mod', '__init__.pyi'),
134+
os.path.join('dir2', 'mod', 'a.pyi'),
135+
}
136+
m = ModuleDiscovery(['dir1', 'dir2'], namespaces_allowed=True)
137+
srcs = m.find_modules_recursive('mod')
138+
assert_equal([s.module for s in srcs], ['mod', 'mod.a'])

0 commit comments

Comments
 (0)