nccl dep + default for cuda (#2526)

This commit is contained in:
Awni Hannun
2025-08-21 17:57:49 -07:00
committed by GitHub
parent 9392fc3f88
commit f93f87c802
3 changed files with 9 additions and 5 deletions

View File

@@ -20,6 +20,8 @@ from select import select
from subprocess import PIPE, Popen, run
from typing import Optional
import mlx.core as mx
@dataclass
class Host:
@@ -437,8 +439,8 @@ def launch_nccl(parser, hosts, args, command):
try:
for rank in range(world_size):
env = base_env.copy()
env["MLX_RANK"] = str(rank)
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.nproc_per_node)
env["MLX_RANK"] = str(rank % args.repeat_hosts)
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.repeat_hosts)
p = Popen(command, env=env)
procs.append(p)
@@ -708,7 +710,7 @@ def distributed_config():
parser.add_argument(
"--backend",
choices=["ring", "mpi", "nccl"],
default="ring",
default="nccl" if mx.cuda.is_available() else "ring",
help="Which distributed backend to configure",
)
parser.add_argument(
@@ -780,7 +782,7 @@ def main():
parser.add_argument(
"--backend",
choices=["ring", "mpi", "nccl"],
default="ring",
default="nccl" if mx.cuda.is_available() else "ring",
help="Which distributed backend to launch",
)
parser.add_argument(