mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-29 01:19:54 +08:00
check num gpus + ensure row contiguous for all reduce
This commit is contained in:
parent
1eb589cd77
commit
51505c2d5a
@ -1,5 +1,6 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
@ -7,25 +8,29 @@
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
namespace distributed {
|
||||
namespace mlx::core::distributed {
|
||||
void AllReduce::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& input = inputs[0];
|
||||
auto& output = outputs[0];
|
||||
|
||||
auto set_input_output = [s = stream()](const array& in, array& out) -> std::pair<array, array> {
|
||||
if (!in.flags().row_contiguous) {
|
||||
copy_gpu(in, out, CopyType::General, s);
|
||||
return {out, out};
|
||||
} else if (in.is_donatable()) {
|
||||
out.copy_shared_buffer(in);
|
||||
return {in, out};
|
||||
} else {
|
||||
return {in, out};
|
||||
}
|
||||
};
|
||||
|
||||
auto [input, output] = set_input_output(inputs[0], outputs[0]);
|
||||
|
||||
auto& encoder = cu::get_command_encoder(stream());
|
||||
|
||||
if (input.is_donatable()) {
|
||||
output.copy_shared_buffer(input);
|
||||
} else {
|
||||
output.set_data(allocator::malloc(output.nbytes()));
|
||||
}
|
||||
|
||||
encoder.set_input_array(input);
|
||||
encoder.set_output_array(output);
|
||||
|
||||
@ -47,5 +52,4 @@ void AllReduce::eval_gpu(
|
||||
"Only all reduce sum, max, and min are supported.");
|
||||
}
|
||||
}
|
||||
} // namespace distributed
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core::distributed
|
||||
|
@ -55,6 +55,11 @@ def parse_hardware_ports(ports_string):
|
||||
return ports
|
||||
|
||||
|
||||
def get_num_nvidia_gpus():
|
||||
result = run(['nvidia-smi', "-L"], capture_output=True, text=True, check=True)
|
||||
return len(result.stdout.strip().split("\n"))
|
||||
|
||||
|
||||
def extract_rings(hosts, index):
|
||||
def usable_port(i, j, used_ports):
|
||||
return (i, j) not in used_ports and hosts[i].ports[j].connected_to is not None
|
||||
@ -421,7 +426,7 @@ def launch_nccl(parser, hosts, args, command):
|
||||
master_host = hosts[0].ips[0]
|
||||
|
||||
if master_host != "127.0.0.1":
|
||||
raise ValueError("The NCCL backend only supports localhost for now. ")
|
||||
raise ValueError("The NCCL backend only supports localhost for now.")
|
||||
master_port = args.nccl_port
|
||||
world_size = len(hosts)
|
||||
|
||||
@ -436,11 +441,18 @@ def launch_nccl(parser, hosts, args, command):
|
||||
}
|
||||
)
|
||||
procs = []
|
||||
num_gpus = get_num_nvidia_gpus()
|
||||
if num_gpus == 0:
|
||||
raise RuntimeError("Cannot run NCCL backend with no GPUs.")
|
||||
if args.repeat_hosts > num_gpus:
|
||||
raise RuntimeError("NCCL requires a separate GPU per process.")
|
||||
|
||||
try:
|
||||
for rank in range(world_size):
|
||||
env = base_env.copy()
|
||||
env["MLX_RANK"] = str(rank % args.repeat_hosts)
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.repeat_hosts)
|
||||
mlx_rank = str(rank % args.repeat_hosts)
|
||||
env["MLX_RANK"] = mlx_rank
|
||||
env["CUDA_VISIBLE_DEVICES"] = mlx_rank
|
||||
p = Popen(command, env=env)
|
||||
procs.append(p)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user