@@ -124,6 +124,7 @@ def sample_requests(
124
124
dataset_path : str ,
125
125
num_requests : int ,
126
126
tokenizer : Any ,
127
+ max_output_length : int ,
127
128
) -> List [InputRequest ]:
128
129
# Load the dataset.
129
130
with open (dataset_path ) as f :
@@ -167,7 +168,7 @@ def sample_requests(
167
168
if prompt_len > 1024 or prompt_len + output_len > 2048 :
168
169
# Prune too long sequences.
169
170
continue
170
- reqeust = InputRequest (prompt , prompt_len , output , output_len )
171
+ reqeust = InputRequest (prompt , prompt_len , output , max_output_length )
171
172
filtered_dataset .append (reqeust )
172
173
173
174
# Sample the requests.
@@ -388,7 +389,7 @@ def main(args: argparse.Namespace):
388
389
if tokenizer == "test" or args .dataset == "test" :
389
390
input_requests = mock_requests (args .total_mock_requests ) # e.g. [("AB", 2, "AB", 3)]
390
391
else :
391
- input_requests = sample_requests (args .dataset , args .num_prompts , tokenizer )
392
+ input_requests = sample_requests (args .dataset , args .num_prompts , tokenizer , args . max_output_length )
392
393
393
394
benchmark_result , request_outputs = asyncio .run (
394
395
benchmark (
@@ -501,6 +502,14 @@ def main(args: argparse.Namespace):
501
502
default = 150 ,
502
503
help = "The maximum number of mock requests to send for benchmark testing." ,
503
504
)
505
+
506
+ parser .add_argument (
507
+ "--max-output-length" ,
508
+ type = int ,
509
+ default = 1024 ,
510
+ help = "The maximum output length for reference request." ,
511
+ )
512
+
504
513
parser .add_argument ("--seed" , type = int , default = 0 )
505
514
parser .add_argument (
506
515
"--disable-tqdm" ,
0 commit comments