File tree Expand file tree Collapse file tree 2 files changed +4
-5
lines changed Expand file tree Collapse file tree 2 files changed +4
-5
lines changed Original file line number Diff line number Diff line change @@ -711,7 +711,10 @@ def create_pytorch_engine(
711
711
pt_model = None
712
712
713
713
if not sharding_config :
714
- sharding_config = os .path .join ("default_shardings" , model_name + ".yaml" )
714
+ sharding_file_name = "llama" if model_name .startswith ("llama" ) else "gemma"
715
+ sharding_config = os .path .join (
716
+ "default_shardings" , sharding_file_name + ".yaml"
717
+ )
715
718
716
719
env_data = JetEngineEnvironmentData (
717
720
tokenizer_path = tokenizer_path ,
Original file line number Diff line number Diff line change @@ -105,10 +105,6 @@ def main(argv: Sequence[str]):
105
105
devices = server_lib .get_devices ()
106
106
print (f"devices: { devices } " )
107
107
sharding_config_path = _SHARDING_CONFIG .value
108
- if not sharding_config_path :
109
- sharding_config_path = os .path .join (
110
- "default_shardings" , _MODEL_NAME .value + ".yaml"
111
- )
112
108
engine = jetstream_pt .create_pytorch_engine (
113
109
devices = devices ,
114
110
tokenizer_path = _TOKENIZER_PATH .value ,
You can’t perform that action at this time.
0 commit comments