From 51505c2d5a60661e6d8bd1e6df475cb1d522c6dd Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 22 Aug 2025 09:39:36 -0700 Subject: [PATCH] check num gpus + ensure row contiguous for all reduce --- mlx/backend/cuda/distributed.cu | 30 +++++++++++++++++------------- python/mlx/distributed_run.py | 18 +++++++++++++++--- 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/mlx/backend/cuda/distributed.cu b/mlx/backend/cuda/distributed.cu index 2cdf615f5..dba168a68 100644 --- a/mlx/backend/cuda/distributed.cu +++ b/mlx/backend/cuda/distributed.cu @@ -1,5 +1,6 @@ // Copyright © 2025 Apple Inc. +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/distributed/primitives.h" @@ -7,25 +8,29 @@ #include -namespace mlx::core { -namespace distributed { +namespace mlx::core::distributed { void AllReduce::eval_gpu( const std::vector& inputs, std::vector& outputs) { assert(inputs.size() == 1); assert(outputs.size() == 1); - auto& input = inputs[0]; - auto& output = outputs[0]; + + auto set_input_output = [s = stream()](const array& in, array& out) -> std::pair { + if (!in.flags().row_contiguous) { + copy_gpu(in, out, CopyType::General, s); + return {out, out}; + } else if (in.is_donatable()) { + out.copy_shared_buffer(in); + return {in, out}; + } else { + return {in, out}; + } + }; + + auto [input, output] = set_input_output(inputs[0], outputs[0]); auto& encoder = cu::get_command_encoder(stream()); - - if (input.is_donatable()) { - output.copy_shared_buffer(input); - } else { - output.set_data(allocator::malloc(output.nbytes())); - } - encoder.set_input_array(input); encoder.set_output_array(output); @@ -47,5 +52,4 @@ void AllReduce::eval_gpu( "Only all reduce sum, max, and min are supported."); } } -} // namespace distributed -} // namespace mlx::core \ No newline at end of file +} // namespace mlx::core::distributed diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index a588326de..31274d4a9 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -55,6 +55,11 @@ def parse_hardware_ports(ports_string): return ports +def get_num_nvidia_gpus(): + result = run(['nvidia-smi', "-L"], capture_output=True, text=True, check=True) + return len(result.stdout.strip().split("\n")) + + def extract_rings(hosts, index): def usable_port(i, j, used_ports): return (i, j) not in used_ports and hosts[i].ports[j].connected_to is not None @@ -421,7 +426,7 @@ def launch_nccl(parser, hosts, args, command): master_host = hosts[0].ips[0] if master_host != "127.0.0.1": - raise ValueError("The NCCL backend only supports localhost for now. ") + raise ValueError("The NCCL backend only supports localhost for now.") master_port = args.nccl_port world_size = len(hosts) @@ -436,11 +441,18 @@ def launch_nccl(parser, hosts, args, command): } ) procs = [] + num_gpus = get_num_nvidia_gpus() + if num_gpus == 0: + raise RuntimeError("Cannot run NCCL backend with no GPUs.") + if args.repeat_hosts > num_gpus: + raise RuntimeError("NCCL requires a separate GPU per process.") + try: for rank in range(world_size): env = base_env.copy() - env["MLX_RANK"] = str(rank % args.repeat_hosts) - env["CUDA_VISIBLE_DEVICES"] = str(rank % args.repeat_hosts) + mlx_rank = str(rank % args.repeat_hosts) + env["MLX_RANK"] = mlx_rank + env["CUDA_VISIBLE_DEVICES"] = mlx_rank p = Popen(command, env=env) procs.append(p)