nccl default for backend=any (#2528)

* nccl default for backend=any

* check num gpus + ensure row contiguous for all reduce

* comment
This commit is contained in:
Awni Hannun
2025-08-22 12:24:27 -07:00
committed by GitHub
parent 5722c147de
commit 068a4612e9
5 changed files with 68 additions and 31 deletions

View File

@@ -55,6 +55,11 @@ def parse_hardware_ports(ports_string):
return ports
def get_num_nvidia_gpus():
result = run(["nvidia-smi", "-L"], capture_output=True, text=True, check=True)
return len(result.stdout.strip().split("\n"))
def extract_rings(hosts, index):
def usable_port(i, j, used_ports):
return (i, j) not in used_ports and hosts[i].ports[j].connected_to is not None
@@ -421,14 +426,16 @@ def launch_nccl(parser, hosts, args, command):
master_host = hosts[0].ips[0]
if master_host != "127.0.0.1":
raise ValueError("The NCCL backend only supports localhost for now. ")
raise ValueError("The NCCL backend only supports localhost for now.")
master_port = args.nccl_port
world_size = len(hosts)
base_env = os.environ.copy()
base_env.update(
{
"NCCL_DEBUG": "INFO",
"NCCL_DEBUG": base_env.get(
"NCCL_DEBUG", "INFO" if args.verbose else "DEBUG"
),
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
"NCCL_HOST_IP": master_host,
"NCCL_PORT": str(master_port),
@@ -436,11 +443,18 @@ def launch_nccl(parser, hosts, args, command):
}
)
procs = []
num_gpus = get_num_nvidia_gpus()
if num_gpus == 0:
raise RuntimeError("Cannot run NCCL backend with no GPUs.")
if args.repeat_hosts > num_gpus:
raise RuntimeError("NCCL requires a separate GPU per process.")
try:
for rank in range(world_size):
env = base_env.copy()
env["MLX_RANK"] = str(rank % args.repeat_hosts)
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.repeat_hosts)
mlx_rank = str(rank % args.repeat_hosts)
env["MLX_RANK"] = mlx_rank
env["CUDA_VISIBLE_DEVICES"] = mlx_rank
p = Popen(command, env=env)
procs.append(p)
@@ -821,8 +835,6 @@ def main():
)
args, rest = parser.parse_known_args()
if rest[0] == "--":
rest.pop(0)
if args.print_python:
print(sys.executable)