Skip to content

Commit d44d9e3

Browse files
Prashant Kumarpashu123
Prashant Kumar
authored andcommitted
[CONV] Pass the device and target info
The device and target info was hardcoded and now can be passed viacommand line to the compile command. Also, black formatted the files.
1 parent 7b5fd7c commit d44d9e3

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

convbench/conv_utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def generate_mlir(config: ConvConfig):
275275

276276

277277
def compile_conv_config(
278-
config: ConvConfig, kernel_dir: Path, vmfb_dir: Path
278+
config: ConvConfig, kernel_dir: Path, vmfb_dir: Path, device: str, target: str
279279
) -> tuple[Path, Optional[Path]]:
280280
mlir_file = kernel_dir / (config.get_name() + ".mlir")
281281
vmfb_file = vmfb_dir / (config.get_name() + ".vmfb")
@@ -298,9 +298,9 @@ def compile_conv_config(
298298
"-o",
299299
f"{vmfb_file}",
300300
# Target Device: hip
301-
"--iree-hal-target-device=hip",
301+
f"--iree-hal-target-device={device}",
302302
# Device: MI300x
303-
"--iree-hip-target=gfx942",
303+
f"--iree-hip-target={target}",
304304
]
305305

306306
print(" ".join(exec_args))

convbench/shark_conv.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
from problems import get_conv_configs
1313

1414

15-
def compile_conv(tag, config, kernel_dir, vmfb_dir):
16-
mlir_file, vmfb_file = compile_conv_config(config, kernel_dir, vmfb_dir)
15+
def compile_conv(tag, config, kernel_dir, vmfb_dir, device, target):
16+
mlir_file, vmfb_file = compile_conv_config(
17+
config, kernel_dir, vmfb_dir, device, target
18+
)
1719
return (tag, config, mlir_file, vmfb_file)
1820

1921

@@ -32,6 +34,12 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
3234
type=str,
3335
default="hip",
3436
)
37+
parser.add_argument(
38+
"--target",
39+
help="The device's target to execute benchmarks on",
40+
type=str,
41+
default="gfx942",
42+
)
3543
parser.add_argument(
3644
"--roofline",
3745
help="Comma seperated csv file list to generate roofline plot with",
@@ -68,7 +76,8 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
6876
device = args.device
6977

7078
compile_args = itertools.starmap(
71-
lambda tag, config: (tag, config, kernel_dir, vmfb_dir), configs
79+
lambda tag, config: (tag, config, kernel_dir, vmfb_dir, device, args.target),
80+
configs,
7281
)
7382
with Pool(num_cpus) as pool:
7483
compilation_results = list(tqdm(pool.starmap(compile_conv, list(compile_args))))
@@ -137,7 +146,8 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
137146
config.S,
138147
config.input_dtype,
139148
config.output_dtype,
140-
round(benchmark_gemm_mean_time_us, 4),
149+
round(benchmark_gemm_mean_time_us, 4),
150+
141151
round(arithmetic_intensity, 4),
142152
round(tflops_per_second, 4),
143153
ok,

0 commit comments

Comments
 (0)