Skip to content
This repository was archived by the owner on Jun 10, 2020. It is now read-only.

Commit a4857d4

Browse files
FuegoFroshoyer
authored andcommitted
Ensure stubs are valid for Python 2 and fix running of tests (#19)
* Ensure stubs are valid for Python 2 and fix running of tests The stubs contained an unconditional reference to SupportsBytes, which only exists in Python 3. To make these valid on Python 2, conditionally import that Protocol in Python 3 and otherwise use a dummy class in Python 2. Also have `ndarray` extend `Contains`, while we're here. This also extends the test suites to run all tests against both Python 2 and Python 3, with the ability to specify that certain tests should only be run against Python 3 (eg to test Python 3 exclusive operators). This should help prevent errors like this moving forward. One downside of this is that flake8 doesn't understand the `# type:` comments, so it thinks that imports from `typing` are unused. A workaround for this is to add `# noqa: F401` at the end of the relevant imports, though this is a bit tedious. Finally, change how test requirements are installed and how the `numpy-stubs` package is exposed to mypy, and update the README/Travis file to reflect this. See python/mypy#5007 for more details about the rational behind this change. * Split `pip install .` out of the `test-requirements.txt` file, update Travis and README files accordingly.
1 parent c9d28a2 commit a4857d4

File tree

5 files changed

+95
-36
lines changed

5 files changed

+95
-36
lines changed

numpy-stubs/__init__.pyi

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,32 @@
11
import builtins
2+
import sys
23

4+
from numpy.core._internal import _ctypes
35
from typing import (
4-
Any, Dict, Iterable, List, Optional, Mapping, Sequence, Sized,
5-
SupportsInt, SupportsFloat, SupportsComplex, SupportsBytes, SupportsAbs,
6-
Text, Tuple, Type, TypeVar, Union,
6+
Any,
7+
Container,
8+
Dict,
9+
Iterable,
10+
List,
11+
Mapping,
12+
Optional,
13+
Sequence,
14+
Sized,
15+
SupportsAbs,
16+
SupportsComplex,
17+
SupportsFloat,
18+
SupportsInt,
19+
Text,
20+
Tuple,
21+
Type,
22+
TypeVar,
23+
Union,
724
)
825

9-
import sys
10-
11-
from numpy.core._internal import _ctypes
26+
if sys.version_info[0] < 3:
27+
class SupportsBytes: ...
28+
else:
29+
from typing import SupportsBytes
1230

1331
_Shape = Tuple[int, ...]
1432

@@ -325,7 +343,7 @@ class _ArrayOrScalarCommon(SupportsInt, SupportsFloat, SupportsComplex,
325343
def __getattr__(self, name) -> Any: ...
326344

327345

328-
class ndarray(_ArrayOrScalarCommon, Iterable, Sized):
346+
class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
329347
real: ndarray
330348
imag: ndarray
331349

tests/README.md

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,22 @@ reveal_type(x) # E: <type name>
2222
Right now, the error messages and types are must be **contained within
2323
corresponding mypy message**.
2424

25+
Test files that end in `_py3.py` will only be type checked against Python 3.
26+
All other test files must be valid in both Python 2 and Python 3.
27+
2528
## Running the tests
2629

30+
To setup your test environment, cd into the root of the repo and run:
31+
32+
33+
```
34+
pip install -r test-requirements.txt
35+
pip install .
36+
```
37+
38+
Note that due to how mypy reads type information in PEP 561 packages, you'll
39+
need to re-run the `pip install .` command each time you change the stubs.
40+
2741
We use `py.test` to orchestrate our tests. You can just run:
2842

2943
```
@@ -34,6 +48,16 @@ to run the entire test suite. To run `mypy` on a specific file (which
3448
can be useful for debugging), you can also run:
3549

3650
```
37-
$ cd tests
38-
$ MYPYPATH=.. mypy <file_path>
51+
mypy <file_path>
52+
```
53+
54+
Note that it is assumed that all of these commands target the same
55+
underlying Python interpreter. To ensure you're using the intended version of
56+
Python you can use `python -m` versions of these commands instead:
57+
58+
```
59+
python -m pip install -r test-requirements.txt
60+
python -m pip install .
61+
python -m pytest
62+
python -m mypy <file_path>
3963
```

tests/pass/simple.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
import operator
33

44
import numpy as np
5-
from typing import Iterable
5+
from typing import Iterable # noqa: F401
66

77
# Basic checks
88
array = np.array([1, 2])
9-
def ndarray_func(x: np.ndarray) -> np.ndarray:
9+
def ndarray_func(x):
10+
# type: (np.ndarray) -> np.ndarray
1011
return x
1112
ndarray_func(np.array([1, 2]))
1213
array == 1
@@ -28,7 +29,8 @@ def ndarray_func(x: np.ndarray) -> np.ndarray:
2829
np.dtype((np.int32, (np.int8, 4)))
2930

3031
# Iteration and indexing
31-
def iterable_func(x: Iterable) -> Iterable:
32+
def iterable_func(x):
33+
# type: (Iterable) -> Iterable
3234
return x
3335
iterable_func(array)
3436
[element for element in array]
@@ -122,8 +124,6 @@ def iterable_func(x: Iterable) -> Iterable:
122124
1 | array
123125
array |= 1
124126

125-
array @ array
126-
127127
# unary arithmetic
128128
-array
129129
+array

tests/pass/simple_py3.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import numpy as np
2+
3+
array = np.array([1, 2])
4+
5+
# The @ operator is not in python 2
6+
array @ array

tests/test_stubs.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,48 @@
33
import pytest
44
from mypy import api
55

6-
ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
7-
PASS_DIR = os.path.join(os.path.dirname(__file__), "pass")
8-
FAIL_DIR = os.path.join(os.path.dirname(__file__), "fail")
9-
REVEAL_DIR = os.path.join(os.path.dirname(__file__), "reveal")
10-
11-
os.environ['MYPYPATH'] = ROOT_DIR
6+
TESTS_DIR = os.path.dirname(__file__)
7+
PASS_DIR = os.path.join(TESTS_DIR, "pass")
8+
FAIL_DIR = os.path.join(TESTS_DIR, "fail")
9+
REVEAL_DIR = os.path.join(TESTS_DIR, "reveal")
1210

1311

1412
def get_test_cases(directory):
1513
for root, __, files in os.walk(directory):
1614
for fname in files:
1715
if os.path.splitext(fname)[-1] == ".py":
18-
# yield relative path for nice py.test name
19-
yield os.path.relpath(
20-
os.path.join(root, fname), start=directory)
21-
22-
23-
@pytest.mark.parametrize("path", get_test_cases(PASS_DIR))
24-
def test_success(path):
25-
stdout, stderr, exitcode = api.run([os.path.join(PASS_DIR, path)])
16+
fullpath = os.path.join(root, fname)
17+
# Use relative path for nice py.test name
18+
relpath = os.path.relpath(fullpath, start=directory)
19+
skip_py2 = fname.endswith("_py3.py")
20+
21+
for py_version_number in (2, 3):
22+
if py_version_number == 2 and skip_py2:
23+
continue
24+
py2_arg = ['--py2'] if py_version_number == 2 else []
25+
26+
yield pytest.param(
27+
fullpath,
28+
py2_arg,
29+
# Manually specify a name for the test
30+
id="{} - python{}".format(relpath, py_version_number),
31+
)
32+
33+
34+
@pytest.mark.parametrize("path,py2_arg", get_test_cases(PASS_DIR))
35+
def test_success(path, py2_arg):
36+
stdout, stderr, exitcode = api.run([path] + py2_arg)
2637
assert stdout == ''
2738
assert exitcode == 0
2839

2940

30-
@pytest.mark.parametrize("path", get_test_cases(FAIL_DIR))
31-
def test_fail(path):
32-
stdout, stderr, exitcode = api.run([os.path.join(FAIL_DIR, path)])
41+
@pytest.mark.parametrize("path,py2_arg", get_test_cases(FAIL_DIR))
42+
def test_fail(path, py2_arg):
43+
stdout, stderr, exitcode = api.run([path] + py2_arg)
3344

3445
assert exitcode != 0
3546

36-
with open(os.path.join(FAIL_DIR, path)) as fin:
47+
with open(path) as fin:
3748
lines = fin.readlines()
3849

3950
errors = {}
@@ -59,11 +70,11 @@ def test_fail(path):
5970
pytest.fail(f'Error {repr(errors[lineno])} not found')
6071

6172

62-
@pytest.mark.parametrize("path", get_test_cases(REVEAL_DIR))
63-
def test_reveal(path):
64-
stdout, stderr, exitcode = api.run([os.path.join(REVEAL_DIR, path)])
73+
@pytest.mark.parametrize("path,py2_arg", get_test_cases(REVEAL_DIR))
74+
def test_reveal(path, py2_arg):
75+
stdout, stderr, exitcode = api.run([path] + py2_arg)
6576

66-
with open(os.path.join(REVEAL_DIR, path)) as fin:
77+
with open(path) as fin:
6778
lines = fin.readlines()
6879

6980
for error_line in stdout.split("\n"):

0 commit comments

Comments
 (0)