nccl dep + default for cuda (#2526)

This commit is contained in:
Awni Hannun 2025-08-21 17:57:49 -07:00 committed by GitHub
parent 9392fc3f88
commit f93f87c802
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 5 deletions

View File

@ -20,6 +20,8 @@ from select import select
from subprocess import PIPE, Popen, run from subprocess import PIPE, Popen, run
from typing import Optional from typing import Optional
import mlx.core as mx
@dataclass @dataclass
class Host: class Host:
@ -437,8 +439,8 @@ def launch_nccl(parser, hosts, args, command):
try: try:
for rank in range(world_size): for rank in range(world_size):
env = base_env.copy() env = base_env.copy()
env["MLX_RANK"] = str(rank) env["MLX_RANK"] = str(rank % args.repeat_hosts)
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.nproc_per_node) env["CUDA_VISIBLE_DEVICES"] = str(rank % args.repeat_hosts)
p = Popen(command, env=env) p = Popen(command, env=env)
procs.append(p) procs.append(p)
@ -708,7 +710,7 @@ def distributed_config():
parser.add_argument( parser.add_argument(
"--backend", "--backend",
choices=["ring", "mpi", "nccl"], choices=["ring", "mpi", "nccl"],
default="ring", default="nccl" if mx.cuda.is_available() else "ring",
help="Which distributed backend to configure", help="Which distributed backend to configure",
) )
parser.add_argument( parser.add_argument(
@ -780,7 +782,7 @@ def main():
parser.add_argument( parser.add_argument(
"--backend", "--backend",
choices=["ring", "mpi", "nccl"], choices=["ring", "mpi", "nccl"],
default="ring", default="nccl" if mx.cuda.is_available() else "ring",
help="Which distributed backend to launch", help="Which distributed backend to launch",
) )
parser.add_argument( parser.add_argument(

View File

@ -6,6 +6,7 @@ auditwheel repair dist/* \
--exclude libnvrtc* \ --exclude libnvrtc* \
--exclude libcuda* \ --exclude libcuda* \
--exclude libcudnn* \ --exclude libcudnn* \
--exclude libnccl* \
-w wheel_tmp -w wheel_tmp
@ -17,7 +18,7 @@ rm "${repaired_wheel}"
mlx_so="mlx/lib/libmlx.so" mlx_so="mlx/lib/libmlx.so"
rpath=$(patchelf --print-rpath "${mlx_so}") rpath=$(patchelf --print-rpath "${mlx_so}")
base="\$ORIGIN/../../nvidia" base="\$ORIGIN/../../nvidia"
rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib:${base}/cudnn/lib rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib:${base}/cudnn/lib:${base}/nccl/lib
patchelf --force-rpath --set-rpath "$rpath" "$mlx_so" patchelf --force-rpath --set-rpath "$rpath" "$mlx_so"
python ../python/scripts/repair_record.py ${mlx_so} python ../python/scripts/repair_record.py ${mlx_so}

View File

@ -297,6 +297,7 @@ if __name__ == "__main__":
"nvidia-cublas-cu12==12.9.*", "nvidia-cublas-cu12==12.9.*",
"nvidia-cuda-nvrtc-cu12==12.9.*", "nvidia-cuda-nvrtc-cu12==12.9.*",
"nvidia-cudnn-cu12==9.*", "nvidia-cudnn-cu12==9.*",
"nvidia-nccl-cu12",
] ]
else: else:
name = "mlx-cpu" name = "mlx-cpu"