Change the name to a fun pun

This commit is contained in:
Angelos Katharopoulos
2025-11-20 17:48:23 -08:00
parent 47af2c8cb0
commit 8fab4f0929
7 changed files with 38 additions and 38 deletions

View File

@@ -11,4 +11,4 @@ endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ibv) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/jaccl)

View File

@@ -5,7 +5,7 @@
#include "mlx/backend/cuda/cuda.h" #include "mlx/backend/cuda/cuda.h"
#include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/ibv/ibv.h" #include "mlx/distributed/jaccl/jaccl.h"
#include "mlx/distributed/mpi/mpi.h" #include "mlx/distributed/mpi/mpi.h"
#include "mlx/distributed/nccl/nccl.h" #include "mlx/distributed/nccl/nccl.h"
#include "mlx/distributed/ring/ring.h" #include "mlx/distributed/ring/ring.h"
@@ -104,7 +104,7 @@ class EmptyGroup : public GroupImpl {
bool is_available() { bool is_available() {
return mpi::is_available() || ring::is_available() || nccl::is_available() || return mpi::is_available() || ring::is_available() || nccl::is_available() ||
ibv::is_available(); jaccl::is_available();
} }
bool is_available(const std::string& bk) { bool is_available(const std::string& bk) {
@@ -120,8 +120,8 @@ bool is_available(const std::string& bk) {
if (bk == "nccl") { if (bk == "nccl") {
return nccl::is_available(); return nccl::is_available();
} }
if (bk == "ibv") { if (bk == "jaccl") {
return ibv::is_available(); return jaccl::is_available();
} }
return false; return false;
} }
@@ -156,8 +156,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
group = ring::init(strict); group = ring::init(strict);
} else if (bk == "nccl") { } else if (bk == "nccl") {
group = nccl::init(strict); group = nccl::init(strict);
} else if (bk == "ibv") { } else if (bk == "jaccl") {
group = ibv::init(strict); group = jaccl::init(strict);
} else if (bk == "any") { } else if (bk == "any") {
if (mlx::core::cu::is_available()) { if (mlx::core::cu::is_available()) {
group = nccl::init(false); group = nccl::init(false);
@@ -172,8 +172,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
bk_ = "mpi"; bk_ = "mpi";
} }
if (group == nullptr) { if (group == nullptr) {
group = ibv::init(false); group = jaccl::init(false);
bk_ = "ibv"; bk_ = "jaccl";
} }
if (group == nullptr && strict) { if (group == nullptr && strict) {
throw std::runtime_error("[distributed] Couldn't initialize any backend"); throw std::runtime_error("[distributed] Couldn't initialize any backend");
@@ -181,7 +181,7 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
} else { } else {
std::ostringstream msg; std::ostringstream msg;
msg << "[distributed] The only valid values for backend are 'any', 'mpi', 'nccl', " msg << "[distributed] The only valid values for backend are 'any', 'mpi', 'nccl', "
<< "'ibv' and 'ring' but '" << bk << "' was provided."; << "'jaccl' and 'ring' but '" << bk << "' was provided.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }

View File

@@ -1,8 +1,8 @@
if(MLX_BUILD_CPU if(MLX_BUILD_CPU
AND ${CMAKE_SYSTEM_NAME} MATCHES "Darwin" AND ${CMAKE_SYSTEM_NAME} MATCHES "Darwin"
AND MACOS_SDK_VERSION GREATER_EQUAL 26.2) AND MACOS_SDK_VERSION GREATER_EQUAL 26.2)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ibv.cpp) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jaccl.cpp)
target_link_libraries(mlx PRIVATE rdma) target_link_libraries(mlx PRIVATE rdma)
else() else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_ibv.cpp) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_jaccl.cpp)
endif() endif()

View File

@@ -13,7 +13,7 @@
#include "mlx/distributed/utils.h" #include "mlx/distributed/utils.h"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
constexpr const char* IBV_TAG = "[ibv]"; constexpr const char* IBV_TAG = "[jaccl]";
constexpr int NUM_BUFFERS = 2; constexpr int NUM_BUFFERS = 2;
constexpr int BUFFER_SIZE = 4096; constexpr int BUFFER_SIZE = 4096;
constexpr int MAX_SEND_WR = 32; constexpr int MAX_SEND_WR = 32;
@@ -103,7 +103,7 @@ class SharedBuffer {
auto [it, inserted] = memory_regions_.insert({protection_domain, nullptr}); auto [it, inserted] = memory_regions_.insert({protection_domain, nullptr});
if (!inserted) { if (!inserted) {
throw std::runtime_error( throw std::runtime_error(
"[ibv] Buffer can be registered once per protection domain"); "[jaccl] Buffer can be registered once per protection domain");
} }
it->second = ibv_reg_mr( it->second = ibv_reg_mr(
@@ -113,7 +113,7 @@ class SharedBuffer {
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ |
IBV_ACCESS_REMOTE_WRITE); IBV_ACCESS_REMOTE_WRITE);
if (!it->second) { if (!it->second) {
throw std::runtime_error("[ibv] Register memory region failed"); throw std::runtime_error("[jaccl] Register memory region failed");
} }
} }
@@ -205,14 +205,14 @@ struct Connection {
void allocate_protection_domain() { void allocate_protection_domain() {
protection_domain = ibv_alloc_pd(ctx); protection_domain = ibv_alloc_pd(ctx);
if (protection_domain == nullptr) { if (protection_domain == nullptr) {
throw std::runtime_error("[ibv] Couldn't allocate protection domain"); throw std::runtime_error("[jaccl] Couldn't allocate protection domain");
} }
} }
void create_completion_queue(int num_entries) { void create_completion_queue(int num_entries) {
completion_queue = ibv_create_cq(ctx, num_entries, nullptr, nullptr, 0); completion_queue = ibv_create_cq(ctx, num_entries, nullptr, nullptr, 0);
if (completion_queue == nullptr) { if (completion_queue == nullptr) {
throw std::runtime_error("[ibv] Couldn't create completion queue"); throw std::runtime_error("[jaccl] Couldn't create completion queue");
} }
} }
@@ -234,7 +234,7 @@ struct Connection {
queue_pair = ibv_create_qp(protection_domain, &init_attr); queue_pair = ibv_create_qp(protection_domain, &init_attr);
if (queue_pair == nullptr) { if (queue_pair == nullptr) {
throw std::runtime_error("[ibv] Couldn't create queue pair"); throw std::runtime_error("[jaccl] Couldn't create queue pair");
} }
} }
@@ -269,7 +269,7 @@ struct Connection {
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) { if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
std::ostringstream msg; std::ostringstream msg;
msg << "[ibv] Changing queue pair to INIT failed with errno " << status; msg << "[jaccl] Changing queue pair to INIT failed with errno " << status;
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
} }
@@ -299,7 +299,7 @@ struct Connection {
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) { if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
std::ostringstream msg; std::ostringstream msg;
msg << "[ibv] Changing queue pair to RTR failed with errno " << status; msg << "[jaccl] Changing queue pair to RTR failed with errno " << status;
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
} }
@@ -313,7 +313,7 @@ struct Connection {
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) { if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
std::ostringstream msg; std::ostringstream msg;
msg << "[ibv] Changing queue pair to RTS failed with errno " << status; msg << "[jaccl] Changing queue pair to RTS failed with errno " << status;
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
} }
@@ -333,7 +333,7 @@ struct Connection {
ibv_post_send(queue_pair, &work_request, &bad_work_request); ibv_post_send(queue_pair, &work_request, &bad_work_request);
status != 0) { status != 0) {
std::ostringstream msg; std::ostringstream msg;
msg << "[ibv] Send failed with error code " << status; msg << "[jaccl] Send failed with error code " << status;
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
} }
@@ -351,7 +351,7 @@ struct Connection {
ibv_post_recv(queue_pair, &work_request, &bad_work_request); ibv_post_recv(queue_pair, &work_request, &bad_work_request);
status != 0) { status != 0) {
std::ostringstream msg; std::ostringstream msg;
msg << "[ibv] Recv failed with error code " << status; msg << "[jaccl] Recv failed with error code " << status;
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
} }
@@ -491,7 +491,7 @@ class ConnectionManager {
side_channel_(rank_, size_, coordinator_addr) { side_channel_(rank_, size_, coordinator_addr) {
create_contexts(device_names); create_contexts(device_names);
if (connections_[rank_].ctx != nullptr) { if (connections_[rank_].ctx != nullptr) {
throw std::runtime_error("[ibv] Malformed device file"); throw std::runtime_error("[jaccl] Malformed device file");
} }
} }
@@ -653,7 +653,7 @@ class ConnectionManager {
auto ctx = ibv_open_device(devices[i]); auto ctx = ibv_open_device(devices[i]);
if (ctx == nullptr) { if (ctx == nullptr) {
std::ostringstream msg; std::ostringstream msg;
msg << "[ibv] Could not open device " << name; msg << "[jaccl] Could not open device " << name;
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
connections_.emplace_back(ctx); connections_.emplace_back(ctx);
@@ -688,7 +688,7 @@ std::vector<std::string> load_device_names(int rank, const char* dev_file) {
return device_names; return device_names;
} }
namespace mlx::core::distributed::ibv { namespace mlx::core::distributed::jaccl {
class IBVGroup : public GroupImpl { class IBVGroup : public GroupImpl {
public: public:
@@ -926,11 +926,11 @@ class IBVGroup : public GroupImpl {
} }
void sum_scatter(const array& input, array& output, Stream stream) override { void sum_scatter(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[ring] sum_scatter not supported."); throw std::runtime_error("[jaccl] sum_scatter not supported.");
} }
std::shared_ptr<GroupImpl> split(int color, int key = -1) override { std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
throw std::runtime_error("[ibv] Group split not supported."); throw std::runtime_error("[jaccl] Group split not supported.");
} }
private: private:
@@ -1093,7 +1093,7 @@ std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
if (!is_available() || !dev_file || !coordinator || !rank_str) { if (!is_available() || !dev_file || !coordinator || !rank_str) {
if (strict) { if (strict) {
std::ostringstream msg; std::ostringstream msg;
msg << "[ibv] You need to provide via environment variables a rank (MLX_RANK), " msg << "[jaccl] You need to provide via environment variables a rank (MLX_RANK), "
<< "a device file (MLX_IBV_DEVICES) and a coordinator ip/port (MLX_IBV_COORDINATOR) " << "a device file (MLX_IBV_DEVICES) and a coordinator ip/port (MLX_IBV_COORDINATOR) "
<< "but provided MLX_RANK=\"" << ((rank_str) ? rank_str : "") << "but provided MLX_RANK=\"" << ((rank_str) ? rank_str : "")
<< "\", MLX_IBV_DEVICES=\"" << ((dev_file) ? dev_file : "") << "\", MLX_IBV_DEVICES=\"" << ((dev_file) ? dev_file : "")
@@ -1110,7 +1110,7 @@ std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
auto cm = ConnectionManager(rank, device_names, coordinator); auto cm = ConnectionManager(rank, device_names, coordinator);
if (cm.size() > MAX_PEERS) { if (cm.size() > MAX_PEERS) {
std::ostringstream msg; std::ostringstream msg;
msg << "[ibv] The maximum number of supported peers is " << MAX_PEERS msg << "[jaccl] The maximum number of supported peers is " << MAX_PEERS
<< " but " << cm.size() << " was provided"; << " but " << cm.size() << " was provided";
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
@@ -1121,4 +1121,4 @@ std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
return std::make_shared<IBVGroup>(std::move(cm)); return std::make_shared<IBVGroup>(std::move(cm));
} }
} // namespace mlx::core::distributed::ibv } // namespace mlx::core::distributed::jaccl

View File

@@ -2,11 +2,11 @@
#include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed.h"
namespace mlx::core::distributed::ibv { namespace mlx::core::distributed::jaccl {
using GroupImpl = mlx::core::distributed::detail::GroupImpl; using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available(); bool is_available();
std::shared_ptr<GroupImpl> init(bool strict = false); std::shared_ptr<GroupImpl> init(bool strict = false);
} // namespace mlx::core::distributed::ibv } // namespace mlx::core::distributed::jaccl

View File

@@ -1,8 +1,8 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/distributed/ibv/ibv.h" #include "mlx/distributed/jaccl/jaccl.h"
namespace mlx::core::distributed::ibv { namespace mlx::core::distributed::jaccl {
using GroupImpl = mlx::core::distributed::detail::GroupImpl; using GroupImpl = mlx::core::distributed::detail::GroupImpl;
@@ -12,9 +12,9 @@ bool is_available() {
std::shared_ptr<GroupImpl> init(bool strict /* = false */) { std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
if (strict) { if (strict) {
throw std::runtime_error("Cannot initialize ibv distributed backend."); throw std::runtime_error("Cannot initialize jaccl distributed backend.");
} }
return nullptr; return nullptr;
} }
} // namespace mlx::core::distributed::ibv } // namespace mlx::core::distributed::jaccl

View File

@@ -95,7 +95,7 @@ void init_distributed(nb::module_& parent_module) {
in case ``mx.distributed.is_available()`` returns False otherwise in case ``mx.distributed.is_available()`` returns False otherwise
it throws a runtime error. Default: ``False`` it throws a runtime error. Default: ``False``
backend (str, optional): Which distributed backend to initialize. backend (str, optional): Which distributed backend to initialize.
Possible values ``mpi``, ``ring``, ``nccl``, ``ibv``, ``any``. If Possible values ``mpi``, ``ring``, ``nccl``, ``jaccl``, ``any``. If
set to ``any`` all available backends are tried and the first one set to ``any`` all available backends are tried and the first one
that succeeds becomes the global group which will be returned in that succeeds becomes the global group which will be returned in
subsequent calls. Default: ``any`` subsequent calls. Default: ``any``