-
Notifications
You must be signed in to change notification settings - Fork 22
format and use hf tokenizer api #65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,9 +16,11 @@ | |
from torch import distributed as dist | ||
from fms.models import get_model, register_model | ||
from fms.models.llama import LLaMAConfig, _llama_factory_factory | ||
from fms.utils import generation, tokenizers | ||
from fms.utils import generation | ||
from fms.utils.generation import pad_input_ids | ||
|
||
from transformers import AutoTokenizer | ||
|
||
|
||
# This example script validates the LLaMA implementation by running inference on a couple of prompts. | ||
# | ||
|
@@ -551,7 +553,7 @@ def select_int8_module( | |
dprint(model) | ||
dprint("=" * 60 + "\n") | ||
|
||
tokenizer = tokenizers.get_tokenizer(args.tokenizer) | ||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) | ||
model.eval() | ||
torch.set_grad_enabled(False) | ||
loading_model_time = time.time() - loading_model_time | ||
|
@@ -570,15 +572,6 @@ def select_int8_module( | |
add_special_tokens = tokenizer.bos_token_id != tokenizer.eos_token_id | ||
|
||
|
||
def ids_for_prompt(prompt): | ||
tokens = tokenizer.tokenize(prompt) | ||
ids = tokenizer.convert_tokens_to_ids(tokens) | ||
if add_special_tokens: | ||
ids = [tokenizer.bos_token_id] + ids | ||
ids = torch.tensor(ids, dtype=torch.long, device=device) | ||
return ids | ||
|
||
|
||
def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length): | ||
# we may want the prompt length to be fixed to some max length | ||
# this will ensure that prior to padding the input ids | ||
|
@@ -626,7 +619,11 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length): | |
for i, prompt_file_path in enumerate(prompt_file_paths): | ||
if i == args.batch_size: | ||
break | ||
prompts.append(ids_for_prompt(prompt_file_path.read_text(encoding="utf-8"))) | ||
prompts.append( | ||
tokenizer.encode( | ||
prompt_file_path.read_text(encoding="utf-8"), return_tensors="pt" | ||
) | ||
) | ||
|
||
else: | ||
if args.prompt_type == "chat": | ||
|
@@ -656,10 +653,10 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length): | |
dprint("prompt_type must be one of chat or code") | ||
exit() | ||
|
||
prompt1 = ids_for_prompt(prompt1) | ||
prompt2 = ids_for_prompt(prompt2) | ||
prompt3 = ids_for_prompt(prompt3) | ||
prompt4 = ids_for_prompt(prompt4) | ||
prompt1 = tokenizer.encode(prompt1, return_tensors="pt").squeeze(0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could we do a batch encode here? |
||
prompt2 = tokenizer.encode(prompt2, return_tensors="pt").squeeze(0) | ||
prompt3 = tokenizer.encode(prompt3, return_tensors="pt").squeeze(0) | ||
prompt4 = tokenizer.encode(prompt4, return_tensors="pt").squeeze(0) | ||
prompts = [prompt1, prompt2, prompt3, prompt4] | ||
prompts = prompts * ((args.batch_size // 4) + 1) | ||
prompts = prompts[: args.batch_size] | ||
|
@@ -703,9 +700,7 @@ def print_result(result, result_idx: int): | |
if not args.no_early_termination: | ||
result = generation.truncate_after_eos(result, tokenizer.eos_token_id) | ||
|
||
output_str = tokenizer.convert_tokens_to_string( | ||
tokenizer.convert_ids_to_tokens(result) | ||
) | ||
output_str = tokenizer.decode(result) | ||
|
||
if args.output_path != "": | ||
output_path = Path(args.output_path) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,7 @@ | |
import torch._inductor.config | ||
from fms.models import get_model, register_model | ||
from fms.models.llama import LLaMAConfig, _llama_factory_factory | ||
from fms.utils import generation, tokenizers | ||
from fms.utils import generation | ||
from fms.utils.generation import pad_input_ids | ||
from torch import distributed as dist | ||
from aiu_fms_testing_utils.utils import warmup_model | ||
|
@@ -27,6 +27,7 @@ | |
) | ||
from aiu_fms_testing_utils.utils import aiu_setup | ||
from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size | ||
from transformers import AutoTokenizer | ||
|
||
# This example script validates models on AIU through comparisons to other devices. | ||
parser = argparse.ArgumentParser( | ||
|
@@ -469,7 +470,7 @@ | |
dprint(validation_model) | ||
dprint("=" * 60 + "\n") | ||
|
||
tokenizer = tokenizers.get_tokenizer(args.tokenizer) | ||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) | ||
model.eval() | ||
torch.set_grad_enabled(False) | ||
loading_model_time = time.time() - loading_model_time | ||
|
@@ -490,15 +491,6 @@ | |
add_special_tokens = tokenizer.bos_token_id != tokenizer.eos_token_id | ||
|
||
|
||
def ids_for_prompt(prompt): | ||
tokens = tokenizer.tokenize(prompt) | ||
ids = tokenizer.convert_tokens_to_ids(tokens) | ||
if add_special_tokens: | ||
ids = [tokenizer.bos_token_id] + ids | ||
ids = torch.tensor(ids, dtype=torch.long, device="cpu") | ||
return ids | ||
|
||
|
||
def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length): | ||
# we may want the prompt length to be fixed to some max length | ||
# this will ensure that prior to padding the input ids | ||
|
@@ -547,7 +539,11 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length): | |
for i, prompt_file_path in enumerate(prompt_file_paths): | ||
if i == args.batch_size: | ||
break | ||
prompts.append(ids_for_prompt(prompt_file_path.read_text(encoding="utf-8"))) | ||
prompts.append( | ||
tokenizer.encode( | ||
prompt_file_path.read_text(encoding="utf-8"), return_tensors="pt" | ||
) | ||
) | ||
|
||
else: | ||
if args.prompt_type == "chat": | ||
|
@@ -577,10 +573,10 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length): | |
dprint("prompt_type must be one of chat or code") | ||
exit() | ||
|
||
prompt1 = ids_for_prompt(prompt1) | ||
prompt2 = ids_for_prompt(prompt2) | ||
prompt3 = ids_for_prompt(prompt3) | ||
prompt4 = ids_for_prompt(prompt4) | ||
prompt1 = tokenizer.encode(prompt1, return_tensors="pt").squeeze(0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could this be a batch encode? |
||
prompt2 = tokenizer.encode(prompt2, return_tensors="pt").squeeze(0) | ||
prompt3 = tokenizer.encode(prompt3, return_tensors="pt").squeeze(0) | ||
prompt4 = tokenizer.encode(prompt4, return_tensors="pt").squeeze(0) | ||
prompts = [prompt1, prompt2, prompt3, prompt4] | ||
prompts = prompts * ((args.batch_size // 4) + 1) | ||
prompts = prompts[: args.batch_size] | ||
|
@@ -622,9 +618,7 @@ def print_result(result, result_idx: int = 0, file_prefix: str = ""): | |
if not args.no_early_termination: | ||
result = generation.truncate_after_eos(result, tokenizer.eos_token_id) | ||
|
||
output_str = tokenizer.convert_tokens_to_string( | ||
tokenizer.convert_ids_to_tokens(result) | ||
) | ||
output_str = tokenizer.decode(result) | ||
|
||
if args.output_path != "": | ||
output_path = Path(args.output_path) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
from fms.models.hf.utils import AutoConfig | ||
from fms.utils import serialization, tokenizers | ||
from fms.utils import serialization | ||
import pytest | ||
from fms.models import get_model | ||
from fms.utils.generation import pad_input_ids | ||
|
@@ -20,9 +20,10 @@ | |
from aiu_fms_testing_utils.utils import ( | ||
warmup_model, | ||
sample_sharegpt_requests, | ||
ids_for_prompt, | ||
) | ||
import json | ||
from transformers import AutoTokenizer | ||
|
||
from aiu_fms_testing_utils.utils.aiu_setup import dprint, aiu_dist_setup | ||
|
||
import os | ||
|
@@ -56,9 +57,6 @@ | |
GRANITE_3p3_8B_INSTRUCT: os.path.join( | ||
MICRO_MODELS_HOME, "granite-3.3-8b-layers-3-step-100000" | ||
), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are we removing this model? |
||
LLAMA_3p1_70B_INSTRUCT: os.path.join( | ||
MICRO_MODELS_HOME, "llama-3.1-70b-layers-3-step-24000" | ||
), | ||
} | ||
|
||
SHARE_GPT_DATASET_PATH = os.environ.get( | ||
|
@@ -295,7 +293,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0): | |
) | ||
prompt_list = [] | ||
for prompt, _ in prompts_and_sizes: | ||
prompt_list.append(ids_for_prompt(prompt, tokenizer)) | ||
prompt_list.append(tokenizer.encode(prompt, return_tensors="pt").squeeze(0)) | ||
|
||
input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length) | ||
return input_ids, extra_kwargs | ||
|
@@ -451,7 +449,7 @@ def test_common_shapes( | |
**distributed_kwargs, | ||
} | ||
|
||
tokenizer = tokenizers.get_tokenizer(model_path) | ||
tokenizer = AutoTokenizer.from_pretrained(model_path) | ||
|
||
# prepare the AIU model | ||
model = persistent_model.get_or_create( | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would the above BaseTokenizer type hint need to be updated to a huggingface tokenizer?