Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit b2c0a55

Browse files
author
Chaluvadi
committedMar 11, 2024
tests for blas operations
1 parent a657583 commit b2c0a55

File tree

1 file changed

+542
-0
lines changed

1 file changed

+542
-0
lines changed
 

‎tests/test_blas.py

+542
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,542 @@
1+
import random
2+
3+
import pytest
4+
5+
import arrayfire_wrapper.dtypes as dtypes
6+
import arrayfire_wrapper.lib as wrapper
7+
from arrayfire_wrapper.defines import AFArray
8+
from arrayfire_wrapper.lib._constants import MatProp
9+
10+
11+
@pytest.mark.parametrize(
12+
"shape",
13+
[
14+
(1,),
15+
(100,),
16+
(1000,),
17+
(10000,),
18+
(1, 1),
19+
(10, 10),
20+
(100, 100),
21+
(1000, 1000),
22+
(1, 1, 1),
23+
(10, 10, 10),
24+
(100, 100, 100),
25+
(1000, 1000, 1000),
26+
(1, 1, 1, 1),
27+
(10, 10, 10, 10),
28+
(100, 100, 100, 100),
29+
(1000, 1000, 1000, 1000),
30+
],
31+
)
32+
def test_dot_res(shape: tuple) -> None:
33+
"""Test if the dot product outputs an AFArray with a dimension of 1"""
34+
dtype = dtypes.f32
35+
shape = (10,)
36+
x = wrapper.randu(shape, dtype)
37+
38+
result = wrapper.dot(x, x, MatProp.NONE, MatProp.NONE)
39+
40+
assert isinstance(result, AFArray)
41+
42+
43+
@pytest.mark.parametrize(
44+
"shape_pairs",
45+
[
46+
[(5, 1), (1, 6)],
47+
[(10, 10), (9, 10)],
48+
[(9, 8), (10, 10)],
49+
[(random.randint(1, 10), 100), (1000, random.randint(1, 10))],
50+
[(random.randint(1, 10), 100000), (2, random.randint(1, 10))],
51+
],
52+
)
53+
def test_dot_invalid_shape_comp(shape_pairs: list) -> None:
54+
"""Test if an improper shape pair is properly handled"""
55+
with pytest.raises(RuntimeError):
56+
dtype = dtypes.f32
57+
x = wrapper.randu(shape_pairs[0], dtype)
58+
y = wrapper.randu(shape_pairs[1], dtype)
59+
60+
wrapper.dot(x, y, MatProp.NONE, MatProp.NONE)
61+
62+
63+
def test_dot_empty_vector() -> None:
64+
"""Test if an empty array passed into the dot product is properly handled"""
65+
with pytest.raises(RuntimeError):
66+
empty_shape = (0,)
67+
dtype = dtypes.f32
68+
69+
x = wrapper.randu(empty_shape, dtype)
70+
wrapper.dot(x, x, MatProp.NONE, MatProp.NONE)
71+
72+
73+
@pytest.mark.parametrize(
74+
"shape",
75+
[
76+
(1,),
77+
(100,),
78+
(1000,),
79+
(10000,),
80+
(1, 1),
81+
(10, 10),
82+
(100, 100),
83+
(1000, 1000),
84+
(1, 1, 1),
85+
(10, 10, 10),
86+
(100, 100, 100),
87+
(1000, 1000, 1000),
88+
(1, 1, 1, 1),
89+
(10, 10, 10, 10),
90+
(100, 100, 100, 100),
91+
(1000, 1000, 1000, 1000),
92+
],
93+
)
94+
def test_dot_diff_dtype(shape: tuple) -> None:
95+
"""Test of dot product of arrays of different dtypes is properly handled"""
96+
with pytest.raises(RuntimeError):
97+
x = wrapper.randu(shape, dtypes.f32)
98+
y = wrapper.randu(shape, dtypes.c32)
99+
100+
wrapper.dot(x, y, MatProp.NONE, MatProp.NONE)
101+
102+
103+
@pytest.mark.parametrize(
104+
"shape",
105+
[(1,), (10,), (100,), (1000,)],
106+
)
107+
@pytest.mark.parametrize(
108+
"dtype_index",
109+
[i for i in range(13)],
110+
)
111+
def test_dot_invalid_dtype(shape: tuple, dtype_index: int) -> None:
112+
"""Test if improper dtypes are properly handled"""
113+
if dtype_index in [12, 0, 2, 1, 3]:
114+
pytest.skip()
115+
116+
with pytest.raises(RuntimeError):
117+
x = wrapper.randu(shape, dtypes.s16)
118+
y = wrapper.randu(shape, dtypes.s16)
119+
120+
wrapper.dot(x, y, MatProp.NONE, MatProp.NONE)
121+
122+
123+
# dot all tests
124+
@pytest.mark.parametrize(
125+
"shape",
126+
[
127+
(1,),
128+
(100,),
129+
(1000,),
130+
(10000,),
131+
(1, 1),
132+
(10, 10),
133+
(100, 100),
134+
(1000, 1000),
135+
(1, 1, 1),
136+
(10, 10, 10),
137+
(100, 100, 100),
138+
(1000, 1000, 1000),
139+
(1, 1, 1, 1),
140+
(10, 10, 10, 10),
141+
(100, 100, 100, 100),
142+
(1000, 1000, 1000, 1000),
143+
],
144+
)
145+
def test_dot_all_res_float(shape: tuple) -> None:
146+
"""Test if the dot_all product outputs a float scalar value"""
147+
dtype = dtypes.f32
148+
shape = (10,)
149+
x = wrapper.randu(shape, dtype)
150+
151+
result = wrapper.dot_all(x, x, MatProp.NONE, MatProp.NONE)
152+
153+
assert isinstance(result, float)
154+
155+
156+
@pytest.mark.parametrize(
157+
"shape",
158+
[
159+
(1,),
160+
(100,),
161+
(1000,),
162+
(10000,),
163+
(1, 1),
164+
(10, 10),
165+
(100, 100),
166+
(1000, 1000),
167+
(1, 1, 1),
168+
(10, 10, 10),
169+
(100, 100, 100),
170+
(1000, 1000, 1000),
171+
(1, 1, 1, 1),
172+
(10, 10, 10, 10),
173+
(100, 100, 100, 100),
174+
(1000, 1000, 1000, 1000),
175+
],
176+
)
177+
def test_dot_all_res_complex(shape: tuple) -> None:
178+
"""Test if the dot_all product outputs a complex scalar value"""
179+
dtype = dtypes.c32
180+
shape = (10,)
181+
x = wrapper.randu(shape, dtype)
182+
183+
result = wrapper.dot_all(x, x, MatProp.NONE, MatProp.NONE)
184+
185+
assert isinstance(result, complex)
186+
187+
188+
@pytest.mark.parametrize(
189+
"shape",
190+
[
191+
(1,),
192+
(100,),
193+
(1000,),
194+
(10000,),
195+
(1, 1),
196+
(10, 10),
197+
(100, 100),
198+
(1000, 1000),
199+
(1, 1, 1),
200+
(10, 10, 10),
201+
(100, 100, 100),
202+
(1000, 1000, 1000),
203+
(1, 1, 1, 1),
204+
(10, 10, 10, 10),
205+
(100, 100, 100, 100),
206+
(1000, 1000, 1000, 1000),
207+
],
208+
)
209+
def test_dot_all_diff_dtype(shape: tuple) -> None:
210+
"""Test if a dot product of arrays of different dtypes is properly handled"""
211+
with pytest.raises(RuntimeError):
212+
x = wrapper.randu(shape, dtypes.f32)
213+
y = wrapper.randu(shape, dtypes.c32)
214+
215+
wrapper.dot_all(x, y, MatProp.NONE, MatProp.NONE)
216+
217+
218+
@pytest.mark.parametrize(
219+
"shape",
220+
[(1,), (10,), (100,), (1000,)],
221+
)
222+
@pytest.mark.parametrize(
223+
"dtype_index",
224+
[i for i in range(13)],
225+
)
226+
def test_dot_all_invalid_dtype(shape: tuple, dtype_index: int) -> None:
227+
"""Test if dot_all properly handles improper dtypes"""
228+
if dtype_index in [12, 0, 2, 1, 3]:
229+
pytest.skip()
230+
231+
with pytest.raises(RuntimeError):
232+
x = wrapper.randu(shape, dtypes.s16)
233+
y = wrapper.randu(shape, dtypes.s16)
234+
235+
wrapper.dot_all(x, y, MatProp.NONE, MatProp.NONE)
236+
237+
238+
@pytest.mark.parametrize(
239+
"shape_pairs",
240+
[
241+
[(5, 1), (1, 6)],
242+
[(10, 10), (9, 10)],
243+
[(9, 8), (10, 10)],
244+
[(random.randint(1, 10), 100), (1000, random.randint(1, 10))],
245+
[(random.randint(1, 10), 100000), (2, random.randint(1, 10))],
246+
],
247+
)
248+
def test_dot_all_invalid_shape_comp(shape_pairs: list) -> None:
249+
"""Test if dot_all properly handles an improper shape pair"""
250+
with pytest.raises(RuntimeError):
251+
dtype = dtypes.f32
252+
x = wrapper.randu(shape_pairs[0], dtype)
253+
y = wrapper.randu(shape_pairs[1], dtype)
254+
255+
wrapper.dot_all(x, y, MatProp.NONE, MatProp.NONE)
256+
257+
258+
def test_dot_all_empty_vector() -> None:
259+
"""Test if an empty array passed into the dot product is properly handled"""
260+
with pytest.raises(RuntimeError):
261+
empty_shape = (0,)
262+
dtype = dtypes.f32
263+
264+
x = wrapper.randu(empty_shape, dtype)
265+
wrapper.dot(x, x, MatProp.NONE, MatProp.NONE)
266+
267+
268+
# gemm tests
269+
@pytest.mark.parametrize(
270+
"shape_pairs",
271+
[
272+
[(random.randint(1, 10), 10), (10, random.randint(1, 10))],
273+
[(random.randint(1, 10), 100), (100, random.randint(1, 10))],
274+
[(random.randint(1, 10), 1000), (1000, random.randint(1, 10))],
275+
],
276+
)
277+
def test_gemm_correct_shape_2d(shape_pairs: list) -> None:
278+
"""Test if matmul outputs an array with the correct shape given 2d inputs"""
279+
dtype = dtypes.f32
280+
x = wrapper.randu(shape_pairs[0], dtype)
281+
y = wrapper.randu(shape_pairs[1], dtype)
282+
283+
result_shape = (shape_pairs[0][0], shape_pairs[1][1])
284+
result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1)
285+
286+
assert wrapper.get_dims(result)[0:2] == result_shape
287+
288+
289+
@pytest.mark.parametrize(
290+
"shape_pairs",
291+
[
292+
[(random.randint(1, 10), 1, 2), (1, random.randint(1, 10), 2)],
293+
[(random.randint(1, 10), 10, 2), (10, random.randint(1, 10), 2)],
294+
[(random.randint(1, 10), 100, 2), (100, random.randint(1, 10), 2)],
295+
[(random.randint(1, 10), 1000, 2), (1000, random.randint(1, 10), 2)],
296+
],
297+
)
298+
def test_gemm_correct_shape_3d(shape_pairs: list) -> None:
299+
"""Test if matul outpus an array with the correct shape given 3d inputs"""
300+
dtype = dtypes.f32
301+
x = wrapper.randu(shape_pairs[0], dtype)
302+
y = wrapper.randu(shape_pairs[1], dtype)
303+
result_shape = (shape_pairs[0][0], shape_pairs[1][1], shape_pairs[0][2])
304+
305+
result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1)
306+
assert wrapper.get_dims(result)[0:3] == result_shape
307+
308+
309+
@pytest.mark.parametrize(
310+
"shape_pairs",
311+
[
312+
[(random.randint(1, 10), 1, 2, 2), (1, random.randint(1, 10), 2, 2)],
313+
[(random.randint(1, 10), 10, 2, 2), (10, random.randint(1, 10), 2, 2)],
314+
[(random.randint(1, 10), 100, 2, 2), (100, random.randint(1, 10), 2, 2)],
315+
[(random.randint(1, 10), 1000, 2, 2), (1000, random.randint(1, 10), 2, 2)],
316+
],
317+
)
318+
def test_gemm_correct_shape_4d(shape_pairs: list) -> None:
319+
"""Test if matmul outpus an array with the correct shape given 4d inputs"""
320+
dtype = dtypes.f32
321+
x = wrapper.randu(shape_pairs[0], dtype)
322+
y = wrapper.randu(shape_pairs[1], dtype)
323+
result_shape = (shape_pairs[0][0], shape_pairs[1][1], shape_pairs[0][2], shape_pairs[0][3])
324+
325+
result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1)
326+
assert wrapper.get_dims(result)[0:4] == result_shape
327+
328+
329+
@pytest.mark.parametrize(
330+
"dtype",
331+
[dtypes.f32, dtypes.c32, dtypes.f64, dtypes.c64],
332+
)
333+
def test_gemm_correct_dtype(dtype: dtypes.Dtype) -> None:
334+
"""Test if matmul outputs an array with the correct dtype"""
335+
if dtype in [dtypes.f64, dtypes.c64] and not wrapper.get_dbl_support():
336+
pytest.skip()
337+
338+
shape = (100, 100)
339+
x = wrapper.randu(shape, dtype)
340+
y = wrapper.randu(shape, dtype)
341+
342+
result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1)
343+
344+
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
345+
346+
347+
@pytest.mark.parametrize(
348+
"shape_pairs",
349+
[
350+
[(5, 2), (1, 6)],
351+
[(10, 10), (9, 10)],
352+
[(9, 8), (10, 10)],
353+
[(random.randint(1, 10), 100), (1000, random.randint(1, 10))],
354+
[(random.randint(1, 10), 100000), (2, random.randint(1, 10))],
355+
],
356+
)
357+
def test_gemm_invalid_pair(shape_pairs: list) -> None:
358+
"""Test if matmul handles improper shape pairs"""
359+
with pytest.raises(RuntimeError):
360+
dtype = dtypes.f32
361+
x = wrapper.randu(shape_pairs[0], dtype)
362+
y = wrapper.randu(shape_pairs[1], dtype)
363+
364+
wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1)
365+
366+
367+
def test_gemm_empty_shape() -> None:
368+
"""Test if matmul handles an empty array"""
369+
with pytest.raises(RuntimeError):
370+
empty_shape = (0,)
371+
dtype = dtypes.f32
372+
373+
x = wrapper.randu(empty_shape, dtype)
374+
wrapper.gemm(x, x, MatProp.NONE, MatProp.NONE, 1, 1)
375+
376+
377+
@pytest.mark.parametrize(
378+
"dtype_index",
379+
[i for i in range(13)],
380+
)
381+
def test_gemm_invalid_dtype(dtype_index: int) -> None:
382+
"""Test if matmul handles an array with an invalid dtype - integer, long, short"""
383+
shape = (random.randint(1, 10), random.randint(1, 10))
384+
if dtype_index in [12, 0, 2, 1, 3]:
385+
pytest.skip()
386+
387+
dtype = dtypes.c_api_value_to_dtype(dtype_index)
388+
389+
with pytest.raises(TypeError):
390+
x = wrapper.randu(shape, dtype)
391+
y = wrapper.randu(shape, dtype)
392+
393+
wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1)
394+
395+
396+
def test_gemm_empty_matrix() -> None:
397+
"""Test if matmul handles an empty array passed in"""
398+
with pytest.raises(RuntimeError):
399+
empty_shape = (0,)
400+
dtype = dtypes.f32
401+
402+
x = wrapper.randu(empty_shape, dtype)
403+
wrapper.gemm(x, x, MatProp.NONE, MatProp.NONE, 1, 1)
404+
405+
406+
# matmul tests
407+
@pytest.mark.parametrize(
408+
"shape_pairs",
409+
[
410+
[(random.randint(1, 10), 1), (1, random.randint(1, 10))],
411+
[(random.randint(1, 10), 10), (10, random.randint(1, 10))],
412+
[(random.randint(1, 10), 100), (100, random.randint(1, 10))],
413+
[(random.randint(1, 10), 1000), (1000, random.randint(1, 10))],
414+
],
415+
)
416+
def test_matmul_correct_shape_2d(shape_pairs: list) -> None:
417+
"""Test if matmul outputs an array with the correct shape given 2d inputs"""
418+
dtype = dtypes.f32
419+
x = wrapper.randu(shape_pairs[0], dtype)
420+
y = wrapper.randu(shape_pairs[1], dtype)
421+
422+
result_shape = (shape_pairs[0][0], shape_pairs[1][1])
423+
result = wrapper.matmul(x, y, MatProp.NONE, MatProp.NONE)
424+
425+
assert wrapper.get_dims(result)[0 : len(shape_pairs[0])] == result_shape # noqa: E203
426+
427+
428+
@pytest.mark.parametrize(
429+
"shape_pairs",
430+
[
431+
[(random.randint(1, 10), 1, 2), (1, random.randint(1, 10), 2)],
432+
[(random.randint(1, 10), 10, 2), (10, random.randint(1, 10), 2)],
433+
[(random.randint(1, 10), 100, 2), (100, random.randint(1, 10), 2)],
434+
[(random.randint(1, 10), 1000, 2), (1000, random.randint(1, 10), 2)],
435+
],
436+
)
437+
def test_matmul_correct_shape_3d(shape_pairs: list) -> None:
438+
"""Test if matul outpus an array with the correct shape given 3d inputs"""
439+
dtype = dtypes.f32
440+
x = wrapper.randu(shape_pairs[0], dtype)
441+
y = wrapper.randu(shape_pairs[1], dtype)
442+
result_shape = (shape_pairs[0][0], shape_pairs[1][1], shape_pairs[0][2])
443+
444+
result = wrapper.matmul(x, y, MatProp.NONE, MatProp.NONE)
445+
assert wrapper.get_dims(result)[0:3] == result_shape
446+
447+
448+
@pytest.mark.parametrize(
449+
"shape_pairs",
450+
[
451+
[(random.randint(1, 10), 1, 2, 2), (1, random.randint(1, 10), 2, 2)],
452+
[(random.randint(1, 10), 10, 2, 2), (10, random.randint(1, 10), 2, 2)],
453+
[(random.randint(1, 10), 100, 2, 2), (100, random.randint(1, 10), 2, 2)],
454+
[(random.randint(1, 10), 1000, 2, 2), (1000, random.randint(1, 10), 2, 2)],
455+
],
456+
)
457+
def test_matmul_correct_shape_4d(shape_pairs: list) -> None:
458+
"""Test if matmul outpus an array with the correct shape given 4d inputs"""
459+
dtype = dtypes.f32
460+
x = wrapper.randu(shape_pairs[0], dtype)
461+
y = wrapper.randu(shape_pairs[1], dtype)
462+
result_shape = (shape_pairs[0][0], shape_pairs[1][1], shape_pairs[0][2], shape_pairs[0][3])
463+
464+
result = wrapper.matmul(x, y, MatProp.NONE, MatProp.NONE)
465+
assert wrapper.get_dims(result)[0:4] == result_shape
466+
467+
468+
@pytest.mark.parametrize(
469+
"dtype",
470+
[dtypes.f16, dtypes.f32, dtypes.c32, dtypes.f64, dtypes.c64],
471+
)
472+
def test_matmul_correct_dtype(dtype: dtypes.Dtype) -> None:
473+
"""Test if matmul outputs an array with the correct dtype"""
474+
if dtype in [dtypes.f64, dtypes.c64] and not wrapper.get_dbl_support():
475+
pytest.skip()
476+
477+
shape = (100, 100)
478+
x = wrapper.randu(shape, dtype)
479+
y = wrapper.randu(shape, dtype)
480+
481+
result = wrapper.matmul(x, y, MatProp.NONE, MatProp.NONE)
482+
483+
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
484+
485+
486+
@pytest.mark.parametrize(
487+
"shape_pairs",
488+
[
489+
[(5, 2), (1, 6)],
490+
[(10, 10), (9, 10)],
491+
[(9, 8), (10, 10)],
492+
[(random.randint(1, 10), 100), (1000, random.randint(1, 10))],
493+
[(random.randint(1, 10), 100000), (2, random.randint(1, 10))],
494+
],
495+
)
496+
def test_matmul_invalid_pair(shape_pairs: list) -> None:
497+
"""Test if matmul handles improper shape pairs"""
498+
with pytest.raises(RuntimeError):
499+
dtype = dtypes.f32
500+
x = wrapper.randu(shape_pairs[0], dtype)
501+
y = wrapper.randu(shape_pairs[1], dtype)
502+
503+
wrapper.matmul(x, y, MatProp.NONE, MatProp.NONE)
504+
505+
506+
def test_matmul_empty_shape() -> None:
507+
"""Test if matmul handles an empty array"""
508+
with pytest.raises(RuntimeError):
509+
empty_shape = (0,)
510+
dtype = dtypes.f32
511+
512+
x = wrapper.randu(empty_shape, dtype)
513+
wrapper.matmul(x, x, MatProp.NONE, MatProp.NONE)
514+
515+
516+
@pytest.mark.parametrize(
517+
"dtype_index",
518+
[i for i in range(13)],
519+
)
520+
def test_matmul_invalid_dtype(dtype_index: int) -> None:
521+
"""Test if matmul handles an array with an invalid dtype - integer, long, short"""
522+
shape = (random.randint(1, 10), random.randint(1, 10))
523+
if dtype_index in [12, 0, 2, 1, 3]:
524+
pytest.skip()
525+
526+
dtype = dtypes.c_api_value_to_dtype(dtype_index)
527+
528+
with pytest.raises(RuntimeError):
529+
x = wrapper.randu(shape, dtype)
530+
y = wrapper.randu(shape, dtype)
531+
532+
wrapper.matmul(x, y, MatProp.NONE, MatProp.NONE)
533+
534+
535+
def test_matmul_empty_matrix() -> None:
536+
"""Test if matmul handles an empty array passed in"""
537+
with pytest.raises(RuntimeError):
538+
empty_shape = (0,)
539+
dtype = dtypes.f32
540+
541+
x = wrapper.randu(empty_shape, dtype)
542+
wrapper.matmul(x, x, MatProp.NONE, MatProp.NONE)

0 commit comments

Comments
 (0)
Please sign in to comment.