From 44721ac62bc8737a10f3b30b1094138839c7751d Mon Sep 17 00:00:00 2001 From: hsmallbone Date: Wed, 8 Jan 2025 13:21:33 +0800 Subject: [PATCH] feat: Add no_ssh multinode launcher option for deepspeed --- src/accelerate/utils/constants.py | 2 +- src/accelerate/utils/launch.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/accelerate/utils/constants.py b/src/accelerate/utils/constants.py index a6d7d262678..af5a95da123 100644 --- a/src/accelerate/utils/constants.py +++ b/src/accelerate/utils/constants.py @@ -41,7 +41,7 @@ "2.1.0.a0+32f93b1" # Technically should be 2.1.0, but MS-AMP uses this specific prerelease in their Docker image. ) FSDP_MODEL_NAME = "pytorch_model_fsdp" -DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich"] +DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich", "nossh"] TORCH_DYNAMO_MODES = ["default", "reduce-overhead", "max-autotune"] ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION = "2.2.0" XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0" diff --git a/src/accelerate/utils/launch.py b/src/accelerate/utils/launch.py index c6f3d60031d..7e413ee5fb4 100644 --- a/src/accelerate/utils/launch.py +++ b/src/accelerate/utils/launch.py @@ -321,8 +321,12 @@ def prepare_deepspeed_cmd_env(args: argparse.Namespace) -> Tuple[List[str], Dict args.deepspeed_multinode_launcher = DEEPSPEED_MULTINODE_LAUNCHERS[0] if num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]: - cmd = ["deepspeed", "--no_local_rank"] - cmd.extend(["--hostfile", str(args.deepspeed_hostfile), "--launcher", str(args.deepspeed_multinode_launcher)]) + cmd = ["deepspeed"] + cmd.extend(["--hostfile", str(args.deepspeed_hostfile)]) + if args.deepspeed_multinode_launcher == "nossh": + cmd.extend(["--node_rank", str(args.machine_rank), "--no_ssh"]) + else: + cmd.extend(["--no_local_rank", "--launcher", str(args.deepspeed_multinode_launcher)]) if args.deepspeed_exclusion_filter is not None: cmd.extend( [