diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index afd8b5130..3dd373eb1 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -20,6 +20,8 @@ from select import select from subprocess import PIPE, Popen, run from typing import Optional +import mlx.core as mx + @dataclass class Host: @@ -437,8 +439,8 @@ def launch_nccl(parser, hosts, args, command): try: for rank in range(world_size): env = base_env.copy() - env["MLX_RANK"] = str(rank) - env["CUDA_VISIBLE_DEVICES"] = str(rank % args.nproc_per_node) + env["MLX_RANK"] = str(rank % args.repeat_hosts) + env["CUDA_VISIBLE_DEVICES"] = str(rank % args.repeat_hosts) p = Popen(command, env=env) procs.append(p) @@ -708,7 +710,7 @@ def distributed_config(): parser.add_argument( "--backend", choices=["ring", "mpi", "nccl"], - default="ring", + default="nccl" if mx.cuda.is_available() else "ring", help="Which distributed backend to configure", ) parser.add_argument( @@ -780,7 +782,7 @@ def main(): parser.add_argument( "--backend", choices=["ring", "mpi", "nccl"], - default="ring", + default="nccl" if mx.cuda.is_available() else "ring", help="Which distributed backend to launch", ) parser.add_argument( diff --git a/python/scripts/repair_cuda.sh b/python/scripts/repair_cuda.sh index e83fc4406..9f8cd9b0d 100644 --- a/python/scripts/repair_cuda.sh +++ b/python/scripts/repair_cuda.sh @@ -6,6 +6,7 @@ auditwheel repair dist/* \ --exclude libnvrtc* \ --exclude libcuda* \ --exclude libcudnn* \ + --exclude libnccl* \ -w wheel_tmp @@ -17,7 +18,7 @@ rm "${repaired_wheel}" mlx_so="mlx/lib/libmlx.so" rpath=$(patchelf --print-rpath "${mlx_so}") base="\$ORIGIN/../../nvidia" -rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib:${base}/cudnn/lib +rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib:${base}/cudnn/lib:${base}/nccl/lib patchelf --force-rpath --set-rpath "$rpath" "$mlx_so" python ../python/scripts/repair_record.py ${mlx_so} diff --git a/setup.py b/setup.py index d15f36b2b..72646e9ad 100644 --- a/setup.py +++ b/setup.py @@ -297,6 +297,7 @@ if __name__ == "__main__": "nvidia-cublas-cu12==12.9.*", "nvidia-cuda-nvrtc-cu12==12.9.*", "nvidia-cudnn-cu12==9.*", + "nvidia-nccl-cu12", ] else: name = "mlx-cpu"