Skip to content

Commit 17c17f5

Browse files
authored
Merge pull request #117 from crusaderky/isdtype
isdtype() should raise if parameter is not a dtype
2 parents e4b6bfe + c441971 commit 17c17f5

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

.gitignore

+4-4
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,12 @@ ENV/
128128
env.bak/
129129
venv.bak/
130130

131-
# Spyder project settings
131+
# Project settings
132+
.idea
133+
.ropeproject
132134
.spyderproject
133135
.spyproject
134-
135-
# Rope project settings
136-
.ropeproject
136+
.vscode
137137

138138
# mkdocs documentation
139139
/site

array_api_strict/_data_type_functions.py

+3
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,9 @@ def isdtype(
167167
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
168168
for more details
169169
"""
170+
if not isinstance(dtype, _DType):
171+
raise TypeError(f"'dtype' must be a dtype, not a {type(dtype)!r}")
172+
170173
if isinstance(kind, tuple):
171174
# Disallow nested tuples
172175
if any(isinstance(k, tuple) for k in kind):

array_api_strict/tests/test_data_type_functions.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,17 @@ def test_can_cast(from_, to, expected):
3131
def test_isdtype_strictness():
3232
assert_raises(TypeError, lambda: isdtype(float64, 64))
3333
assert_raises(ValueError, lambda: isdtype(float64, 'f8'))
34-
3534
assert_raises(TypeError, lambda: isdtype(float64, (('integral',),)))
35+
assert_raises(TypeError, lambda: isdtype(float64, None))
36+
assert_raises(TypeError, lambda: isdtype(np.float64, float64))
37+
assert_raises(TypeError, lambda: isdtype(asarray(1.0), float64))
38+
3639
with assert_raises(TypeError), warnings.catch_warnings(record=True) as w:
3740
warnings.simplefilter("always")
3841
isdtype(float64, np.object_)
3942
assert len(w) == 1
4043
assert issubclass(w[-1].category, UserWarning)
4144

42-
assert_raises(TypeError, lambda: isdtype(float64, None))
4345
with assert_raises(TypeError), warnings.catch_warnings(record=True) as w:
4446
warnings.simplefilter("always")
4547
isdtype(float64, np.float64)

0 commit comments

Comments
 (0)