diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index fc8850f90..e3bcb822e 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -435,7 +435,7 @@ def launch_nccl(parser, hosts, args, command): for rank in range(world_size): env = base_env.copy() env["MLX_RANK"] = str(rank) - + env["CUDA_VISIBLE_DEVICES"] = str(rank % args.nproc_per_node) p = Popen(command, env=env) procs.append(p)