From dadf8d9c93c80ed21ea33bae1ec074d8e646104b Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Thu, 7 Aug 2025 15:09:46 +0200 Subject: [PATCH] repeat host -> proc per node --- python/mlx/distributed_run.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index fa5e094bd..fc8850f90 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -418,7 +418,7 @@ def launch_mpi(parser, hosts, args, command): def launch_nccl(parser, hosts, args, command): master_host = hosts[0].ips[0] master_port = args.nccl_port - world_size = args.repeat_hosts * len(hosts) + world_size = args.nproc_per_node * len(hosts) base_env = os.environ.copy() base_env.update( @@ -814,6 +814,12 @@ def main(): default=12345, help="The port to use for the NCCL communication (only for nccl backend)", ) + parser.add_argument( + "--nproc-per-node", + type=positive_number, + default=1, + help="How many processes to run per node (only for nccl backend)", + ) args, rest = parser.parse_known_args() if rest[0] == "--":