diff --git a/src/accelerate/commands/launch.py b/src/accelerate/commands/launch.py index e0074f80762..1cf2dbed68b 100644 --- a/src/accelerate/commands/launch.py +++ b/src/accelerate/commands/launch.py @@ -481,6 +481,12 @@ def launch_command_parser(subparsers=None): type=str, help="DeepSpeed hostfile for configuring multi-node compute resources.", ) + deepspeed_args.add_argument( + "--deepspeed_ssh_port", + default=None, + type=str, + help="SSH port to use for remote connections with DeepSpeed.", + ) deepspeed_args.add_argument( "--deepspeed_exclusion_filter", default=None, diff --git a/src/accelerate/utils/launch.py b/src/accelerate/utils/launch.py index c6f3d60031d..4c4b5e573aa 100644 --- a/src/accelerate/utils/launch.py +++ b/src/accelerate/utils/launch.py @@ -339,6 +339,8 @@ def prepare_deepspeed_cmd_env(args: argparse.Namespace) -> Tuple[List[str], Dict ) else: cmd.extend(["--num_gpus", str(args.num_processes // args.num_machines)]) + if args.deepspeed_ssh_port is not None: + cmd.extend(["--ssh_port", str(args.deepspeed_ssh_port)]) if main_process_ip: cmd.extend(["--master_addr", str(main_process_ip)]) cmd.extend(["--master_port", str(main_process_port)])