Compare commits

..

3 Commits

Author SHA1 Message Date
Angelos Katharopoulos
bf9456f6cc Change the name to a fun pun 2025-11-20 17:48:23 -08:00
Angelos Katharopoulos
704f81c03d Add headers for gcc 2025-11-20 17:31:02 -08:00
Angelos Katharopoulos
df6b23156f Expose per-backend availability in C++ and python 2025-11-20 15:26:59 -08:00
10 changed files with 83 additions and 42 deletions

View File

@@ -11,4 +11,4 @@ endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ibv) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/jaccl)

View File

@@ -5,7 +5,7 @@
#include "mlx/backend/cuda/cuda.h" #include "mlx/backend/cuda/cuda.h"
#include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/ibv/ibv.h" #include "mlx/distributed/jaccl/jaccl.h"
#include "mlx/distributed/mpi/mpi.h" #include "mlx/distributed/mpi/mpi.h"
#include "mlx/distributed/nccl/nccl.h" #include "mlx/distributed/nccl/nccl.h"
#include "mlx/distributed/ring/ring.h" #include "mlx/distributed/ring/ring.h"
@@ -104,7 +104,26 @@ class EmptyGroup : public GroupImpl {
bool is_available() { bool is_available() {
return mpi::is_available() || ring::is_available() || nccl::is_available() || return mpi::is_available() || ring::is_available() || nccl::is_available() ||
ibv::is_available(); jaccl::is_available();
}
bool is_available(const std::string& bk) {
if (bk == "any") {
return is_available();
}
if (bk == "mpi") {
return mpi::is_available();
}
if (bk == "ring") {
return ring::is_available();
}
if (bk == "nccl") {
return nccl::is_available();
}
if (bk == "jaccl") {
return jaccl::is_available();
}
return false;
} }
int Group::rank() const { int Group::rank() const {
@@ -137,8 +156,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
group = ring::init(strict); group = ring::init(strict);
} else if (bk == "nccl") { } else if (bk == "nccl") {
group = nccl::init(strict); group = nccl::init(strict);
} else if (bk == "ibv") { } else if (bk == "jaccl") {
group = ibv::init(strict); group = jaccl::init(strict);
} else if (bk == "any") { } else if (bk == "any") {
if (mlx::core::cu::is_available()) { if (mlx::core::cu::is_available()) {
group = nccl::init(false); group = nccl::init(false);
@@ -153,8 +172,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
bk_ = "mpi"; bk_ = "mpi";
} }
if (group == nullptr) { if (group == nullptr) {
group = ibv::init(false); group = jaccl::init(false);
bk_ = "ibv"; bk_ = "jaccl";
} }
if (group == nullptr && strict) { if (group == nullptr && strict) {
throw std::runtime_error("[distributed] Couldn't initialize any backend"); throw std::runtime_error("[distributed] Couldn't initialize any backend");
@@ -162,7 +181,7 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
} else { } else {
std::ostringstream msg; std::ostringstream msg;
msg << "[distributed] The only valid values for backend are 'any', 'mpi', 'nccl', " msg << "[distributed] The only valid values for backend are 'any', 'mpi', 'nccl', "
<< "'ibv' and 'ring' but '" << bk << "' was provided."; << "'jaccl' and 'ring' but '" << bk << "' was provided.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }

View File

@@ -16,6 +16,7 @@ class GroupImpl;
/* Check if a communication backend is available */ /* Check if a communication backend is available */
bool is_available(); bool is_available();
bool is_available(const std::string& bk);
/** /**
* A distributed::Group represents a group of independent mlx processes that * A distributed::Group represents a group of independent mlx processes that

View File

@@ -1,8 +1,8 @@
if(MLX_BUILD_CPU if(MLX_BUILD_CPU
AND ${CMAKE_SYSTEM_NAME} MATCHES "Darwin" AND ${CMAKE_SYSTEM_NAME} MATCHES "Darwin"
AND MACOS_SDK_VERSION GREATER_EQUAL 26.2) AND MACOS_SDK_VERSION GREATER_EQUAL 26.2)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ibv.cpp) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jaccl.cpp)
target_link_libraries(mlx PRIVATE rdma) target_link_libraries(mlx PRIVATE rdma)
else() else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_ibv.cpp) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_jaccl.cpp)
endif() endif()

View File

@@ -13,7 +13,7 @@
#include "mlx/distributed/utils.h" #include "mlx/distributed/utils.h"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
constexpr const char* IBV_TAG = "[ibv]"; constexpr const char* IBV_TAG = "[jaccl]";
constexpr int NUM_BUFFERS = 2; constexpr int NUM_BUFFERS = 2;
constexpr int BUFFER_SIZE = 4096; constexpr int BUFFER_SIZE = 4096;
constexpr int MAX_SEND_WR = 32; constexpr int MAX_SEND_WR = 32;
@@ -103,7 +103,7 @@ class SharedBuffer {
auto [it, inserted] = memory_regions_.insert({protection_domain, nullptr}); auto [it, inserted] = memory_regions_.insert({protection_domain, nullptr});
if (!inserted) { if (!inserted) {
throw std::runtime_error( throw std::runtime_error(
"[ibv] Buffer can be registered once per protection domain"); "[jaccl] Buffer can be registered once per protection domain");
} }
it->second = ibv_reg_mr( it->second = ibv_reg_mr(
@@ -113,7 +113,7 @@ class SharedBuffer {
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ |
IBV_ACCESS_REMOTE_WRITE); IBV_ACCESS_REMOTE_WRITE);
if (!it->second) { if (!it->second) {
throw std::runtime_error("[ibv] Register memory region failed"); throw std::runtime_error("[jaccl] Register memory region failed");
} }
} }
@@ -205,14 +205,14 @@ struct Connection {
void allocate_protection_domain() { void allocate_protection_domain() {
protection_domain = ibv_alloc_pd(ctx); protection_domain = ibv_alloc_pd(ctx);
if (protection_domain == nullptr) { if (protection_domain == nullptr) {
throw std::runtime_error("[ibv] Couldn't allocate protection domain"); throw std::runtime_error("[jaccl] Couldn't allocate protection domain");
} }
} }
void create_completion_queue(int num_entries) { void create_completion_queue(int num_entries) {
completion_queue = ibv_create_cq(ctx, num_entries, nullptr, nullptr, 0); completion_queue = ibv_create_cq(ctx, num_entries, nullptr, nullptr, 0);
if (completion_queue == nullptr) { if (completion_queue == nullptr) {
throw std::runtime_error("[ibv] Couldn't create completion queue"); throw std::runtime_error("[jaccl] Couldn't create completion queue");
} }
} }
@@ -234,7 +234,7 @@ struct Connection {
queue_pair = ibv_create_qp(protection_domain, &init_attr); queue_pair = ibv_create_qp(protection_domain, &init_attr);
if (queue_pair == nullptr) { if (queue_pair == nullptr) {
throw std::runtime_error("[ibv] Couldn't create queue pair"); throw std::runtime_error("[jaccl] Couldn't create queue pair");
} }
} }
@@ -269,7 +269,7 @@ struct Connection {
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) { if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
std::ostringstream msg; std::ostringstream msg;
msg << "[ibv] Changing queue pair to INIT failed with errno " << status; msg << "[jaccl] Changing queue pair to INIT failed with errno " << status;
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
} }
@@ -299,7 +299,7 @@ struct Connection {
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) { if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
std::ostringstream msg; std::ostringstream msg;
msg << "[ibv] Changing queue pair to RTR failed with errno " << status; msg << "[jaccl] Changing queue pair to RTR failed with errno " << status;
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
} }
@@ -313,7 +313,7 @@ struct Connection {
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) { if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
std::ostringstream msg; std::ostringstream msg;
msg << "[ibv] Changing queue pair to RTS failed with errno " << status; msg << "[jaccl] Changing queue pair to RTS failed with errno " << status;
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
} }
@@ -333,7 +333,7 @@ struct Connection {
ibv_post_send(queue_pair, &work_request, &bad_work_request); ibv_post_send(queue_pair, &work_request, &bad_work_request);
status != 0) { status != 0) {
std::ostringstream msg; std::ostringstream msg;
msg << "[ibv] Send failed with error code " << status; msg << "[jaccl] Send failed with error code " << status;
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
} }
@@ -351,7 +351,7 @@ struct Connection {
ibv_post_recv(queue_pair, &work_request, &bad_work_request); ibv_post_recv(queue_pair, &work_request, &bad_work_request);
status != 0) { status != 0) {
std::ostringstream msg; std::ostringstream msg;
msg << "[ibv] Recv failed with error code " << status; msg << "[jaccl] Recv failed with error code " << status;
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
} }
@@ -491,7 +491,7 @@ class ConnectionManager {
side_channel_(rank_, size_, coordinator_addr) { side_channel_(rank_, size_, coordinator_addr) {
create_contexts(device_names); create_contexts(device_names);
if (connections_[rank_].ctx != nullptr) { if (connections_[rank_].ctx != nullptr) {
throw std::runtime_error("[ibv] Malformed device file"); throw std::runtime_error("[jaccl] Malformed device file");
} }
} }
@@ -653,7 +653,7 @@ class ConnectionManager {
auto ctx = ibv_open_device(devices[i]); auto ctx = ibv_open_device(devices[i]);
if (ctx == nullptr) { if (ctx == nullptr) {
std::ostringstream msg; std::ostringstream msg;
msg << "[ibv] Could not open device " << name; msg << "[jaccl] Could not open device " << name;
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
connections_.emplace_back(ctx); connections_.emplace_back(ctx);
@@ -688,7 +688,7 @@ std::vector<std::string> load_device_names(int rank, const char* dev_file) {
return device_names; return device_names;
} }
namespace mlx::core::distributed::ibv { namespace mlx::core::distributed::jaccl {
class IBVGroup : public GroupImpl { class IBVGroup : public GroupImpl {
public: public:
@@ -926,11 +926,11 @@ class IBVGroup : public GroupImpl {
} }
void sum_scatter(const array& input, array& output, Stream stream) override { void sum_scatter(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[ring] sum_scatter not supported."); throw std::runtime_error("[jaccl] sum_scatter not supported.");
} }
std::shared_ptr<GroupImpl> split(int color, int key = -1) override { std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
throw std::runtime_error("[ibv] Group split not supported."); throw std::runtime_error("[jaccl] Group split not supported.");
} }
private: private:
@@ -1079,6 +1079,8 @@ class IBVGroup : public GroupImpl {
bool is_available() { bool is_available() {
if (__builtin_available(macOS 26.2, *)) { if (__builtin_available(macOS 26.2, *)) {
return true; return true;
} else {
return false;
} }
} }
@@ -1088,10 +1090,10 @@ std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
const char* rank_str = std::getenv("MLX_RANK"); const char* rank_str = std::getenv("MLX_RANK");
const char* ring_verbose = std::getenv("MLX_IBV_VERBOSE"); const char* ring_verbose = std::getenv("MLX_IBV_VERBOSE");
if (!dev_file || !coordinator || !rank_str) { if (!is_available() || !dev_file || !coordinator || !rank_str) {
if (strict) { if (strict) {
std::ostringstream msg; std::ostringstream msg;
msg << "[ibv] You need to provide via environment variables a rank (MLX_RANK), " msg << "[jaccl] You need to provide via environment variables a rank (MLX_RANK), "
<< "a device file (MLX_IBV_DEVICES) and a coordinator ip/port (MLX_IBV_COORDINATOR) " << "a device file (MLX_IBV_DEVICES) and a coordinator ip/port (MLX_IBV_COORDINATOR) "
<< "but provided MLX_RANK=\"" << ((rank_str) ? rank_str : "") << "but provided MLX_RANK=\"" << ((rank_str) ? rank_str : "")
<< "\", MLX_IBV_DEVICES=\"" << ((dev_file) ? dev_file : "") << "\", MLX_IBV_DEVICES=\"" << ((dev_file) ? dev_file : "")
@@ -1108,7 +1110,7 @@ std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
auto cm = ConnectionManager(rank, device_names, coordinator); auto cm = ConnectionManager(rank, device_names, coordinator);
if (cm.size() > MAX_PEERS) { if (cm.size() > MAX_PEERS) {
std::ostringstream msg; std::ostringstream msg;
msg << "[ibv] The maximum number of supported peers is " << MAX_PEERS msg << "[jaccl] The maximum number of supported peers is " << MAX_PEERS
<< " but " << cm.size() << " was provided"; << " but " << cm.size() << " was provided";
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
@@ -1119,4 +1121,4 @@ std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
return std::make_shared<IBVGroup>(std::move(cm)); return std::make_shared<IBVGroup>(std::move(cm));
} }
} // namespace mlx::core::distributed::ibv } // namespace mlx::core::distributed::jaccl

View File

@@ -2,11 +2,11 @@
#include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed.h"
namespace mlx::core::distributed::ibv { namespace mlx::core::distributed::jaccl {
using GroupImpl = mlx::core::distributed::detail::GroupImpl; using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available(); bool is_available();
std::shared_ptr<GroupImpl> init(bool strict = false); std::shared_ptr<GroupImpl> init(bool strict = false);
} // namespace mlx::core::distributed::ibv } // namespace mlx::core::distributed::jaccl

View File

@@ -1,8 +1,8 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/distributed/ibv/ibv.h" #include "mlx/distributed/jaccl/jaccl.h"
namespace mlx::core::distributed::ibv { namespace mlx::core::distributed::jaccl {
using GroupImpl = mlx::core::distributed::detail::GroupImpl; using GroupImpl = mlx::core::distributed::detail::GroupImpl;
@@ -12,9 +12,9 @@ bool is_available() {
std::shared_ptr<GroupImpl> init(bool strict /* = false */) { std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
if (strict) { if (strict) {
throw std::runtime_error("Cannot initialize ibv distributed backend."); throw std::runtime_error("Cannot initialize jaccl distributed backend.");
} }
return nullptr; return nullptr;
} }
} // namespace mlx::core::distributed::ibv } // namespace mlx::core::distributed::jaccl

View File

@@ -2,6 +2,7 @@
#include <netdb.h> #include <netdb.h>
#include <unistd.h> #include <unistd.h>
#include <cstring>
#include <sstream> #include <sstream>
#include <thread> #include <thread>
@@ -14,7 +15,7 @@ namespace mlx::core::distributed::detail {
*/ */
address_t parse_address(const std::string& ip, const std::string& port) { address_t parse_address(const std::string& ip, const std::string& port) {
struct addrinfo hints, *res; struct addrinfo hints, *res;
memset(&hints, 0, sizeof(hints)); std::memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC; hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM; hints.ai_socktype = SOCK_STREAM;

View File

@@ -3,6 +3,8 @@
#pragma once #pragma once
#include <sys/socket.h> #include <sys/socket.h>
#include <functional>
#include <string>
namespace mlx::core::distributed::detail { namespace mlx::core::distributed::detail {

View File

@@ -52,9 +52,25 @@ void init_distributed(nb::module_& parent_module) {
m.def( m.def(
"is_available", "is_available",
&mx::distributed::is_available, [](const std::string& backend) {
return mx::distributed::is_available(backend);
},
"backend"_a = "any",
nb::sig("def is_available(backend: str = 'any') -> bool"),
R"pbdoc( R"pbdoc(
Check if a communication backend is available. Check if a communication backend is available.
Note, this function returns whether MLX has the capability of
instantiating that distributed backend not whether it is possible to
create a communication group. For that purpose one should use
``init(strict=True)``.
Args:
backend (str, optional): The name of the backend to check for availability.
It takes the same values as ``init()``. Default: ``any``.
Returns:
bool: Whether the distributed backend is available.
)pbdoc"); )pbdoc");
m.def( m.def(
@@ -79,10 +95,10 @@ void init_distributed(nb::module_& parent_module) {
in case ``mx.distributed.is_available()`` returns False otherwise in case ``mx.distributed.is_available()`` returns False otherwise
it throws a runtime error. Default: ``False`` it throws a runtime error. Default: ``False``
backend (str, optional): Which distributed backend to initialize. backend (str, optional): Which distributed backend to initialize.
Possible values ``mpi``, ``ring``, ``nccl``, ``any``. If set to ``any`` all Possible values ``mpi``, ``ring``, ``nccl``, ``jaccl``, ``any``. If
available backends are tried and the first one that succeeds set to ``any`` all available backends are tried and the first one
becomes the global group which will be returned in subsequent that succeeds becomes the global group which will be returned in
calls. Default: ``any`` subsequent calls. Default: ``any``
Returns: Returns:
Group: The group representing all the launched processes. Group: The group representing all the launched processes.