Skip to content

Commit

Permalink
Add pre-generated prompts option for benchmark (#1091)
Browse files Browse the repository at this point in the history
During benchmarking, we wanted to have pre-generated prompts that have
been prepared for better benchmark result. Hence, It can be handy during
benchmarking. In our test, we wanted to focus only token generation and
sampling on SLM.
  • Loading branch information
omer-demir authored Jan 24, 2025
1 parent 4db5f2a commit 0636ce3
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
21 changes: 21 additions & 0 deletions benchmark/python/benchmark_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ def generate_prompt(model, tokenizer, prompt_length, use_graph_capture) -> str:
generator.generate_next_token()
return tokenizer.decode(generator.get_sequence(0))

# Use prompt length to get pre-defined prompt
def get_prompt_by_length(prompt_length):
json_path = "prompts.json"
with open(json_path) as prompts_file:
content = prompts_file.read()
data = json.load(content)
return data[f"{prompt_length}"]

def get_target_pip_package_version(target_pip_package_name_list):
# get package name and version
import pkg_resources
Expand Down Expand Up @@ -231,6 +239,18 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length
# use random tokens instead of generating a prompt using the model and then tokenizing it
tokens = np.random.randint(100, size=(batch_size, prompt_length))
prompt = [tokenizer.decode(tokens[0])] * batch_size
elif args.use_prompt_set:
prompt = [get_prompt_by_length(prompt_length)] * batch_size
tokens = tokenizer.encode_batch(prompt)

if len(tokens) > max_length:
# Shorten the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
tokens = tokens[:, :max_length]
elif len(tokens) < max_length:
# Lengthen the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
tokens_first_col = tokens[:, 0].unsqueeze(0).T
for _ in range(max_length - len(tokens)):
tokens = np.hstack((tokens_first_col, tokens))
else:
prompt = [generate_prompt(model, tokenizer, prompt_length, args.use_graph_capture)] * batch_size
tokens = tokenizer.encode_batch(prompt)
Expand Down Expand Up @@ -416,6 +436,7 @@ def str2strlist(value):
parser.add_argument('-mn', '--model_name', type=str, default='model_name', help='Model name defined by users')
parser.add_argument('-pr', '--precision', type=str, default='fp16', help='Model precision for metrics info')
parser.add_argument('--use_random_tokens', action='store_true', help='Use random tokens instead of generating a prompt')
parser.add_argument('--use_prompt_set', action='store_true', help='Use pre-generated prompt set instead of generating a prompt')
args = parser.parse_args()

# check max_lengths
Expand Down
Loading

0 comments on commit 0636ce3

Please sign in to comment.