Skip to content

Commit 6f8dd4a

Browse files
authored
Fix sharding config file name bug (#86)
* fix sharding config file name bug * formmat
1 parent caf0734 commit 6f8dd4a

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

jetstream_pt/engine.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,10 @@ def create_pytorch_engine(
711711
pt_model = None
712712

713713
if not sharding_config:
714-
sharding_config = os.path.join("default_shardings", model_name + ".yaml")
714+
sharding_file_name = "llama" if model_name.startswith("llama") else "gemma"
715+
sharding_config = os.path.join(
716+
"default_shardings", sharding_file_name + ".yaml"
717+
)
715718

716719
env_data = JetEngineEnvironmentData(
717720
tokenizer_path=tokenizer_path,

run_server.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,6 @@ def main(argv: Sequence[str]):
105105
devices = server_lib.get_devices()
106106
print(f"devices: {devices}")
107107
sharding_config_path = _SHARDING_CONFIG.value
108-
if not sharding_config_path:
109-
sharding_config_path = os.path.join(
110-
"default_shardings", _MODEL_NAME.value + ".yaml"
111-
)
112108
engine = jetstream_pt.create_pytorch_engine(
113109
devices=devices,
114110
tokenizer_path=_TOKENIZER_PATH.value,

0 commit comments

Comments
 (0)