Skip to content

Commit 38506c1

Browse files
committed
format and use hf tokenizer api
Signed-off-by: kcirred <[email protected]>
1 parent 4f36d1b commit 38506c1

File tree

10 files changed

+50
-72
lines changed

10 files changed

+50
-72
lines changed

aiu_fms_testing_utils/testing/validation.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import List, Tuple, Callable, MutableMapping, Any, Optional
33

44
import torch
5-
from aiu_fms_testing_utils.utils import ids_for_prompt
65
from aiu_fms_testing_utils.utils.aiu_setup import dprint
76
import os
87

@@ -206,8 +205,8 @@ def load_validation_information(
206205
# Text format will get tokenized
207206
validation_info.append(
208207
{
209-
"tokens": ids_for_prompt(
210-
validation_file_path.read_text(encoding="utf-8"), tokenizer
208+
"tokens": tokenizer.encode(
209+
validation_file_path.read_text(encoding="utf-8"), return_tensors="pt"
211210
),
212211
"logits": None,
213212
}
@@ -378,12 +377,8 @@ def print_failed_cases(failed_cases, aiu_tokens, validation_tokens, tokenizer):
378377
aiu_token = aiu_tokens[sentence_index][token_index]
379378
validation_token = validation_tokens[sentence_index][token_index]
380379

381-
aiu_str = tokenizer.convert_tokens_to_string(
382-
tokenizer.convert_ids_to_tokens(aiu_token)
383-
)
384-
validation_str = tokenizer.convert_tokens_to_string(
385-
tokenizer.convert_ids_to_tokens(validation_token)
386-
)
380+
aiu_str = tokenizer.decode(aiu_token)
381+
validation_str = tokenizer.decode(validation_token)
387382
print(
388383
f"In sentence {sentence_index + 1}/{len(aiu_tokens)}, token {token_index}, AIU outputs {aiu_token} instead of {validation_token} -- AIU val={aiu_str} -- CPU val={validation_str}"
389384
)

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,6 @@ def warmup_model(
6767
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")
6868

6969

70-
def ids_for_prompt(prompt, tokenizer):
71-
tokens = tokenizer.tokenize(prompt)
72-
ids = tokenizer.convert_tokens_to_ids(tokens)
73-
if tokenizer.bos_token_id != tokenizer.eos_token_id:
74-
ids = [tokenizer.bos_token_id] + ids
75-
ids = torch.tensor(ids, dtype=torch.long, device="cpu")
76-
return ids
77-
7870

7971
def __download_file(url, filename):
8072
try:
@@ -110,7 +102,7 @@ def __sample_requests(
110102

111103
# Tokenize the prompts and completions.
112104
prompt = prompt_list[i]
113-
prompt_token_ids = ids_for_prompt(prompt, tokenizer)
105+
prompt_token_ids = tokenizer.encode(prompt, return_tensors="pt").squeeze(0)
114106

115107
prompt_len = len(prompt_token_ids)
116108
if prompt_len < prompt_length_min or prompt_len > prompt_length_max:
@@ -217,7 +209,7 @@ def prepare_inputs(
217209
)
218210
prompt_list = []
219211
for prompt, _ in prompts_and_sizes:
220-
prompt_list.append(ids_for_prompt(prompt, tokenizer))
212+
prompt_list.append(tokenizer.encode(prompt, return_tensors="pt").squeeze(0))
221213

222214
input_ids, padding_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length)
223215
return input_ids, padding_kwargs

scripts/generate_metrics.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
GoldenTokenHook,
1616
top_k_loss_calculator,
1717
)
18-
from aiu_fms_testing_utils.utils import ids_for_prompt, sample_sharegpt_requests
18+
from aiu_fms_testing_utils.utils import sample_sharegpt_requests
1919
from fms.models import get_model
20-
from fms.utils import tokenizers
2120
from fms.utils.generation import pad_input_ids
21+
from transformers import AutoTokenizer
2222

2323
parser = argparse.ArgumentParser(
2424
description="Script to determine a reasonable logits loss threshold when testing with aiu"
@@ -156,7 +156,7 @@
156156
if default_dtype is not None:
157157
torch.set_default_dtype(default_dtype)
158158

159-
tokenizer = tokenizers.get_tokenizer(args.tokenizer)
159+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
160160

161161
torch.set_grad_enabled(False)
162162

@@ -190,7 +190,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
190190
)
191191
prompt_list = []
192192
for prompt, _ in prompts_and_sizes:
193-
prompt_list.append(ids_for_prompt(prompt, tokenizer))
193+
prompt_list.append(tokenizer.encode(prompt, return_tensors="pt").squeeze(0))
194194

195195
input_ids, padding_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length)
196196
return input_ids, padding_kwargs

scripts/inference.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
from torch import distributed as dist
1717
from fms.models import get_model, register_model
1818
from fms.models.llama import LLaMAConfig, _llama_factory_factory
19-
from fms.utils import generation, tokenizers
19+
from fms.utils import generation
2020
from fms.utils.generation import pad_input_ids
2121

22+
from transformers import AutoTokenizer
23+
2224

2325
# This example script validates the LLaMA implementation by running inference on a couple of prompts.
2426
#
@@ -551,7 +553,7 @@ def select_int8_module(
551553
dprint(model)
552554
dprint("=" * 60 + "\n")
553555

554-
tokenizer = tokenizers.get_tokenizer(args.tokenizer)
556+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
555557
model.eval()
556558
torch.set_grad_enabled(False)
557559
loading_model_time = time.time() - loading_model_time
@@ -570,15 +572,6 @@ def select_int8_module(
570572
add_special_tokens = tokenizer.bos_token_id != tokenizer.eos_token_id
571573

572574

573-
def ids_for_prompt(prompt):
574-
tokens = tokenizer.tokenize(prompt)
575-
ids = tokenizer.convert_tokens_to_ids(tokens)
576-
if add_special_tokens:
577-
ids = [tokenizer.bos_token_id] + ids
578-
ids = torch.tensor(ids, dtype=torch.long, device=device)
579-
return ids
580-
581-
582575
def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length):
583576
# we may want the prompt length to be fixed to some max length
584577
# this will ensure that prior to padding the input ids
@@ -626,7 +619,11 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length):
626619
for i, prompt_file_path in enumerate(prompt_file_paths):
627620
if i == args.batch_size:
628621
break
629-
prompts.append(ids_for_prompt(prompt_file_path.read_text(encoding="utf-8")))
622+
prompts.append(
623+
tokenizer.encode(
624+
prompt_file_path.read_text(encoding="utf-8"), return_tensors="pt"
625+
)
626+
)
630627

631628
else:
632629
if args.prompt_type == "chat":
@@ -656,10 +653,10 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length):
656653
dprint("prompt_type must be one of chat or code")
657654
exit()
658655

659-
prompt1 = ids_for_prompt(prompt1)
660-
prompt2 = ids_for_prompt(prompt2)
661-
prompt3 = ids_for_prompt(prompt3)
662-
prompt4 = ids_for_prompt(prompt4)
656+
prompt1 = tokenizer.encode(prompt1, return_tensors="pt").squeeze(0)
657+
prompt2 = tokenizer.encode(prompt2, return_tensors="pt").squeeze(0)
658+
prompt3 = tokenizer.encode(prompt3, return_tensors="pt").squeeze(0)
659+
prompt4 = tokenizer.encode(prompt4, return_tensors="pt").squeeze(0)
663660
prompts = [prompt1, prompt2, prompt3, prompt4]
664661
prompts = prompts * ((args.batch_size // 4) + 1)
665662
prompts = prompts[: args.batch_size]
@@ -703,9 +700,7 @@ def print_result(result, result_idx: int):
703700
if not args.no_early_termination:
704701
result = generation.truncate_after_eos(result, tokenizer.eos_token_id)
705702

706-
output_str = tokenizer.convert_tokens_to_string(
707-
tokenizer.convert_ids_to_tokens(result)
708-
)
703+
output_str = tokenizer.decode(result)
709704

710705
if args.output_path != "":
711706
output_path = Path(args.output_path)

scripts/small-toy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch_sendnn import torch_sendnn # noqa
2020

2121

22+
2223
# ==============================================================
2324
# Toy Encoder Model
2425
# ==============================================================

scripts/validation.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch._inductor.config
1212
from fms.models import get_model, register_model
1313
from fms.models.llama import LLaMAConfig, _llama_factory_factory
14-
from fms.utils import generation, tokenizers
14+
from fms.utils import generation
1515
from fms.utils.generation import pad_input_ids
1616
from torch import distributed as dist
1717
from aiu_fms_testing_utils.utils import warmup_model
@@ -27,6 +27,7 @@
2727
)
2828
from aiu_fms_testing_utils.utils import aiu_setup
2929
from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size
30+
from transformers import AutoTokenizer
3031

3132
# This example script validates models on AIU through comparisons to other devices.
3233
parser = argparse.ArgumentParser(
@@ -469,7 +470,7 @@
469470
dprint(validation_model)
470471
dprint("=" * 60 + "\n")
471472

472-
tokenizer = tokenizers.get_tokenizer(args.tokenizer)
473+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
473474
model.eval()
474475
torch.set_grad_enabled(False)
475476
loading_model_time = time.time() - loading_model_time
@@ -490,15 +491,6 @@
490491
add_special_tokens = tokenizer.bos_token_id != tokenizer.eos_token_id
491492

492493

493-
def ids_for_prompt(prompt):
494-
tokens = tokenizer.tokenize(prompt)
495-
ids = tokenizer.convert_tokens_to_ids(tokens)
496-
if add_special_tokens:
497-
ids = [tokenizer.bos_token_id] + ids
498-
ids = torch.tensor(ids, dtype=torch.long, device="cpu")
499-
return ids
500-
501-
502494
def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length):
503495
# we may want the prompt length to be fixed to some max length
504496
# this will ensure that prior to padding the input ids
@@ -547,7 +539,11 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length):
547539
for i, prompt_file_path in enumerate(prompt_file_paths):
548540
if i == args.batch_size:
549541
break
550-
prompts.append(ids_for_prompt(prompt_file_path.read_text(encoding="utf-8")))
542+
prompts.append(
543+
tokenizer.encode(
544+
prompt_file_path.read_text(encoding="utf-8"), return_tensors="pt"
545+
)
546+
)
551547

552548
else:
553549
if args.prompt_type == "chat":
@@ -577,10 +573,10 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length):
577573
dprint("prompt_type must be one of chat or code")
578574
exit()
579575

580-
prompt1 = ids_for_prompt(prompt1)
581-
prompt2 = ids_for_prompt(prompt2)
582-
prompt3 = ids_for_prompt(prompt3)
583-
prompt4 = ids_for_prompt(prompt4)
576+
prompt1 = tokenizer.encode(prompt1, return_tensors="pt").squeeze(0)
577+
prompt2 = tokenizer.encode(prompt2, return_tensors="pt").squeeze(0)
578+
prompt3 = tokenizer.encode(prompt3, return_tensors="pt").squeeze(0)
579+
prompt4 = tokenizer.encode(prompt4, return_tensors="pt").squeeze(0)
584580
prompts = [prompt1, prompt2, prompt3, prompt4]
585581
prompts = prompts * ((args.batch_size // 4) + 1)
586582
prompts = prompts[: args.batch_size]
@@ -622,9 +618,7 @@ def print_result(result, result_idx: int = 0, file_prefix: str = ""):
622618
if not args.no_early_termination:
623619
result = generation.truncate_after_eos(result, tokenizer.eos_token_id)
624620

625-
output_str = tokenizer.convert_tokens_to_string(
626-
tokenizer.convert_ids_to_tokens(result)
627-
)
621+
output_str = tokenizer.decode(result)
628622

629623
if args.output_path != "":
630624
output_path = Path(args.output_path)

tests/models/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55

66

7+
78
def pytest_sessionstart(session):
89
"""
910
Called after the Session object has been created and

tests/models/test_decoders.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from fms.models.hf.utils import AutoConfig
2-
from fms.utils import serialization, tokenizers
2+
from fms.utils import serialization
33
import pytest
44
from fms.models import get_model
55
from fms.utils.generation import pad_input_ids
@@ -20,9 +20,10 @@
2020
from aiu_fms_testing_utils.utils import (
2121
warmup_model,
2222
sample_sharegpt_requests,
23-
ids_for_prompt,
2423
)
2524
import json
25+
from transformers import AutoTokenizer
26+
2627
from aiu_fms_testing_utils.utils.aiu_setup import dprint, aiu_dist_setup
2728

2829
import os
@@ -56,9 +57,6 @@
5657
GRANITE_3p3_8B_INSTRUCT: os.path.join(
5758
MICRO_MODELS_HOME, "granite-3.3-8b-layers-3-step-100000"
5859
),
59-
LLAMA_3p1_70B_INSTRUCT: os.path.join(
60-
MICRO_MODELS_HOME, "llama-3.1-70b-layers-3-step-24000"
61-
),
6260
}
6361

6462
SHARE_GPT_DATASET_PATH = os.environ.get(
@@ -295,7 +293,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
295293
)
296294
prompt_list = []
297295
for prompt, _ in prompts_and_sizes:
298-
prompt_list.append(ids_for_prompt(prompt, tokenizer))
296+
prompt_list.append(tokenizer.encode(prompt, return_tensors="pt").squeeze(0))
299297

300298
input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length)
301299
return input_ids, extra_kwargs
@@ -451,7 +449,7 @@ def test_common_shapes(
451449
**distributed_kwargs,
452450
}
453451

454-
tokenizer = tokenizers.get_tokenizer(model_path)
452+
tokenizer = AutoTokenizer.from_pretrained(model_path)
455453

456454
# prepare the AIU model
457455
model = persistent_model.get_or_create(

tests/models/test_encoders.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22
ModelSignatureParams,
33
get_signature,
44
)
5-
from fms.utils import tokenizers
65
import pytest
76
from fms.models import get_model
87
from fms.utils.generation import pad_input_ids
98
import itertools
109
import torch
11-
from aiu_fms_testing_utils.utils import ids_for_prompt, sample_squad_v2_qa_requests
10+
from aiu_fms_testing_utils.utils import sample_squad_v2_qa_requests
1211
from aiu_fms_testing_utils.utils.aiu_setup import dprint
1312
import os
1413
import numpy as np
14+
from transformers import AutoTokenizer
1515

1616
# Add models to test here
1717
ROBERTA_SQUAD_V2 = "deepset/roberta-base-squad2"
@@ -61,7 +61,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
6161
)
6262
prompt_list = []
6363
for prompt, _ in prompts_and_sizes:
64-
prompt_list.append(ids_for_prompt(prompt, tokenizer))
64+
prompt_list.append(tokenizer.encode(prompt, return_tensors="pt").squeeze(0))
6565

6666
input_ids, padding_kwargs = pad_input_ids(
6767
prompt_list, min_pad_length=seq_length, is_causal_mask=False
@@ -111,7 +111,7 @@ def test_common_shapes(model_path, batch_size, seq_length):
111111
f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}"
112112
)
113113

114-
tokenizer = tokenizers.get_tokenizer(model_path)
114+
tokenizer = AutoTokenizer.from_pretrained(model_path)
115115

116116
if os.path.exists(model_path):
117117
model_path_kwargs = {"model_path": model_path}

tests/models/test_scripts.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def execute_script(execute_cmd):
5151
raise Exception(error)
5252

5353

54+
5455
def execute_inference(model_path, max_new_tokens, batch_size, seq_length):
5556
execute_cmd = [
5657
"python3",
@@ -71,6 +72,7 @@ def execute_inference(model_path, max_new_tokens, batch_size, seq_length):
7172
return execute_script(execute_cmd)
7273

7374

75+
7476
common_asserts = [
7577
"### Response:\nProvide a list of instructions for preparing chicken soup",
7678
"### Response:\nExplain some popular greetings in Spanish.",

0 commit comments

Comments
 (0)