Skip to content

Commit fe22a9f

Browse files
authoredNov 6, 2024··
Add model warmup flag into cli (#197)
add model warmup flag into cli
1 parent 02927c9 commit fe22a9f

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed
 

‎jetstream_pt/cli.py

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
flags.DEFINE_bool(
3737
"internal_use_local_tokenizer", 0, "Use local tokenizer if set to True"
3838
)
39+
flags.DEFINE_bool("enable_model_warmup", False, "enable model warmup")
3940

4041

4142
def shard_weights(env, weights, weight_shardings):
@@ -111,6 +112,7 @@ def serve():
111112
config=server_config,
112113
devices=devices,
113114
metrics_server_config=metrics_server_config,
115+
enable_model_warmup=FLAGS.enable_model_warmup,
114116
)
115117
print("Started jetstream_server....")
116118
jetstream_server.wait_for_termination()

0 commit comments

Comments
 (0)
Please sign in to comment.