@@ -204,35 +204,33 @@ def result_type(
204
204
# required by the spec rather than using np.result_type. NumPy implements
205
205
# too many extra type promotions like int64 + uint64 -> float64, and does
206
206
# value-based casting on scalar arrays.
207
- A = []
207
+ dtypes = []
208
208
scalars = []
209
209
for a in arrays_and_dtypes :
210
- if isinstance (a , Array ):
211
- a = a .dtype
210
+ if isinstance (a , DType ):
211
+ dtypes .append (a )
212
+ elif isinstance (a , Array ):
213
+ dtypes .append (a .dtype )
212
214
elif isinstance (a , (bool , int , float , complex )):
213
215
scalars .append (a )
214
- elif isinstance (a , np .ndarray ) or a not in _all_dtypes :
215
- raise TypeError ("result_type() inputs must be array_api arrays or dtypes" )
216
- A .append (a )
217
-
218
- # remove python scalars
219
- B = [a for a in A if not isinstance (a , (bool , int , float , complex ))]
216
+ else :
217
+ raise TypeError (
218
+ "result_type() inputs must be Array API arrays, dtypes, or scalars"
219
+ )
220
220
221
- if len ( B ) == 0 :
221
+ if not dtypes :
222
222
raise ValueError ("at least one array or dtype is required" )
223
- elif len (B ) == 1 :
224
- result = B [0 ]
225
- else :
226
- t = B [0 ]
227
- for t2 in B [1 :]:
228
- t = _result_type (t , t2 )
229
- result = t
223
+ result = dtypes [0 ]
224
+ for t2 in dtypes [1 :]:
225
+ result = _result_type (result , t2 )
230
226
231
- if len ( scalars ) == 0 :
227
+ if not scalars :
232
228
return result
233
229
234
230
if get_array_api_strict_flags ()['api_version' ] <= '2023.12' :
235
- raise TypeError ("result_type() inputs must be array_api arrays or dtypes" )
231
+ raise TypeError (
232
+ "result_type() inputs must be Array API arrays, dtypes, or scalars"
233
+ )
236
234
237
235
# promote python scalars given the result_type for all arrays/dtypes
238
236
from ._creation_functions import empty
0 commit comments