mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-23 22:06:40 +08:00
nccl dep + default for cuda (#2526)
This commit is contained in:
parent
9392fc3f88
commit
f93f87c802
@ -20,6 +20,8 @@ from select import select
|
||||
from subprocess import PIPE, Popen, run
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
@dataclass
|
||||
class Host:
|
||||
@ -437,8 +439,8 @@ def launch_nccl(parser, hosts, args, command):
|
||||
try:
|
||||
for rank in range(world_size):
|
||||
env = base_env.copy()
|
||||
env["MLX_RANK"] = str(rank)
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.nproc_per_node)
|
||||
env["MLX_RANK"] = str(rank % args.repeat_hosts)
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.repeat_hosts)
|
||||
p = Popen(command, env=env)
|
||||
procs.append(p)
|
||||
|
||||
@ -708,7 +710,7 @@ def distributed_config():
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
choices=["ring", "mpi", "nccl"],
|
||||
default="ring",
|
||||
default="nccl" if mx.cuda.is_available() else "ring",
|
||||
help="Which distributed backend to configure",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -780,7 +782,7 @@ def main():
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
choices=["ring", "mpi", "nccl"],
|
||||
default="ring",
|
||||
default="nccl" if mx.cuda.is_available() else "ring",
|
||||
help="Which distributed backend to launch",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
@ -6,6 +6,7 @@ auditwheel repair dist/* \
|
||||
--exclude libnvrtc* \
|
||||
--exclude libcuda* \
|
||||
--exclude libcudnn* \
|
||||
--exclude libnccl* \
|
||||
-w wheel_tmp
|
||||
|
||||
|
||||
@ -17,7 +18,7 @@ rm "${repaired_wheel}"
|
||||
mlx_so="mlx/lib/libmlx.so"
|
||||
rpath=$(patchelf --print-rpath "${mlx_so}")
|
||||
base="\$ORIGIN/../../nvidia"
|
||||
rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib:${base}/cudnn/lib
|
||||
rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib:${base}/cudnn/lib:${base}/nccl/lib
|
||||
patchelf --force-rpath --set-rpath "$rpath" "$mlx_so"
|
||||
python ../python/scripts/repair_record.py ${mlx_so}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user