12
12
from problems import get_conv_configs
13
13
14
14
15
- def compile_conv (tag , config , kernel_dir , vmfb_dir ):
16
- mlir_file , vmfb_file = compile_conv_config (config , kernel_dir , vmfb_dir )
15
+ def compile_conv (tag , config , kernel_dir , vmfb_dir , device , target ):
16
+ mlir_file , vmfb_file = compile_conv_config (
17
+ config , kernel_dir , vmfb_dir , device , target
18
+ )
17
19
return (tag , config , mlir_file , vmfb_file )
18
20
19
21
@@ -32,6 +34,12 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
32
34
type = str ,
33
35
default = "hip" ,
34
36
)
37
+ parser .add_argument (
38
+ "--target" ,
39
+ help = "The device's target to execute benchmarks on" ,
40
+ type = str ,
41
+ default = "gfx942" ,
42
+ )
35
43
parser .add_argument (
36
44
"--roofline" ,
37
45
help = "Comma seperated csv file list to generate roofline plot with" ,
@@ -68,7 +76,8 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
68
76
device = args .device
69
77
70
78
compile_args = itertools .starmap (
71
- lambda tag , config : (tag , config , kernel_dir , vmfb_dir ), configs
79
+ lambda tag , config : (tag , config , kernel_dir , vmfb_dir , device , args .target ),
80
+ configs ,
72
81
)
73
82
with Pool (num_cpus ) as pool :
74
83
compilation_results = list (tqdm (pool .starmap (compile_conv , list (compile_args ))))
@@ -137,7 +146,8 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
137
146
config .S ,
138
147
config .input_dtype ,
139
148
config .output_dtype ,
140
- round (benchmark_gemm_mean_time_us , 4 ),
149
+ round (benchmark_gemm_mean_time_us , 4 ),
150
+
141
151
round (arithmetic_intensity , 4 ),
142
152
round (tflops_per_second , 4 ),
143
153
ok ,
0 commit comments