Skip to content

Commit 56a2aa7

Browse files
committed
MAINT: result_type cosmetic refactor
1 parent ea5deb1 commit 56a2aa7

File tree

1 file changed

+11
-17
lines changed

1 file changed

+11
-17
lines changed

array_api_strict/_data_type_functions.py

+11-17
Original file line numberDiff line numberDiff line change
@@ -204,31 +204,25 @@ def result_type(
204204
# required by the spec rather than using np.result_type. NumPy implements
205205
# too many extra type promotions like int64 + uint64 -> float64, and does
206206
# value-based casting on scalar arrays.
207-
A = []
207+
dtypes = []
208208
scalars = []
209209
for a in arrays_and_dtypes:
210-
if isinstance(a, Array):
211-
a = a.dtype
210+
if isinstance(a, DType):
211+
dtypes.append(a)
212+
elif isinstance(a, Array):
213+
dtypes.append(a.dtype)
212214
elif isinstance(a, (bool, int, float, complex)):
213215
scalars.append(a)
214-
elif isinstance(a, np.ndarray) or a not in _all_dtypes:
216+
else:
215217
raise TypeError("result_type() inputs must be array_api arrays or dtypes")
216-
A.append(a)
217-
218-
# remove python scalars
219-
B = [a for a in A if not isinstance(a, (bool, int, float, complex))]
220218

221-
if len(B) == 0:
219+
if not dtypes:
222220
raise ValueError("at least one array or dtype is required")
223-
elif len(B) == 1:
224-
result = B[0]
225-
else:
226-
t = B[0]
227-
for t2 in B[1:]:
228-
t = _result_type(t, t2)
229-
result = t
221+
result = dtypes[0]
222+
for t2 in dtypes[1:]:
223+
result = _result_type(result, t2)
230224

231-
if len(scalars) == 0:
225+
if not scalars:
232226
return result
233227

234228
if get_array_api_strict_flags()['api_version'] <= '2023.12':

0 commit comments

Comments
 (0)