Skip to content

Commit d13ab1b

Browse files
authored
Merge pull request #142 from crusaderky/result_type_nits
MAINT: `result_type` cosmetic refactor
2 parents ea5deb1 + 72eabc4 commit d13ab1b

File tree

1 file changed

+15
-19
lines changed

1 file changed

+15
-19
lines changed

array_api_strict/_data_type_functions.py

+15-19
Original file line numberDiff line numberDiff line change
@@ -204,35 +204,31 @@ def result_type(
204204
# required by the spec rather than using np.result_type. NumPy implements
205205
# too many extra type promotions like int64 + uint64 -> float64, and does
206206
# value-based casting on scalar arrays.
207-
A = []
207+
dtypes = []
208208
scalars = []
209209
for a in arrays_and_dtypes:
210-
if isinstance(a, Array):
211-
a = a.dtype
210+
if isinstance(a, DType):
211+
dtypes.append(a)
212+
elif isinstance(a, Array):
213+
dtypes.append(a.dtype)
212214
elif isinstance(a, (bool, int, float, complex)):
213215
scalars.append(a)
214-
elif isinstance(a, np.ndarray) or a not in _all_dtypes:
215-
raise TypeError("result_type() inputs must be array_api arrays or dtypes")
216-
A.append(a)
217-
218-
# remove python scalars
219-
B = [a for a in A if not isinstance(a, (bool, int, float, complex))]
216+
else:
217+
raise TypeError(
218+
"result_type() inputs must be Array API arrays, dtypes, or scalars"
219+
)
220220

221-
if len(B) == 0:
221+
if not dtypes:
222222
raise ValueError("at least one array or dtype is required")
223-
elif len(B) == 1:
224-
result = B[0]
225-
else:
226-
t = B[0]
227-
for t2 in B[1:]:
228-
t = _result_type(t, t2)
229-
result = t
223+
result = dtypes[0]
224+
for t2 in dtypes[1:]:
225+
result = _result_type(result, t2)
230226

231-
if len(scalars) == 0:
227+
if not scalars:
232228
return result
233229

234230
if get_array_api_strict_flags()['api_version'] <= '2023.12':
235-
raise TypeError("result_type() inputs must be array_api arrays or dtypes")
231+
raise TypeError("result_type() inputs must be Array API arrays or dtypes")
236232

237233
# promote python scalars given the result_type for all arrays/dtypes
238234
from ._creation_functions import empty

0 commit comments

Comments
 (0)