Skip to content

Commit 15c9a48

Browse files
add local tokenizer option for automated testing without hf token
1 parent 3afdc43 commit 15c9a48

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

jetstream_pt/cli.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@
3333
"",
3434
"if set, then save the result to the given file name",
3535
)
36-
36+
flags.DEFINE_bool(
37+
"internal_use_local_tokenizer",
38+
0,
39+
"Use local tokenizer if set to True")
3740

3841
def shard_weights(env, weights, weight_shardings):
3942
"""Shard weights according to weight_shardings"""
@@ -57,8 +60,11 @@ def create_engine(devices):
5760
FLAGS.max_input_length,
5861
FLAGS.max_output_length,
5962
)
60-
tokenizer = AutoTokenizer.from_pretrained(FLAGS.model_id)
6163
env = environment.JetEngineEnvironment(env_data)
64+
if FLAGS.internal_use_local_tokenizer:
65+
tokenizer = AutoTokenizer.from_pretrained(env_data.tokenizer_path)
66+
else:
67+
tokenizer = AutoTokenizer.from_pretrained(FLAGS.model_id)
6268
env.hf_tokenizer = tokenizer
6369
model = fetch_models.instantiate_model_from_repo_id(FLAGS.model_id, env)
6470
# NOTE: this is assigned later because, the model should be constructed

0 commit comments

Comments
 (0)