mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-24 06:16:38 +08:00
nccl dep + default for cuda (#2526)
This commit is contained in:
parent
9392fc3f88
commit
f93f87c802
@ -20,6 +20,8 @@ from select import select
|
|||||||
from subprocess import PIPE, Popen, run
|
from subprocess import PIPE, Popen, run
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Host:
|
class Host:
|
||||||
@ -437,8 +439,8 @@ def launch_nccl(parser, hosts, args, command):
|
|||||||
try:
|
try:
|
||||||
for rank in range(world_size):
|
for rank in range(world_size):
|
||||||
env = base_env.copy()
|
env = base_env.copy()
|
||||||
env["MLX_RANK"] = str(rank)
|
env["MLX_RANK"] = str(rank % args.repeat_hosts)
|
||||||
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.nproc_per_node)
|
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.repeat_hosts)
|
||||||
p = Popen(command, env=env)
|
p = Popen(command, env=env)
|
||||||
procs.append(p)
|
procs.append(p)
|
||||||
|
|
||||||
@ -708,7 +710,7 @@ def distributed_config():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend",
|
||||||
choices=["ring", "mpi", "nccl"],
|
choices=["ring", "mpi", "nccl"],
|
||||||
default="ring",
|
default="nccl" if mx.cuda.is_available() else "ring",
|
||||||
help="Which distributed backend to configure",
|
help="Which distributed backend to configure",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -780,7 +782,7 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend",
|
||||||
choices=["ring", "mpi", "nccl"],
|
choices=["ring", "mpi", "nccl"],
|
||||||
default="ring",
|
default="nccl" if mx.cuda.is_available() else "ring",
|
||||||
help="Which distributed backend to launch",
|
help="Which distributed backend to launch",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -6,6 +6,7 @@ auditwheel repair dist/* \
|
|||||||
--exclude libnvrtc* \
|
--exclude libnvrtc* \
|
||||||
--exclude libcuda* \
|
--exclude libcuda* \
|
||||||
--exclude libcudnn* \
|
--exclude libcudnn* \
|
||||||
|
--exclude libnccl* \
|
||||||
-w wheel_tmp
|
-w wheel_tmp
|
||||||
|
|
||||||
|
|
||||||
@ -17,7 +18,7 @@ rm "${repaired_wheel}"
|
|||||||
mlx_so="mlx/lib/libmlx.so"
|
mlx_so="mlx/lib/libmlx.so"
|
||||||
rpath=$(patchelf --print-rpath "${mlx_so}")
|
rpath=$(patchelf --print-rpath "${mlx_so}")
|
||||||
base="\$ORIGIN/../../nvidia"
|
base="\$ORIGIN/../../nvidia"
|
||||||
rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib:${base}/cudnn/lib
|
rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib:${base}/cudnn/lib:${base}/nccl/lib
|
||||||
patchelf --force-rpath --set-rpath "$rpath" "$mlx_so"
|
patchelf --force-rpath --set-rpath "$rpath" "$mlx_so"
|
||||||
python ../python/scripts/repair_record.py ${mlx_so}
|
python ../python/scripts/repair_record.py ${mlx_so}
|
||||||
|
|
||||||
|
1
setup.py
1
setup.py
@ -297,6 +297,7 @@ if __name__ == "__main__":
|
|||||||
"nvidia-cublas-cu12==12.9.*",
|
"nvidia-cublas-cu12==12.9.*",
|
||||||
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
||||||
"nvidia-cudnn-cu12==9.*",
|
"nvidia-cudnn-cu12==9.*",
|
||||||
|
"nvidia-nccl-cu12",
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
name = "mlx-cpu"
|
name = "mlx-cpu"
|
||||||
|
Loading…
Reference in New Issue
Block a user