Skip to content

Commit b0ae45a

Browse files
adamomainzfacebook-github-bot
authored andcommitted
fixing OSS ci for tritonbench (#2439)
Summary: Pull Request resolved: #2439 two problems found so far in CI runs 1. the ci file was not added to the target deps so the ci file couldnt be found even if the flag was used 2. `--op` was required so we fixed that here where `--op` is only required if `--ci` is not used Reviewed By: danzimm, xuzhao9 Differential Revision: D61809602 fbshipit-source-id: e84844f0b6a697892222a008f45a8c20db8ef6bd
1 parent babb128 commit b0ae45a

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

userbenchmark/triton/ci.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import List, Any, Dict
55

66
from . import BM_NAME
7-
from .run import parse_args, _run
7+
from .run import get_parser, _run
88
from userbenchmark.utils import get_default_output_json_path, get_output_json
99
from torchbenchmark.util.triton_op import BenchmarkOperatorResult
1010

@@ -22,7 +22,7 @@ def run_ci():
2222
ci_result = []
2323
for test_opts in CI_TESTS:
2424
logging.info(f"Running the test opts: {test_opts}")
25-
test_args, test_extra_args = parse_args(test_opts)
25+
test_args, test_extra_args = get_parser(test_opts).parse_known_args(test_opts)
2626
metrics = _run(test_args, test_extra_args)
2727
ci_result.append(metrics)
2828
result = ci_result_to_userbenchmark_json(ci_result)

userbenchmark/triton/run.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@
2525

2626
TRITON_BENCH_CSV_DUMP_PATH = tempfile.gettempdir() + "/tritonbench/"
2727

28-
def get_parser():
28+
def get_parser(args = None):
2929
parser = argparse.ArgumentParser(allow_abbrev=False)
3030
parser.add_argument(
3131
"--op",
3232
type=str,
33-
required=True,
33+
required=False,
3434
help="Operator to benchmark."
3535
)
3636
parser.add_argument(
@@ -140,6 +140,12 @@ def get_parser():
140140
)
141141
if not hasattr(torch_version, "git_version"):
142142
parser.add_argument("--log-scuba", action="store_true", help="Log to scuba.")
143+
144+
args, extra_args = parser.parse_known_args(args)
145+
if args.op and args.ci:
146+
parser.error("cannot specify operator when in CI mode")
147+
elif not args.op and not args.ci:
148+
parser.error("must specify operator when not in CI mode")
143149
return parser
144150

145151
def _run(args: argparse.Namespace, extra_args: List[str]) -> BenchmarkOperatorResult:
@@ -187,5 +193,6 @@ def run(args: List[str] = []):
187193
from .ci import run_ci
188194
run_ci()
189195
return
196+
190197
with gpu_lockdown(args.gpu_lockdown):
191198
_run(args, extra_args)

0 commit comments

Comments
 (0)