Skip to content

Commit 788e6be

Browse files
committed
MAINT: result_type cosmetic refactor
1 parent ea5deb1 commit 788e6be

File tree

1 file changed

+17
-19
lines changed

1 file changed

+17
-19
lines changed

array_api_strict/_data_type_functions.py

+17-19
Original file line numberDiff line numberDiff line change
@@ -204,35 +204,33 @@ def result_type(
204204
# required by the spec rather than using np.result_type. NumPy implements
205205
# too many extra type promotions like int64 + uint64 -> float64, and does
206206
# value-based casting on scalar arrays.
207-
A = []
207+
dtypes = []
208208
scalars = []
209209
for a in arrays_and_dtypes:
210-
if isinstance(a, Array):
211-
a = a.dtype
210+
if isinstance(a, DType):
211+
dtypes.append(a)
212+
elif isinstance(a, Array):
213+
dtypes.append(a.dtype)
212214
elif isinstance(a, (bool, int, float, complex)):
213215
scalars.append(a)
214-
elif isinstance(a, np.ndarray) or a not in _all_dtypes:
215-
raise TypeError("result_type() inputs must be array_api arrays or dtypes")
216-
A.append(a)
217-
218-
# remove python scalars
219-
B = [a for a in A if not isinstance(a, (bool, int, float, complex))]
216+
else:
217+
raise TypeError(
218+
"result_type() inputs must be Array API arrays, dtypes, or scalars"
219+
)
220220

221-
if len(B) == 0:
221+
if not dtypes:
222222
raise ValueError("at least one array or dtype is required")
223-
elif len(B) == 1:
224-
result = B[0]
225-
else:
226-
t = B[0]
227-
for t2 in B[1:]:
228-
t = _result_type(t, t2)
229-
result = t
223+
result = dtypes[0]
224+
for t2 in dtypes[1:]:
225+
result = _result_type(result, t2)
230226

231-
if len(scalars) == 0:
227+
if not scalars:
232228
return result
233229

234230
if get_array_api_strict_flags()['api_version'] <= '2023.12':
235-
raise TypeError("result_type() inputs must be array_api arrays or dtypes")
231+
raise TypeError(
232+
"result_type() inputs must be Array API arrays, dtypes, or scalars"
233+
)
236234

237235
# promote python scalars given the result_type for all arrays/dtypes
238236
from ._creation_functions import empty

0 commit comments

Comments
 (0)