Skip to content

Commit

Permalink
feat: Add no_ssh multinode launcher option for deepspeed
Browse files Browse the repository at this point in the history
  • Loading branch information
hsmallbone committed Jan 8, 2025
1 parent d6d3e03 commit 44721ac
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/accelerate/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"2.1.0.a0+32f93b1" # Technically should be 2.1.0, but MS-AMP uses this specific prerelease in their Docker image.
)
FSDP_MODEL_NAME = "pytorch_model_fsdp"
DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich"]
DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich", "nossh"]
TORCH_DYNAMO_MODES = ["default", "reduce-overhead", "max-autotune"]
ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION = "2.2.0"
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0"
Expand Down
8 changes: 6 additions & 2 deletions src/accelerate/utils/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,12 @@ def prepare_deepspeed_cmd_env(args: argparse.Namespace) -> Tuple[List[str], Dict
args.deepspeed_multinode_launcher = DEEPSPEED_MULTINODE_LAUNCHERS[0]

if num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]:
cmd = ["deepspeed", "--no_local_rank"]
cmd.extend(["--hostfile", str(args.deepspeed_hostfile), "--launcher", str(args.deepspeed_multinode_launcher)])
cmd = ["deepspeed"]
cmd.extend(["--hostfile", str(args.deepspeed_hostfile)])
if args.deepspeed_multinode_launcher == "nossh":
cmd.extend(["--node_rank", str(args.machine_rank), "--no_ssh"])
else:
cmd.extend(["--no_local_rank", "--launcher", str(args.deepspeed_multinode_launcher)])
if args.deepspeed_exclusion_filter is not None:
cmd.extend(
[
Expand Down

0 comments on commit 44721ac

Please sign in to comment.