Skip to content
This repository was archived by the owner on Jun 10, 2020. It is now read-only.

Commit f0cb9ff

Browse files
Alan Dushoyer
authored andcommitted
Add NumPy scalar hierarchy (#14)
* Add NumPy scalar hierarchy * Include timedelta64 * Factor out _ArrayLike base class * Cluster away platform specific types * fixup! Factor out _ArrayLike base class * Add _is_real type to fix real and imag typing * fixup! fixup! Factor out _ArrayLike base class * s/_is_real/_real_generic/g * s/_ArrayLike/_ArrayOrScalarCommon array_like means "can be converted to array" * Update dtype.type * Add testing framework * Fix type annotations * Update Travis * Use pytest.fail
1 parent 3074ca3 commit f0cb9ff

File tree

11 files changed

+311
-53
lines changed

11 files changed

+311
-53
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
.mypy_cache
2+
.pytest_cache
3+
__pycache__

.travis.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@ notifications:
66

77
install:
88
- pip install -r test-requirements.txt
9-
- if [[ $TRAVIS_PYTHON_VERSION == '3.6' ]]; then pip install flake8-pyi==17.3.0; fi
109

1110
script:
1211
- flake8
13-
- MYPYPATH="." mypy tests/test_simple.py
12+
- py.test
1413

1514
cache:
1615
directories:

numpy/__init__.pyi

Lines changed: 119 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import builtins
33
from typing import (
44
Any, Dict, Iterable, List, Optional, Mapping, Sequence, Sized,
55
SupportsInt, SupportsFloat, SupportsComplex, SupportsBytes, SupportsAbs,
6-
Text, Tuple, Union,
6+
Text, Tuple, Type, TypeVar, Union,
77
)
88

99
import sys
@@ -131,7 +131,7 @@ class dtype:
131131
def str(self) -> builtins.str: ...
132132

133133
@property
134-
def type(self) -> builtins.type: ...
134+
def type(self) -> Type[generic]: ...
135135

136136

137137
_Dtype = dtype # to avoid name conflicts with ndarray.dtype
@@ -195,35 +195,24 @@ class flatiter:
195195
def __next__(self) -> Any: ...
196196

197197

198-
class ndarray(Iterable, Sized, SupportsInt, SupportsFloat, SupportsComplex,
199-
SupportsBytes, SupportsAbs[Any]):
200-
201-
imag: ndarray
202-
real: ndarray
203-
198+
_ArraySelf = TypeVar("_ArraySelf", bound=_ArrayOrScalarCommon)
199+
class _ArrayOrScalarCommon(SupportsInt, SupportsFloat, SupportsComplex,
200+
SupportsBytes, SupportsAbs[Any]):
204201
@property
205-
def T(self) -> ndarray: ...
202+
def T(self: _ArraySelf) -> _ArraySelf: ...
206203

207204
@property
208205
def base(self) -> Optional[ndarray]: ...
209206

210207
@property
211208
def dtype(self) -> _Dtype: ...
212-
@dtype.setter
213-
def dtype(self, value: _DtypeLike): ...
214-
215-
@property
216-
def ctypes(self) -> _ctypes: ...
217209

218210
@property
219211
def data(self) -> memoryview: ...
220212

221213
@property
222214
def flags(self) -> _flagsobj: ...
223215

224-
@property
225-
def flat(self) -> flatiter: ...
226-
227216
@property
228217
def size(self) -> int: ...
229218

@@ -238,22 +227,9 @@ class ndarray(Iterable, Sized, SupportsInt, SupportsFloat, SupportsComplex,
238227

239228
@property
240229
def shape(self) -> _Shape: ...
241-
@shape.setter
242-
def shape(self, value: _ShapeLike): ...
243230

244231
@property
245232
def strides(self) -> _Shape: ...
246-
@strides.setter
247-
def strides(self, value: _ShapeLike): ...
248-
249-
# Many of these special methods are irrelevant currently, since protocols
250-
# aren't supported yet. That said, I'm adding them for completeness.
251-
# https://docs.python.org/3/reference/datamodel.html
252-
def __len__(self) -> int: ...
253-
def __getitem__(self, key) -> Any: ...
254-
def __setitem__(self, key, value): ...
255-
def __iter__(self) -> Any: ...
256-
def __contains__(self, key) -> bool: ...
257233

258234
def __int__(self) -> int: ...
259235
def __float__(self) -> float: ...
@@ -269,17 +245,8 @@ class ndarray(Iterable, Sized, SupportsInt, SupportsFloat, SupportsComplex,
269245
def __str__(self) -> str: ...
270246
def __repr__(self) -> str: ...
271247

272-
def __index__(self) -> int: ...
273-
274-
def __copy__(self, order: str = ...) -> ndarray: ...
275-
def __deepcopy__(self, memo: dict) -> ndarray: ...
276-
277-
# https://github.com/numpy/numpy/blob/v1.13.0/numpy/lib/mixins.py#L63-L181
278-
279-
# TODO(shoyer): add overloads (returning ndarray) for cases where other is
280-
# known not to define __array_priority__ or __array_ufunc__, such as for
281-
# numbers or other numpy arrays. Or even better, use protocols (once they
282-
# work).
248+
def __copy__(self: _ArraySelf, order: str = ...) -> _ArraySelf: ...
249+
def __deepcopy__(self: _ArraySelf, memo: dict) -> _ArraySelf: ...
283250

284251
def __lt__(self, other): ...
285252
def __le__(self, other): ...
@@ -349,15 +316,122 @@ class ndarray(Iterable, Sized, SupportsInt, SupportsFloat, SupportsComplex,
349316
def __matmul__(self, other): ...
350317
def __rmatmul__(self, other): ...
351318

352-
def __neg__(self) -> ndarray: ...
353-
def __pos__(self) -> ndarray: ...
354-
def __abs__(self) -> ndarray: ...
355-
def __invert__(self) -> ndarray: ...
319+
def __neg__(self: _ArraySelf) -> _ArraySelf: ...
320+
def __pos__(self: _ArraySelf) -> _ArraySelf: ...
321+
def __abs__(self: _ArraySelf) -> _ArraySelf: ...
322+
def __invert__(self: _ArraySelf) -> _ArraySelf: ...
356323

357324
# TODO(shoyer): remove when all methods are defined
358325
def __getattr__(self, name) -> Any: ...
359326

360327

328+
class ndarray(_ArrayOrScalarCommon, Iterable, Sized):
329+
real: ndarray
330+
imag: ndarray
331+
332+
@property
333+
def dtype(self) -> _Dtype: ...
334+
@dtype.setter
335+
def dtype(self, value: _DtypeLike): ...
336+
337+
@property
338+
def ctypes(self) -> _ctypes: ...
339+
340+
@property
341+
def shape(self) -> _Shape: ...
342+
@shape.setter
343+
def shape(self, value: _ShapeLike): ...
344+
345+
@property
346+
def flat(self) -> flatiter: ...
347+
348+
@property
349+
def strides(self) -> _Shape: ...
350+
@strides.setter
351+
def strides(self, value: _ShapeLike): ...
352+
353+
# Many of these special methods are irrelevant currently, since protocols
354+
# aren't supported yet. That said, I'm adding them for completeness.
355+
# https://docs.python.org/3/reference/datamodel.html
356+
def __len__(self) -> int: ...
357+
def __getitem__(self, key) -> Any: ...
358+
def __setitem__(self, key, value): ...
359+
def __iter__(self) -> Any: ...
360+
def __contains__(self, key) -> bool: ...
361+
def __index__(self) -> int: ...
362+
363+
class generic(_ArrayOrScalarCommon):
364+
def __init__(self, value: Any = ...) -> None: ...
365+
@property
366+
def base(self) -> None: ...
367+
368+
class _real_generic(generic):
369+
@property
370+
def real(self: _ArraySelf) -> _ArraySelf: ...
371+
372+
@property
373+
def imag(self: _ArraySelf) -> _ArraySelf: ...
374+
375+
class number(generic):
376+
def __init__(
377+
self, value: Union[SupportsInt, SupportsFloat] = ...
378+
) -> None: ...
379+
class bool_(_real_generic): ...
380+
class object_(generic): ...
381+
class datetime64(_real_generic): ...
382+
383+
class integer(number, _real_generic): ...
384+
class signedinteger(integer): ...
385+
class int8(signedinteger): ...
386+
class int16(signedinteger): ...
387+
class int32(signedinteger): ...
388+
class int64(signedinteger): ...
389+
class timedelta64(signedinteger): ...
390+
391+
class unsignedinteger(integer): ...
392+
class uint8(unsignedinteger): ...
393+
class uint16(unsignedinteger): ...
394+
class uint32(unsignedinteger): ...
395+
class uint64(unsignedinteger): ...
396+
397+
class inexact(number): ...
398+
class floating(inexact, _real_generic): ...
399+
class float16(floating): ...
400+
class float32(floating): ...
401+
class float64(floating): ...
402+
403+
class complexfloating(inexact):
404+
def __init__(
405+
self,
406+
value: Union[SupportsInt, SupportsFloat, SupportsComplex,
407+
complex] = ...,
408+
) -> None: ...
409+
class complex64(complexfloating):
410+
@property
411+
def real(self) -> float32: ...
412+
@property
413+
def imag(self) -> float32: ...
414+
class complex128(complexfloating):
415+
@property
416+
def real(self) -> float64: ...
417+
@property
418+
def imag(self) -> float64: ...
419+
420+
class flexible(_real_generic): ...
421+
class void(flexible): ...
422+
class character(_real_generic): ...
423+
class bytes_(character): ...
424+
class str_(character): ...
425+
426+
# TODO(alan): Platform dependent types
427+
# longcomplex, longdouble, longfloat
428+
# bytes, short, intc, intp, longlong
429+
# half, single, double, longdouble
430+
# uint_, int_, float_, complex_
431+
# float128, complex256
432+
# float96
433+
434+
361435
def array(
362436
object: object,
363437
dtype: _DtypeLike = ...,

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77

88
[flake8]
99
ignore = E301, E302, E305, E701, E704
10+
exclude = .git,__pycache__,tests/reveal

test-requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
flake8==3.3.0
2-
3-
mypy==0.570.0
2+
flake8-pyi==17.3.0
3+
pytest==3.4.2
4+
mypy==0.580.0

tests/README.md

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,39 @@
11
Testing
22
=======
33

4-
To run these tests:
4+
There are three main directories of tests right now:
55

6-
export MYPYPATH='..'
7-
mypy test_simple.py
6+
- `pass/` which contain Python files that must pass `mypy` checking with
7+
no type errors
8+
- `fail/` which contain Python files that must *fail* `mypy` checking
9+
with the annotated errors
10+
- `reveal/` which contain Python files that must output the correct
11+
types with `reveal_type`
812

9-
In future, this should change to use the test framework used by mypy.
13+
`fail` and `reveal` are annotated with comments that specify what error
14+
`mypy` threw and what type should be revealed respectively. The format
15+
looks like:
16+
17+
```python
18+
bad_function # E: <error message>
19+
reveal_type(x) # E: <type name>
20+
```
21+
22+
Right now, the error messages and types are must be **contained within
23+
corresponding mypy message**.
24+
25+
## Running the tests
26+
27+
We use `py.test` to orchestrate our tests. You can just run:
28+
29+
```
30+
py.test
31+
```
32+
33+
to run the entire test suite. To run `mypy` on a specific file (which
34+
can be useful for debugging), you can also run:
35+
36+
```
37+
$ cd tests
38+
$ MYPYPATH=.. mypy <file_path>
39+
```

tests/fail/scalars.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import numpy as np
2+
3+
# Construction
4+
5+
np.float32(3j) # E: incompatible type
6+
np.complex64([]) # E: incompatible type
7+
8+
np.complex64(1, 2) # E: Too many arguments
9+
# TODO: protocols (can't check for non-existent protocols w/ __getattr__)

tests/pass/scalars.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import numpy as np
2+
3+
4+
# Construction
5+
class C:
6+
def __complex__(self):
7+
return 3j
8+
9+
10+
class A:
11+
def __float__(self):
12+
return 4
13+
14+
15+
np.complex32(3)
16+
np.complex64(3j)
17+
np.complex64(C())
18+
19+
np.int8(4)
20+
np.int16(3.4)
21+
np.int32(4)
22+
np.int64(-1)
23+
np.uint8(A())
24+
np.uint32()
25+
26+
np.float16(A())
27+
np.float32(16)
28+
np.float64(3.0)
29+
30+
np.bytes_(b"hello")
31+
np.str_("hello")
32+
33+
# Protocols
34+
float(np.int8(4))
35+
int(np.int16(5))
36+
np.int8(np.float32(6))
37+
38+
# TODO(alan): test after https://github.com/python/typeshed/pull/2004
39+
# complex(np.int32(8))
40+
41+
abs(np.int8(4))
42+
43+
# Array-ish semantics
44+
np.int8().real
45+
np.int16().imag
46+
np.int32().data
47+
np.int64().flags
48+
49+
np.uint8().itemsize * 2
50+
np.uint16().ndim + 1
51+
np.uint32().strides
52+
np.uint64().shape
File renamed without changes.

tests/reveal/scalars.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import numpy as np
2+
3+
x = np.complex64(3 + 2j)
4+
5+
reveal_type(x.real) # E: numpy.float32
6+
reveal_type(x.imag) # E: numpy.float32
7+
8+
reveal_type(x.real.real) # E: numpy.float32
9+
reveal_type(x.real.imag) # E: numpy.float32
10+
11+
reveal_type(x.itemsize) # E: int
12+
reveal_type(x.shape) # E: tuple[builtins.int]
13+
reveal_type(x.strides) # E: tuple[builtins.int]

0 commit comments

Comments
 (0)