nccl default for backend=any (#2528)

* nccl default for backend=any

* check num gpus + ensure row contiguous for all reduce

* comment
This commit is contained in:
Awni Hannun 2025-08-22 12:24:27 -07:00 committed by GitHub
parent 5722c147de
commit 068a4612e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 68 additions and 31 deletions

View File

@ -405,6 +405,7 @@ jobs:
sudo dpkg -i cuda-keyring_1.1-1_all.deb sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update sudo apt-get update
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12 sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
sudo apt-get install libnccl2 libnccl-dev
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install zip sudo apt-get install zip
pip install auditwheel pip install auditwheel

View File

@ -2,30 +2,35 @@
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/distributed/primitives.h" #include "mlx/distributed/primitives.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include <cassert> #include <cassert>
namespace mlx::core { namespace mlx::core::distributed {
namespace distributed {
void AllReduce::eval_gpu( void AllReduce::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
assert(outputs.size() == 1); assert(outputs.size() == 1);
auto& input = inputs[0]; auto set_input_output =
auto& output = outputs[0]; [s = stream()](const array& in, array& out) -> std::pair<array, array> {
if (!in.flags().row_contiguous) {
copy_gpu(in, out, CopyType::General, s);
return {out, out};
} else if (in.is_donatable()) {
out.copy_shared_buffer(in);
return {in, out};
} else {
return {in, out};
}
};
auto [input, output] = set_input_output(inputs[0], outputs[0]);
auto& encoder = cu::get_command_encoder(stream()); auto& encoder = cu::get_command_encoder(stream());
if (input.is_donatable()) {
output.copy_shared_buffer(input);
} else {
output.set_data(allocator::malloc(output.nbytes()));
}
encoder.set_input_array(input); encoder.set_input_array(input);
encoder.set_output_array(output); encoder.set_output_array(output);
@ -47,5 +52,4 @@ void AllReduce::eval_gpu(
"Only all reduce sum, max, and min are supported."); "Only all reduce sum, max, and min are supported.");
} }
} }
} // namespace distributed } // namespace mlx::core::distributed
} // namespace mlx::core

View File

@ -2,6 +2,7 @@
#include <unordered_map> #include <unordered_map>
#include "mlx/backend/cuda/cuda.h"
#include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/mpi/mpi.h" #include "mlx/distributed/mpi/mpi.h"
@ -114,7 +115,7 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
} }
// Create the requested communication group // Create the requested communication group
std::shared_ptr<detail::GroupImpl> group; std::shared_ptr<detail::GroupImpl> group{nullptr};
std::string bk_ = bk; std::string bk_ = bk;
if (bk == "mpi") { if (bk == "mpi") {
group = mpi::init(strict); group = mpi::init(strict);
@ -123,8 +124,14 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
} else if (bk == "nccl") { } else if (bk == "nccl") {
group = nccl::init(strict); group = nccl::init(strict);
} else if (bk == "any") { } else if (bk == "any") {
if (mlx::core::cu::is_available()) {
group = nccl::init(false);
bk_ = "nccl";
}
if (group == nullptr) {
group = ring::init(false); group = ring::init(false);
bk_ = "ring"; bk_ = "ring";
}
if (group == nullptr) { if (group == nullptr) {
group = mpi::init(false); group = mpi::init(false);
bk_ = "mpi"; bk_ = "mpi";

View File

@ -204,14 +204,18 @@ inline void bootstrap_unique_id(
int attempt = 0; int attempt = 0;
bool connected = false; bool connected = false;
bool do_log = std::getenv("NCCL_DEBUG") == "INFO";
for (attempt = 0; attempt < max_retries; ++attempt) { for (attempt = 0; attempt < max_retries; ++attempt) {
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) == if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
0) { 0) {
connected = true; connected = true;
std::cout << "[Rank " << rank << "] Connected successfully on attempt " if (do_log) {
<< attempt + 1 << std::endl; std::cout << "[Rank " << rank
<< "] Connected successfully on attempt " << attempt + 1
<< std::endl;
break; break;
} }
}
if (errno != ECONNREFUSED) { if (errno != ECONNREFUSED) {
break; break;
} }
@ -331,24 +335,33 @@ bool is_available() {
} }
namespace detail { namespace detail {
static std::string get_env_var_or_throw(const char* env_var_name) { std::string get_env_var_or_throw(const char* env_var_name, bool strict) {
const char* value = std::getenv(env_var_name); const char* value = std::getenv(env_var_name);
if (value == nullptr) { if (value == nullptr && strict) {
std::ostringstream msg; std::ostringstream msg;
msg << "[nccl] Required environment variable '" << env_var_name msg << "[nccl] Required environment variable '" << env_var_name
<< "' is not set. " << "' is not set. "
<< "Please set it before initializing the distributed backend."; << "Please set it before initializing the distributed backend.";
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
if (value == nullptr) {
return "";
}
return std::string(value); return std::string(value);
} }
} // namespace detail } // namespace detail
std::shared_ptr<GroupImpl> init(bool strict /* = false */) { std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP"); std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP", strict);
std::string port = detail::get_env_var_or_throw("NCCL_PORT"); std::string port = detail::get_env_var_or_throw("NCCL_PORT", strict);
std::string rank_str = detail::get_env_var_or_throw("MLX_RANK"); std::string rank_str = detail::get_env_var_or_throw("MLX_RANK", strict);
std::string n_nodes_str = detail::get_env_var_or_throw("MLX_WORLD_SIZE"); std::string n_nodes_str =
detail::get_env_var_or_throw("MLX_WORLD_SIZE", strict);
if (!strict &&
(host.empty() || port.empty() || rank_str.empty() ||
n_nodes_str.empty())) {
return nullptr;
}
int rank = std::stoi(rank_str); int rank = std::stoi(rank_str);
int n_nodes = std::stoi(n_nodes_str); int n_nodes = std::stoi(n_nodes_str);

View File

@ -55,6 +55,11 @@ def parse_hardware_ports(ports_string):
return ports return ports
def get_num_nvidia_gpus():
result = run(["nvidia-smi", "-L"], capture_output=True, text=True, check=True)
return len(result.stdout.strip().split("\n"))
def extract_rings(hosts, index): def extract_rings(hosts, index):
def usable_port(i, j, used_ports): def usable_port(i, j, used_ports):
return (i, j) not in used_ports and hosts[i].ports[j].connected_to is not None return (i, j) not in used_ports and hosts[i].ports[j].connected_to is not None
@ -428,7 +433,9 @@ def launch_nccl(parser, hosts, args, command):
base_env = os.environ.copy() base_env = os.environ.copy()
base_env.update( base_env.update(
{ {
"NCCL_DEBUG": "INFO", "NCCL_DEBUG": base_env.get(
"NCCL_DEBUG", "INFO" if args.verbose else "DEBUG"
),
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication "NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
"NCCL_HOST_IP": master_host, "NCCL_HOST_IP": master_host,
"NCCL_PORT": str(master_port), "NCCL_PORT": str(master_port),
@ -436,11 +443,18 @@ def launch_nccl(parser, hosts, args, command):
} }
) )
procs = [] procs = []
num_gpus = get_num_nvidia_gpus()
if num_gpus == 0:
raise RuntimeError("Cannot run NCCL backend with no GPUs.")
if args.repeat_hosts > num_gpus:
raise RuntimeError("NCCL requires a separate GPU per process.")
try: try:
for rank in range(world_size): for rank in range(world_size):
env = base_env.copy() env = base_env.copy()
env["MLX_RANK"] = str(rank % args.repeat_hosts) mlx_rank = str(rank % args.repeat_hosts)
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.repeat_hosts) env["MLX_RANK"] = mlx_rank
env["CUDA_VISIBLE_DEVICES"] = mlx_rank
p = Popen(command, env=env) p = Popen(command, env=env)
procs.append(p) procs.append(p)
@ -821,8 +835,6 @@ def main():
) )
args, rest = parser.parse_known_args() args, rest = parser.parse_known_args()
if rest[0] == "--":
rest.pop(0)
if args.print_python: if args.print_python:
print(sys.executable) print(sys.executable)