mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Change the name to a fun pun
This commit is contained in:
@@ -11,4 +11,4 @@ endif()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ibv)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/jaccl)
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/ibv/ibv.h"
|
||||
#include "mlx/distributed/jaccl/jaccl.h"
|
||||
#include "mlx/distributed/mpi/mpi.h"
|
||||
#include "mlx/distributed/nccl/nccl.h"
|
||||
#include "mlx/distributed/ring/ring.h"
|
||||
@@ -104,7 +104,7 @@ class EmptyGroup : public GroupImpl {
|
||||
|
||||
bool is_available() {
|
||||
return mpi::is_available() || ring::is_available() || nccl::is_available() ||
|
||||
ibv::is_available();
|
||||
jaccl::is_available();
|
||||
}
|
||||
|
||||
bool is_available(const std::string& bk) {
|
||||
@@ -120,8 +120,8 @@ bool is_available(const std::string& bk) {
|
||||
if (bk == "nccl") {
|
||||
return nccl::is_available();
|
||||
}
|
||||
if (bk == "ibv") {
|
||||
return ibv::is_available();
|
||||
if (bk == "jaccl") {
|
||||
return jaccl::is_available();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@@ -156,8 +156,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
||||
group = ring::init(strict);
|
||||
} else if (bk == "nccl") {
|
||||
group = nccl::init(strict);
|
||||
} else if (bk == "ibv") {
|
||||
group = ibv::init(strict);
|
||||
} else if (bk == "jaccl") {
|
||||
group = jaccl::init(strict);
|
||||
} else if (bk == "any") {
|
||||
if (mlx::core::cu::is_available()) {
|
||||
group = nccl::init(false);
|
||||
@@ -172,8 +172,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
||||
bk_ = "mpi";
|
||||
}
|
||||
if (group == nullptr) {
|
||||
group = ibv::init(false);
|
||||
bk_ = "ibv";
|
||||
group = jaccl::init(false);
|
||||
bk_ = "jaccl";
|
||||
}
|
||||
if (group == nullptr && strict) {
|
||||
throw std::runtime_error("[distributed] Couldn't initialize any backend");
|
||||
@@ -181,7 +181,7 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "[distributed] The only valid values for backend are 'any', 'mpi', 'nccl', "
|
||||
<< "'ibv' and 'ring' but '" << bk << "' was provided.";
|
||||
<< "'jaccl' and 'ring' but '" << bk << "' was provided.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
if(MLX_BUILD_CPU
|
||||
AND ${CMAKE_SYSTEM_NAME} MATCHES "Darwin"
|
||||
AND MACOS_SDK_VERSION GREATER_EQUAL 26.2)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ibv.cpp)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jaccl.cpp)
|
||||
target_link_libraries(mlx PRIVATE rdma)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_ibv.cpp)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_jaccl.cpp)
|
||||
endif()
|
||||
@@ -13,7 +13,7 @@
|
||||
#include "mlx/distributed/utils.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
constexpr const char* IBV_TAG = "[ibv]";
|
||||
constexpr const char* IBV_TAG = "[jaccl]";
|
||||
constexpr int NUM_BUFFERS = 2;
|
||||
constexpr int BUFFER_SIZE = 4096;
|
||||
constexpr int MAX_SEND_WR = 32;
|
||||
@@ -103,7 +103,7 @@ class SharedBuffer {
|
||||
auto [it, inserted] = memory_regions_.insert({protection_domain, nullptr});
|
||||
if (!inserted) {
|
||||
throw std::runtime_error(
|
||||
"[ibv] Buffer can be registered once per protection domain");
|
||||
"[jaccl] Buffer can be registered once per protection domain");
|
||||
}
|
||||
|
||||
it->second = ibv_reg_mr(
|
||||
@@ -113,7 +113,7 @@ class SharedBuffer {
|
||||
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ |
|
||||
IBV_ACCESS_REMOTE_WRITE);
|
||||
if (!it->second) {
|
||||
throw std::runtime_error("[ibv] Register memory region failed");
|
||||
throw std::runtime_error("[jaccl] Register memory region failed");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -205,14 +205,14 @@ struct Connection {
|
||||
void allocate_protection_domain() {
|
||||
protection_domain = ibv_alloc_pd(ctx);
|
||||
if (protection_domain == nullptr) {
|
||||
throw std::runtime_error("[ibv] Couldn't allocate protection domain");
|
||||
throw std::runtime_error("[jaccl] Couldn't allocate protection domain");
|
||||
}
|
||||
}
|
||||
|
||||
void create_completion_queue(int num_entries) {
|
||||
completion_queue = ibv_create_cq(ctx, num_entries, nullptr, nullptr, 0);
|
||||
if (completion_queue == nullptr) {
|
||||
throw std::runtime_error("[ibv] Couldn't create completion queue");
|
||||
throw std::runtime_error("[jaccl] Couldn't create completion queue");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -234,7 +234,7 @@ struct Connection {
|
||||
queue_pair = ibv_create_qp(protection_domain, &init_attr);
|
||||
|
||||
if (queue_pair == nullptr) {
|
||||
throw std::runtime_error("[ibv] Couldn't create queue pair");
|
||||
throw std::runtime_error("[jaccl] Couldn't create queue pair");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -269,7 +269,7 @@ struct Connection {
|
||||
|
||||
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ibv] Changing queue pair to INIT failed with errno " << status;
|
||||
msg << "[jaccl] Changing queue pair to INIT failed with errno " << status;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
@@ -299,7 +299,7 @@ struct Connection {
|
||||
|
||||
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ibv] Changing queue pair to RTR failed with errno " << status;
|
||||
msg << "[jaccl] Changing queue pair to RTR failed with errno " << status;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
@@ -313,7 +313,7 @@ struct Connection {
|
||||
|
||||
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ibv] Changing queue pair to RTS failed with errno " << status;
|
||||
msg << "[jaccl] Changing queue pair to RTS failed with errno " << status;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
@@ -333,7 +333,7 @@ struct Connection {
|
||||
ibv_post_send(queue_pair, &work_request, &bad_work_request);
|
||||
status != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ibv] Send failed with error code " << status;
|
||||
msg << "[jaccl] Send failed with error code " << status;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
@@ -351,7 +351,7 @@ struct Connection {
|
||||
ibv_post_recv(queue_pair, &work_request, &bad_work_request);
|
||||
status != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ibv] Recv failed with error code " << status;
|
||||
msg << "[jaccl] Recv failed with error code " << status;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
@@ -491,7 +491,7 @@ class ConnectionManager {
|
||||
side_channel_(rank_, size_, coordinator_addr) {
|
||||
create_contexts(device_names);
|
||||
if (connections_[rank_].ctx != nullptr) {
|
||||
throw std::runtime_error("[ibv] Malformed device file");
|
||||
throw std::runtime_error("[jaccl] Malformed device file");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -653,7 +653,7 @@ class ConnectionManager {
|
||||
auto ctx = ibv_open_device(devices[i]);
|
||||
if (ctx == nullptr) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ibv] Could not open device " << name;
|
||||
msg << "[jaccl] Could not open device " << name;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
connections_.emplace_back(ctx);
|
||||
@@ -688,7 +688,7 @@ std::vector<std::string> load_device_names(int rank, const char* dev_file) {
|
||||
return device_names;
|
||||
}
|
||||
|
||||
namespace mlx::core::distributed::ibv {
|
||||
namespace mlx::core::distributed::jaccl {
|
||||
|
||||
class IBVGroup : public GroupImpl {
|
||||
public:
|
||||
@@ -926,11 +926,11 @@ class IBVGroup : public GroupImpl {
|
||||
}
|
||||
|
||||
void sum_scatter(const array& input, array& output, Stream stream) override {
|
||||
throw std::runtime_error("[ring] sum_scatter not supported.");
|
||||
throw std::runtime_error("[jaccl] sum_scatter not supported.");
|
||||
}
|
||||
|
||||
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||
throw std::runtime_error("[ibv] Group split not supported.");
|
||||
throw std::runtime_error("[jaccl] Group split not supported.");
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -1093,7 +1093,7 @@ std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||
if (!is_available() || !dev_file || !coordinator || !rank_str) {
|
||||
if (strict) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ibv] You need to provide via environment variables a rank (MLX_RANK), "
|
||||
msg << "[jaccl] You need to provide via environment variables a rank (MLX_RANK), "
|
||||
<< "a device file (MLX_IBV_DEVICES) and a coordinator ip/port (MLX_IBV_COORDINATOR) "
|
||||
<< "but provided MLX_RANK=\"" << ((rank_str) ? rank_str : "")
|
||||
<< "\", MLX_IBV_DEVICES=\"" << ((dev_file) ? dev_file : "")
|
||||
@@ -1110,7 +1110,7 @@ std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||
auto cm = ConnectionManager(rank, device_names, coordinator);
|
||||
if (cm.size() > MAX_PEERS) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ibv] The maximum number of supported peers is " << MAX_PEERS
|
||||
msg << "[jaccl] The maximum number of supported peers is " << MAX_PEERS
|
||||
<< " but " << cm.size() << " was provided";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
@@ -1121,4 +1121,4 @@ std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||
return std::make_shared<IBVGroup>(std::move(cm));
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed::ibv
|
||||
} // namespace mlx::core::distributed::jaccl
|
||||
@@ -2,11 +2,11 @@
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
|
||||
namespace mlx::core::distributed::ibv {
|
||||
namespace mlx::core::distributed::jaccl {
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
|
||||
bool is_available();
|
||||
std::shared_ptr<GroupImpl> init(bool strict = false);
|
||||
|
||||
} // namespace mlx::core::distributed::ibv
|
||||
} // namespace mlx::core::distributed::jaccl
|
||||
@@ -1,8 +1,8 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/distributed/ibv/ibv.h"
|
||||
#include "mlx/distributed/jaccl/jaccl.h"
|
||||
|
||||
namespace mlx::core::distributed::ibv {
|
||||
namespace mlx::core::distributed::jaccl {
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
|
||||
@@ -12,9 +12,9 @@ bool is_available() {
|
||||
|
||||
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||
if (strict) {
|
||||
throw std::runtime_error("Cannot initialize ibv distributed backend.");
|
||||
throw std::runtime_error("Cannot initialize jaccl distributed backend.");
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed::ibv
|
||||
} // namespace mlx::core::distributed::jaccl
|
||||
@@ -95,7 +95,7 @@ void init_distributed(nb::module_& parent_module) {
|
||||
in case ``mx.distributed.is_available()`` returns False otherwise
|
||||
it throws a runtime error. Default: ``False``
|
||||
backend (str, optional): Which distributed backend to initialize.
|
||||
Possible values ``mpi``, ``ring``, ``nccl``, ``ibv``, ``any``. If
|
||||
Possible values ``mpi``, ``ring``, ``nccl``, ``jaccl``, ``any``. If
|
||||
set to ``any`` all available backends are tried and the first one
|
||||
that succeeds becomes the global group which will be returned in
|
||||
subsequent calls. Default: ``any``
|
||||
|
||||
Reference in New Issue
Block a user