mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Change the name to a fun pun
This commit is contained in:
@@ -11,4 +11,4 @@ endif()
|
|||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ibv)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/jaccl)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
#include "mlx/backend/cuda/cuda.h"
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
#include "mlx/distributed/ibv/ibv.h"
|
#include "mlx/distributed/jaccl/jaccl.h"
|
||||||
#include "mlx/distributed/mpi/mpi.h"
|
#include "mlx/distributed/mpi/mpi.h"
|
||||||
#include "mlx/distributed/nccl/nccl.h"
|
#include "mlx/distributed/nccl/nccl.h"
|
||||||
#include "mlx/distributed/ring/ring.h"
|
#include "mlx/distributed/ring/ring.h"
|
||||||
@@ -104,7 +104,7 @@ class EmptyGroup : public GroupImpl {
|
|||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return mpi::is_available() || ring::is_available() || nccl::is_available() ||
|
return mpi::is_available() || ring::is_available() || nccl::is_available() ||
|
||||||
ibv::is_available();
|
jaccl::is_available();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_available(const std::string& bk) {
|
bool is_available(const std::string& bk) {
|
||||||
@@ -120,8 +120,8 @@ bool is_available(const std::string& bk) {
|
|||||||
if (bk == "nccl") {
|
if (bk == "nccl") {
|
||||||
return nccl::is_available();
|
return nccl::is_available();
|
||||||
}
|
}
|
||||||
if (bk == "ibv") {
|
if (bk == "jaccl") {
|
||||||
return ibv::is_available();
|
return jaccl::is_available();
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -156,8 +156,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
|||||||
group = ring::init(strict);
|
group = ring::init(strict);
|
||||||
} else if (bk == "nccl") {
|
} else if (bk == "nccl") {
|
||||||
group = nccl::init(strict);
|
group = nccl::init(strict);
|
||||||
} else if (bk == "ibv") {
|
} else if (bk == "jaccl") {
|
||||||
group = ibv::init(strict);
|
group = jaccl::init(strict);
|
||||||
} else if (bk == "any") {
|
} else if (bk == "any") {
|
||||||
if (mlx::core::cu::is_available()) {
|
if (mlx::core::cu::is_available()) {
|
||||||
group = nccl::init(false);
|
group = nccl::init(false);
|
||||||
@@ -172,8 +172,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
|||||||
bk_ = "mpi";
|
bk_ = "mpi";
|
||||||
}
|
}
|
||||||
if (group == nullptr) {
|
if (group == nullptr) {
|
||||||
group = ibv::init(false);
|
group = jaccl::init(false);
|
||||||
bk_ = "ibv";
|
bk_ = "jaccl";
|
||||||
}
|
}
|
||||||
if (group == nullptr && strict) {
|
if (group == nullptr && strict) {
|
||||||
throw std::runtime_error("[distributed] Couldn't initialize any backend");
|
throw std::runtime_error("[distributed] Couldn't initialize any backend");
|
||||||
@@ -181,7 +181,7 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
|||||||
} else {
|
} else {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[distributed] The only valid values for backend are 'any', 'mpi', 'nccl', "
|
msg << "[distributed] The only valid values for backend are 'any', 'mpi', 'nccl', "
|
||||||
<< "'ibv' and 'ring' but '" << bk << "' was provided.";
|
<< "'jaccl' and 'ring' but '" << bk << "' was provided.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
if(MLX_BUILD_CPU
|
if(MLX_BUILD_CPU
|
||||||
AND ${CMAKE_SYSTEM_NAME} MATCHES "Darwin"
|
AND ${CMAKE_SYSTEM_NAME} MATCHES "Darwin"
|
||||||
AND MACOS_SDK_VERSION GREATER_EQUAL 26.2)
|
AND MACOS_SDK_VERSION GREATER_EQUAL 26.2)
|
||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ibv.cpp)
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jaccl.cpp)
|
||||||
target_link_libraries(mlx PRIVATE rdma)
|
target_link_libraries(mlx PRIVATE rdma)
|
||||||
else()
|
else()
|
||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_ibv.cpp)
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_jaccl.cpp)
|
||||||
endif()
|
endif()
|
||||||
@@ -13,7 +13,7 @@
|
|||||||
#include "mlx/distributed/utils.h"
|
#include "mlx/distributed/utils.h"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
constexpr const char* IBV_TAG = "[ibv]";
|
constexpr const char* IBV_TAG = "[jaccl]";
|
||||||
constexpr int NUM_BUFFERS = 2;
|
constexpr int NUM_BUFFERS = 2;
|
||||||
constexpr int BUFFER_SIZE = 4096;
|
constexpr int BUFFER_SIZE = 4096;
|
||||||
constexpr int MAX_SEND_WR = 32;
|
constexpr int MAX_SEND_WR = 32;
|
||||||
@@ -103,7 +103,7 @@ class SharedBuffer {
|
|||||||
auto [it, inserted] = memory_regions_.insert({protection_domain, nullptr});
|
auto [it, inserted] = memory_regions_.insert({protection_domain, nullptr});
|
||||||
if (!inserted) {
|
if (!inserted) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[ibv] Buffer can be registered once per protection domain");
|
"[jaccl] Buffer can be registered once per protection domain");
|
||||||
}
|
}
|
||||||
|
|
||||||
it->second = ibv_reg_mr(
|
it->second = ibv_reg_mr(
|
||||||
@@ -113,7 +113,7 @@ class SharedBuffer {
|
|||||||
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ |
|
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ |
|
||||||
IBV_ACCESS_REMOTE_WRITE);
|
IBV_ACCESS_REMOTE_WRITE);
|
||||||
if (!it->second) {
|
if (!it->second) {
|
||||||
throw std::runtime_error("[ibv] Register memory region failed");
|
throw std::runtime_error("[jaccl] Register memory region failed");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -205,14 +205,14 @@ struct Connection {
|
|||||||
void allocate_protection_domain() {
|
void allocate_protection_domain() {
|
||||||
protection_domain = ibv_alloc_pd(ctx);
|
protection_domain = ibv_alloc_pd(ctx);
|
||||||
if (protection_domain == nullptr) {
|
if (protection_domain == nullptr) {
|
||||||
throw std::runtime_error("[ibv] Couldn't allocate protection domain");
|
throw std::runtime_error("[jaccl] Couldn't allocate protection domain");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void create_completion_queue(int num_entries) {
|
void create_completion_queue(int num_entries) {
|
||||||
completion_queue = ibv_create_cq(ctx, num_entries, nullptr, nullptr, 0);
|
completion_queue = ibv_create_cq(ctx, num_entries, nullptr, nullptr, 0);
|
||||||
if (completion_queue == nullptr) {
|
if (completion_queue == nullptr) {
|
||||||
throw std::runtime_error("[ibv] Couldn't create completion queue");
|
throw std::runtime_error("[jaccl] Couldn't create completion queue");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -234,7 +234,7 @@ struct Connection {
|
|||||||
queue_pair = ibv_create_qp(protection_domain, &init_attr);
|
queue_pair = ibv_create_qp(protection_domain, &init_attr);
|
||||||
|
|
||||||
if (queue_pair == nullptr) {
|
if (queue_pair == nullptr) {
|
||||||
throw std::runtime_error("[ibv] Couldn't create queue pair");
|
throw std::runtime_error("[jaccl] Couldn't create queue pair");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -269,7 +269,7 @@ struct Connection {
|
|||||||
|
|
||||||
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
|
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[ibv] Changing queue pair to INIT failed with errno " << status;
|
msg << "[jaccl] Changing queue pair to INIT failed with errno " << status;
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -299,7 +299,7 @@ struct Connection {
|
|||||||
|
|
||||||
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
|
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[ibv] Changing queue pair to RTR failed with errno " << status;
|
msg << "[jaccl] Changing queue pair to RTR failed with errno " << status;
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -313,7 +313,7 @@ struct Connection {
|
|||||||
|
|
||||||
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
|
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[ibv] Changing queue pair to RTS failed with errno " << status;
|
msg << "[jaccl] Changing queue pair to RTS failed with errno " << status;
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -333,7 +333,7 @@ struct Connection {
|
|||||||
ibv_post_send(queue_pair, &work_request, &bad_work_request);
|
ibv_post_send(queue_pair, &work_request, &bad_work_request);
|
||||||
status != 0) {
|
status != 0) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[ibv] Send failed with error code " << status;
|
msg << "[jaccl] Send failed with error code " << status;
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -351,7 +351,7 @@ struct Connection {
|
|||||||
ibv_post_recv(queue_pair, &work_request, &bad_work_request);
|
ibv_post_recv(queue_pair, &work_request, &bad_work_request);
|
||||||
status != 0) {
|
status != 0) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[ibv] Recv failed with error code " << status;
|
msg << "[jaccl] Recv failed with error code " << status;
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -491,7 +491,7 @@ class ConnectionManager {
|
|||||||
side_channel_(rank_, size_, coordinator_addr) {
|
side_channel_(rank_, size_, coordinator_addr) {
|
||||||
create_contexts(device_names);
|
create_contexts(device_names);
|
||||||
if (connections_[rank_].ctx != nullptr) {
|
if (connections_[rank_].ctx != nullptr) {
|
||||||
throw std::runtime_error("[ibv] Malformed device file");
|
throw std::runtime_error("[jaccl] Malformed device file");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -653,7 +653,7 @@ class ConnectionManager {
|
|||||||
auto ctx = ibv_open_device(devices[i]);
|
auto ctx = ibv_open_device(devices[i]);
|
||||||
if (ctx == nullptr) {
|
if (ctx == nullptr) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[ibv] Could not open device " << name;
|
msg << "[jaccl] Could not open device " << name;
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
connections_.emplace_back(ctx);
|
connections_.emplace_back(ctx);
|
||||||
@@ -688,7 +688,7 @@ std::vector<std::string> load_device_names(int rank, const char* dev_file) {
|
|||||||
return device_names;
|
return device_names;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace mlx::core::distributed::ibv {
|
namespace mlx::core::distributed::jaccl {
|
||||||
|
|
||||||
class IBVGroup : public GroupImpl {
|
class IBVGroup : public GroupImpl {
|
||||||
public:
|
public:
|
||||||
@@ -926,11 +926,11 @@ class IBVGroup : public GroupImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void sum_scatter(const array& input, array& output, Stream stream) override {
|
void sum_scatter(const array& input, array& output, Stream stream) override {
|
||||||
throw std::runtime_error("[ring] sum_scatter not supported.");
|
throw std::runtime_error("[jaccl] sum_scatter not supported.");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||||
throw std::runtime_error("[ibv] Group split not supported.");
|
throw std::runtime_error("[jaccl] Group split not supported.");
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@@ -1093,7 +1093,7 @@ std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
|||||||
if (!is_available() || !dev_file || !coordinator || !rank_str) {
|
if (!is_available() || !dev_file || !coordinator || !rank_str) {
|
||||||
if (strict) {
|
if (strict) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[ibv] You need to provide via environment variables a rank (MLX_RANK), "
|
msg << "[jaccl] You need to provide via environment variables a rank (MLX_RANK), "
|
||||||
<< "a device file (MLX_IBV_DEVICES) and a coordinator ip/port (MLX_IBV_COORDINATOR) "
|
<< "a device file (MLX_IBV_DEVICES) and a coordinator ip/port (MLX_IBV_COORDINATOR) "
|
||||||
<< "but provided MLX_RANK=\"" << ((rank_str) ? rank_str : "")
|
<< "but provided MLX_RANK=\"" << ((rank_str) ? rank_str : "")
|
||||||
<< "\", MLX_IBV_DEVICES=\"" << ((dev_file) ? dev_file : "")
|
<< "\", MLX_IBV_DEVICES=\"" << ((dev_file) ? dev_file : "")
|
||||||
@@ -1110,7 +1110,7 @@ std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
|||||||
auto cm = ConnectionManager(rank, device_names, coordinator);
|
auto cm = ConnectionManager(rank, device_names, coordinator);
|
||||||
if (cm.size() > MAX_PEERS) {
|
if (cm.size() > MAX_PEERS) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[ibv] The maximum number of supported peers is " << MAX_PEERS
|
msg << "[jaccl] The maximum number of supported peers is " << MAX_PEERS
|
||||||
<< " but " << cm.size() << " was provided";
|
<< " but " << cm.size() << " was provided";
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
@@ -1121,4 +1121,4 @@ std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
|||||||
return std::make_shared<IBVGroup>(std::move(cm));
|
return std::make_shared<IBVGroup>(std::move(cm));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::distributed::ibv
|
} // namespace mlx::core::distributed::jaccl
|
||||||
@@ -2,11 +2,11 @@
|
|||||||
|
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
|
|
||||||
namespace mlx::core::distributed::ibv {
|
namespace mlx::core::distributed::jaccl {
|
||||||
|
|
||||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
bool is_available();
|
bool is_available();
|
||||||
std::shared_ptr<GroupImpl> init(bool strict = false);
|
std::shared_ptr<GroupImpl> init(bool strict = false);
|
||||||
|
|
||||||
} // namespace mlx::core::distributed::ibv
|
} // namespace mlx::core::distributed::jaccl
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/distributed/ibv/ibv.h"
|
#include "mlx/distributed/jaccl/jaccl.h"
|
||||||
|
|
||||||
namespace mlx::core::distributed::ibv {
|
namespace mlx::core::distributed::jaccl {
|
||||||
|
|
||||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
@@ -12,9 +12,9 @@ bool is_available() {
|
|||||||
|
|
||||||
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||||
if (strict) {
|
if (strict) {
|
||||||
throw std::runtime_error("Cannot initialize ibv distributed backend.");
|
throw std::runtime_error("Cannot initialize jaccl distributed backend.");
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::distributed::ibv
|
} // namespace mlx::core::distributed::jaccl
|
||||||
@@ -95,7 +95,7 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
in case ``mx.distributed.is_available()`` returns False otherwise
|
in case ``mx.distributed.is_available()`` returns False otherwise
|
||||||
it throws a runtime error. Default: ``False``
|
it throws a runtime error. Default: ``False``
|
||||||
backend (str, optional): Which distributed backend to initialize.
|
backend (str, optional): Which distributed backend to initialize.
|
||||||
Possible values ``mpi``, ``ring``, ``nccl``, ``ibv``, ``any``. If
|
Possible values ``mpi``, ``ring``, ``nccl``, ``jaccl``, ``any``. If
|
||||||
set to ``any`` all available backends are tried and the first one
|
set to ``any`` all available backends are tried and the first one
|
||||||
that succeeds becomes the global group which will be returned in
|
that succeeds becomes the global group which will be returned in
|
||||||
subsequent calls. Default: ``any``
|
subsequent calls. Default: ``any``
|
||||||
|
|||||||
Reference in New Issue
Block a user