Skip to content

Commit 8289e65

Browse files
authored
Set max_output_length by 1024 as default and use max_output_length in inference request (#10)
1 parent 41ad033 commit 8289e65

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

benchmarks/benchmark_serving.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def sample_requests(
124124
dataset_path: str,
125125
num_requests: int,
126126
tokenizer: Any,
127+
max_output_length: int,
127128
) -> List[InputRequest]:
128129
# Load the dataset.
129130
with open(dataset_path) as f:
@@ -167,7 +168,7 @@ def sample_requests(
167168
if prompt_len > 1024 or prompt_len + output_len > 2048:
168169
# Prune too long sequences.
169170
continue
170-
reqeust = InputRequest(prompt, prompt_len, output, output_len)
171+
reqeust = InputRequest(prompt, prompt_len, output, max_output_length)
171172
filtered_dataset.append(reqeust)
172173

173174
# Sample the requests.
@@ -388,7 +389,7 @@ def main(args: argparse.Namespace):
388389
if tokenizer == "test" or args.dataset == "test":
389390
input_requests = mock_requests(args.total_mock_requests) # e.g. [("AB", 2, "AB", 3)]
390391
else:
391-
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
392+
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer, args.max_output_length)
392393

393394
benchmark_result, request_outputs = asyncio.run(
394395
benchmark(
@@ -501,6 +502,14 @@ def main(args: argparse.Namespace):
501502
default=150,
502503
help="The maximum number of mock requests to send for benchmark testing.",
503504
)
505+
506+
parser.add_argument(
507+
"--max-output-length",
508+
type=int,
509+
default=1024,
510+
help="The maximum output length for reference request.",
511+
)
512+
504513
parser.add_argument("--seed", type=int, default=0)
505514
parser.add_argument(
506515
"--disable-tqdm",

0 commit comments

Comments
 (0)