check num gpus + ensure row contiguous for all reduce

This commit is contained in:
Awni Hannun 2025-08-22 09:39:36 -07:00
parent 1eb589cd77
commit 51505c2d5a
2 changed files with 32 additions and 16 deletions

View File

@ -1,5 +1,6 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/distributed/primitives.h"
@ -7,25 +8,29 @@
#include <cassert>
namespace mlx::core {
namespace distributed {
namespace mlx::core::distributed {
void AllReduce::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& input = inputs[0];
auto& output = outputs[0];
auto set_input_output = [s = stream()](const array& in, array& out) -> std::pair<array, array> {
if (!in.flags().row_contiguous) {
copy_gpu(in, out, CopyType::General, s);
return {out, out};
} else if (in.is_donatable()) {
out.copy_shared_buffer(in);
return {in, out};
} else {
return {in, out};
}
};
auto [input, output] = set_input_output(inputs[0], outputs[0]);
auto& encoder = cu::get_command_encoder(stream());
if (input.is_donatable()) {
output.copy_shared_buffer(input);
} else {
output.set_data(allocator::malloc(output.nbytes()));
}
encoder.set_input_array(input);
encoder.set_output_array(output);
@ -47,5 +52,4 @@ void AllReduce::eval_gpu(
"Only all reduce sum, max, and min are supported.");
}
}
} // namespace distributed
} // namespace mlx::core
} // namespace mlx::core::distributed

View File

@ -55,6 +55,11 @@ def parse_hardware_ports(ports_string):
return ports
def get_num_nvidia_gpus():
result = run(['nvidia-smi', "-L"], capture_output=True, text=True, check=True)
return len(result.stdout.strip().split("\n"))
def extract_rings(hosts, index):
def usable_port(i, j, used_ports):
return (i, j) not in used_ports and hosts[i].ports[j].connected_to is not None
@ -421,7 +426,7 @@ def launch_nccl(parser, hosts, args, command):
master_host = hosts[0].ips[0]
if master_host != "127.0.0.1":
raise ValueError("The NCCL backend only supports localhost for now. ")
raise ValueError("The NCCL backend only supports localhost for now.")
master_port = args.nccl_port
world_size = len(hosts)
@ -436,11 +441,18 @@ def launch_nccl(parser, hosts, args, command):
}
)
procs = []
num_gpus = get_num_nvidia_gpus()
if num_gpus == 0:
raise RuntimeError("Cannot run NCCL backend with no GPUs.")
if args.repeat_hosts > num_gpus:
raise RuntimeError("NCCL requires a separate GPU per process.")
try:
for rank in range(world_size):
env = base_env.copy()
env["MLX_RANK"] = str(rank % args.repeat_hosts)
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.repeat_hosts)
mlx_rank = str(rank % args.repeat_hosts)
env["MLX_RANK"] = mlx_rank
env["CUDA_VISIBLE_DEVICES"] = mlx_rank
p = Popen(command, env=env)
procs.append(p)