mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 16:48:10 +08:00
NCCL backend (#2476)
This commit is contained in:

committed by
GitHub

parent
e843c4d8d5
commit
9392fc3f88
@@ -22,6 +22,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
|
51
mlx/backend/cuda/distributed.cu
Normal file
51
mlx/backend/cuda/distributed.cu
Normal file
@@ -0,0 +1,51 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
namespace distributed {
|
||||
void AllReduce::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& input = inputs[0];
|
||||
auto& output = outputs[0];
|
||||
|
||||
auto& encoder = cu::get_command_encoder(stream());
|
||||
|
||||
if (input.is_donatable()) {
|
||||
output.copy_shared_buffer(input);
|
||||
} else {
|
||||
output.set_data(allocator::malloc(output.nbytes()));
|
||||
}
|
||||
|
||||
encoder.set_input_array(input);
|
||||
encoder.set_output_array(output);
|
||||
|
||||
auto capture = encoder.capture_context();
|
||||
auto& s = stream();
|
||||
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
distributed::detail::all_sum(group(), input, output, s);
|
||||
break;
|
||||
case Max:
|
||||
distributed::detail::all_max(group(), input, output, s);
|
||||
break;
|
||||
case Min:
|
||||
distributed::detail::all_min(group(), input, output, s);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"Only all reduce sum, max, and min are supported.");
|
||||
}
|
||||
}
|
||||
} // namespace distributed
|
||||
} // namespace mlx::core
|
@@ -42,7 +42,6 @@ NO_GPU_MULTI(Eig)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
||||
namespace distributed {
|
||||
NO_GPU_MULTI(AllReduce)
|
||||
NO_GPU_MULTI(AllGather)
|
||||
NO_GPU_MULTI(Send)
|
||||
NO_GPU_MULTI(Recv)
|
||||
|
@@ -6,3 +6,4 @@ target_sources(
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
|
||||
|
@@ -5,12 +5,17 @@
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/mpi/mpi.h"
|
||||
#include "mlx/distributed/nccl/nccl.h"
|
||||
#include "mlx/distributed/ring/ring.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
namespace detail {
|
||||
|
||||
Stream communication_stream(Group group, StreamOrDevice s /* = {} */) {
|
||||
return group.raw_group()->communication_stream(s);
|
||||
}
|
||||
|
||||
void all_sum(Group group, const array& input, array& output, Stream stream) {
|
||||
group.raw_group()->all_sum(input, output, stream);
|
||||
}
|
||||
@@ -37,6 +42,10 @@ void recv(Group group, array& out, int src, Stream stream) {
|
||||
|
||||
class EmptyGroup : public GroupImpl {
|
||||
public:
|
||||
Stream communication_stream(StreamOrDevice s) override {
|
||||
return to_stream(s);
|
||||
}
|
||||
|
||||
int rank() override {
|
||||
return 0;
|
||||
}
|
||||
@@ -80,7 +89,7 @@ class EmptyGroup : public GroupImpl {
|
||||
} // namespace detail
|
||||
|
||||
bool is_available() {
|
||||
return mpi::is_available() || ring::is_available();
|
||||
return mpi::is_available() || ring::is_available() || nccl::is_available();
|
||||
}
|
||||
|
||||
int Group::rank() const {
|
||||
@@ -111,6 +120,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
||||
group = mpi::init(strict);
|
||||
} else if (bk == "ring") {
|
||||
group = ring::init(strict);
|
||||
} else if (bk == "nccl") {
|
||||
group = nccl::init(strict);
|
||||
} else if (bk == "any") {
|
||||
group = ring::init(false);
|
||||
bk_ = "ring";
|
||||
|
@@ -5,6 +5,7 @@
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
|
@@ -13,10 +13,15 @@ class GroupImpl {
|
||||
public:
|
||||
virtual ~GroupImpl() {}
|
||||
|
||||
// Choose the stream this communication group can operate on
|
||||
virtual Stream communication_stream(StreamOrDevice s = {}) = 0;
|
||||
|
||||
// Group operations
|
||||
virtual int rank() = 0;
|
||||
virtual int size() = 0;
|
||||
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
|
||||
|
||||
// Actual communication operations
|
||||
virtual void all_sum(const array& input, array& output, Stream stream) = 0;
|
||||
virtual void all_gather(const array& input, array& output, Stream stream) = 0;
|
||||
virtual void send(const array& input, int dst, Stream stream) = 0;
|
||||
@@ -25,6 +30,9 @@ class GroupImpl {
|
||||
virtual void all_min(const array& input, array& output, Stream stream) = 0;
|
||||
};
|
||||
|
||||
/* Define the MLX stream that the communication should happen in. */
|
||||
Stream communication_stream(Group group, StreamOrDevice s = {});
|
||||
|
||||
/* Perform an all reduce sum operation */
|
||||
void all_sum(Group group, const array& input, array& output, Stream stream);
|
||||
|
||||
|
@@ -349,6 +349,10 @@ class MPIGroup : public GroupImpl {
|
||||
}
|
||||
}
|
||||
|
||||
Stream communication_stream(StreamOrDevice s) override {
|
||||
return to_stream(s, Device::cpu);
|
||||
}
|
||||
|
||||
int rank() override {
|
||||
if (rank_ < 0) {
|
||||
mpi().rank(comm_, &rank_);
|
||||
|
8
mlx/distributed/nccl/CMakeLists.txt
Normal file
8
mlx/distributed/nccl/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
if(MLX_BUILD_CUDA)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nccl.cpp)
|
||||
find_package(NCCL REQUIRED)
|
||||
target_link_libraries(mlx PRIVATE ${NCCL_LIBRARIES})
|
||||
target_include_directories(mlx PRIVATE ${NCCL_INCLUDE_DIRS})
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_nccl.cpp)
|
||||
endif()
|
359
mlx/distributed/nccl/nccl.cpp
Normal file
359
mlx/distributed/nccl/nccl.cpp
Normal file
@@ -0,0 +1,359 @@
|
||||
#include <arpa/inet.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <nccl.h>
|
||||
#include <netdb.h>
|
||||
#include <sys/socket.h>
|
||||
#include <unistd.h>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <mutex>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::distributed::nccl {
|
||||
|
||||
#define CHECK_CUDA(cmd) \
|
||||
do { \
|
||||
cudaError_t e = cmd; \
|
||||
if (e != cudaSuccess) { \
|
||||
fprintf( \
|
||||
stderr, \
|
||||
"CUDA error %s:%d '%s'\n", \
|
||||
__FILE__, \
|
||||
__LINE__, \
|
||||
cudaGetErrorString(e)); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define CHECK_NCCL(cmd) \
|
||||
do { \
|
||||
ncclResult_t r = cmd; \
|
||||
if (r != ncclSuccess) { \
|
||||
fprintf( \
|
||||
stderr, \
|
||||
"NCCL error %s:%d '%s'\n", \
|
||||
__FILE__, \
|
||||
__LINE__, \
|
||||
ncclGetErrorString(r)); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define MLX_NCCL_TYPE_LIST(X) \
|
||||
X(int8_t, ncclChar) \
|
||||
X(uint8_t, ncclUint8) \
|
||||
X(int32_t, ncclInt) \
|
||||
X(uint32_t, ncclUint32) \
|
||||
X(int64_t, ncclInt64) \
|
||||
X(uint64_t, ncclUint64) \
|
||||
X(float16_t, ncclHalf) \
|
||||
X(bfloat16_t, ncclBfloat16) \
|
||||
X(float, ncclFloat) \
|
||||
X(double, ncclDouble)
|
||||
|
||||
template <class>
|
||||
struct nccl_map {
|
||||
static constexpr bool ok = false; // default: unsupported
|
||||
};
|
||||
|
||||
#define MLX_DEF_NCCL_MAP(T, E) \
|
||||
template <> \
|
||||
struct nccl_map<T> { \
|
||||
static constexpr bool ok = true; \
|
||||
static constexpr ncclDataType_t value = E; \
|
||||
};
|
||||
|
||||
MLX_NCCL_TYPE_LIST(MLX_DEF_NCCL_MAP)
|
||||
#undef MLX_DEF_NCCL_MAP
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename F>
|
||||
void dispatch_dtype(const array& arr, F&& f) {
|
||||
dispatch_all_types(arr.dtype(), [&](auto type_tag) {
|
||||
using T = MLX_GET_TYPE(type_tag);
|
||||
if constexpr (nccl_map<T>::ok) {
|
||||
f(type_tag, nccl_map<T>::value);
|
||||
} else {
|
||||
throw std::invalid_argument("[nccl] Unknown or unsupported dtype");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
inline void sendAll(int sock, const void* buf, size_t len) {
|
||||
const char* ptr = reinterpret_cast<const char*>(buf);
|
||||
while (len > 0) {
|
||||
ssize_t sent = send(sock, ptr, len, 0);
|
||||
if (sent <= 0) {
|
||||
perror("send");
|
||||
exit(1);
|
||||
}
|
||||
ptr += sent;
|
||||
len -= sent;
|
||||
}
|
||||
}
|
||||
|
||||
inline void recvAll(int sock, void* buf, size_t len) {
|
||||
char* ptr = reinterpret_cast<char*>(buf);
|
||||
while (len > 0) {
|
||||
ssize_t rec = recv(sock, ptr, len, 0);
|
||||
if (rec <= 0) {
|
||||
perror("recv");
|
||||
exit(1);
|
||||
}
|
||||
ptr += rec;
|
||||
len -= rec;
|
||||
}
|
||||
}
|
||||
|
||||
inline void bootstrap_unique_id(
|
||||
ncclUniqueId& id,
|
||||
int rank,
|
||||
int size,
|
||||
const std::string& initMethod) {
|
||||
// Parse the init method to extract the host and port
|
||||
if (initMethod.rfind("tcp://", 0) != 0)
|
||||
throw;
|
||||
auto hostport = initMethod.substr(6);
|
||||
auto colon = hostport.find(':');
|
||||
std::string host = hostport.substr(0, colon);
|
||||
int port = std::stoi(hostport.substr(colon + 1));
|
||||
|
||||
if (rank == 0) {
|
||||
// create a unique id on the rank 0
|
||||
CHECK_NCCL(ncclGetUniqueId(&id));
|
||||
|
||||
// create a socket to send the unique id to all other ranks
|
||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
|
||||
if (sock < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[nccl] Couldn't create socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
sockaddr_in serv = {};
|
||||
serv.sin_family = AF_INET;
|
||||
serv.sin_addr.s_addr = htonl(INADDR_ANY);
|
||||
serv.sin_port = htons(port);
|
||||
|
||||
int reuse = 1;
|
||||
// Without this, if rank-0 crashes or restarts process quickly,
|
||||
// the OS might refuse to let binding to the same port, so reuse
|
||||
|
||||
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[nccl] setsockopt() failed: " << strerror(errno);
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
if (bind(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[nccl] bind() failed: " << strerror(errno);
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
if (listen(sock, size - 1) < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[nccl] listen() failed: " << strerror(errno);
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
for (int peer = 1; peer < size; ++peer) {
|
||||
int conn = accept(sock, nullptr, nullptr);
|
||||
if (conn < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[nccl] accept() failed: " << strerror(errno);
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
sendAll(conn, &id, sizeof(id));
|
||||
close(conn);
|
||||
}
|
||||
close(sock);
|
||||
|
||||
} else {
|
||||
// Here just wanted to make show that rank 0 has enough time to bind
|
||||
// so we will retry to connect until max attempts
|
||||
|
||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[nccl] socket() failed: " << strerror(errno);
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
hostent* he = gethostbyname(host.c_str());
|
||||
if (!he) {
|
||||
throw std::runtime_error("[nccl] lookup failed for host: " + host);
|
||||
}
|
||||
sockaddr_in serv = {};
|
||||
serv.sin_family = AF_INET;
|
||||
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
|
||||
serv.sin_port = htons(port);
|
||||
|
||||
const int max_retries = 30;
|
||||
int attempt = 0;
|
||||
bool connected = false;
|
||||
|
||||
for (attempt = 0; attempt < max_retries; ++attempt) {
|
||||
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
|
||||
0) {
|
||||
connected = true;
|
||||
std::cout << "[Rank " << rank << "] Connected successfully on attempt "
|
||||
<< attempt + 1 << std::endl;
|
||||
break;
|
||||
}
|
||||
if (errno != ECONNREFUSED) {
|
||||
break;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
||||
}
|
||||
|
||||
if (!connected) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Rank " << rank << "] connect() failed after " << attempt
|
||||
<< " retries: " << strerror(errno);
|
||||
close(sock);
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
recvAll(sock, &id, sizeof(id));
|
||||
close(sock);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
class NCCLGroup : public GroupImpl {
|
||||
public:
|
||||
NCCLGroup(int worldRank, int worldSize, const std::string initMethod)
|
||||
: rank_(worldRank),
|
||||
size_(worldSize),
|
||||
comm_(nullptr),
|
||||
initMethod_(initMethod) {
|
||||
if (initialized_)
|
||||
return;
|
||||
int ndev;
|
||||
CHECK_CUDA(cudaGetDeviceCount(&ndev));
|
||||
CHECK_CUDA(cudaSetDevice(rank_ % ndev));
|
||||
detail::bootstrap_unique_id(uniqueId_, rank_, size_, initMethod_);
|
||||
CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_));
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
~NCCLGroup() {
|
||||
ncclCommDestroy(comm_);
|
||||
ncclGroupEnd();
|
||||
initialized_ = false;
|
||||
}
|
||||
|
||||
Stream communication_stream(StreamOrDevice s) override {
|
||||
return to_stream(s, Device::gpu);
|
||||
}
|
||||
|
||||
int rank() override {
|
||||
return rank_;
|
||||
}
|
||||
|
||||
int size() override {
|
||||
return size_;
|
||||
}
|
||||
|
||||
void all_sum(const array& input, array& output, Stream stream) override {
|
||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||
using T = typename decltype(type_tag)::type;
|
||||
all_reduce_impl<T>(input, output, stream, dt, ncclSum);
|
||||
});
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||
throw std::runtime_error("[nccl] Group split not supported.");
|
||||
}
|
||||
|
||||
void all_gather(const array& input, array& output, Stream stream) override {
|
||||
throw std::runtime_error(
|
||||
"[nccl] All gather not supported in NCCL backend.");
|
||||
}
|
||||
|
||||
void send(const array& input, int dst, Stream stream) override {
|
||||
throw std::runtime_error("[nccl] Send not supported in NCCL backend.");
|
||||
}
|
||||
|
||||
void recv(array& output, int src, Stream stream) override {
|
||||
throw std::runtime_error("[nccl] Recv not supported in NCCL backend.");
|
||||
}
|
||||
|
||||
void all_max(const array& input, array& output, Stream stream) override {
|
||||
throw std::runtime_error("[nccl] All max not supported in NCCL backend.");
|
||||
}
|
||||
|
||||
void all_min(const array& input, array& output, Stream stream) override {
|
||||
throw std::runtime_error("[nccl] All min not supported in NCCL backend.");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void all_reduce_impl(
|
||||
const array& input,
|
||||
array& output,
|
||||
Stream stream,
|
||||
ncclDataType_t dt,
|
||||
ncclRedOp_t op) {
|
||||
auto& encoder = cu::get_command_encoder(stream);
|
||||
|
||||
CHECK_NCCL(ncclAllReduce(
|
||||
input.data<T>(),
|
||||
output.data<T>(),
|
||||
input.size(),
|
||||
dt,
|
||||
op,
|
||||
comm_,
|
||||
encoder.stream()));
|
||||
}
|
||||
|
||||
int rank_, size_;
|
||||
std::string initMethod_;
|
||||
ncclUniqueId uniqueId_;
|
||||
ncclComm_t comm_;
|
||||
bool initialized_ = false;
|
||||
};
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
static std::string get_env_var_or_throw(const char* env_var_name) {
|
||||
const char* value = std::getenv(env_var_name);
|
||||
if (value == nullptr) {
|
||||
std::ostringstream msg;
|
||||
msg << "[nccl] Required environment variable '" << env_var_name
|
||||
<< "' is not set. "
|
||||
<< "Please set it before initializing the distributed backend.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return std::string(value);
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||
std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP");
|
||||
std::string port = detail::get_env_var_or_throw("NCCL_PORT");
|
||||
std::string rank_str = detail::get_env_var_or_throw("MLX_RANK");
|
||||
std::string n_nodes_str = detail::get_env_var_or_throw("MLX_WORLD_SIZE");
|
||||
|
||||
int rank = std::stoi(rank_str);
|
||||
int n_nodes = std::stoi(n_nodes_str);
|
||||
std::string init_method = "tcp://" + host + ":" + port;
|
||||
|
||||
return std::make_shared<NCCLGroup>(rank, n_nodes, init_method);
|
||||
}
|
||||
} // namespace mlx::core::distributed::nccl
|
12
mlx/distributed/nccl/nccl.h
Normal file
12
mlx/distributed/nccl/nccl.h
Normal file
@@ -0,0 +1,12 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
|
||||
namespace mlx::core::distributed::nccl {
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
|
||||
bool is_available();
|
||||
std::shared_ptr<GroupImpl> init(bool strict = false);
|
||||
|
||||
} // namespace mlx::core::distributed::nccl
|
20
mlx/distributed/nccl/no_nccl.cpp
Normal file
20
mlx/distributed/nccl/no_nccl.cpp
Normal file
@@ -0,0 +1,20 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/distributed/nccl/nccl.h"
|
||||
|
||||
namespace mlx::core::distributed::nccl {
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||
if (strict) {
|
||||
throw std::runtime_error("Cannot initialize nccl distributed backend.");
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed::nccl
|
@@ -2,6 +2,9 @@
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
|
||||
@@ -28,11 +31,12 @@ array all_sum(
|
||||
if (group.size() == 1) {
|
||||
return x;
|
||||
}
|
||||
auto stream = detail::communication_stream(group, s);
|
||||
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<AllReduce>(
|
||||
to_stream(s, Device::cpu), group, AllReduce::Sum),
|
||||
std::make_shared<AllReduce>(stream, group, AllReduce::Sum),
|
||||
{x});
|
||||
}
|
||||
|
||||
@@ -45,11 +49,12 @@ array all_max(
|
||||
if (group.size() == 1) {
|
||||
return x;
|
||||
}
|
||||
auto stream = detail::communication_stream(group, s);
|
||||
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<AllReduce>(
|
||||
to_stream(s, Device::cpu), group, AllReduce::Max),
|
||||
std::make_shared<AllReduce>(stream, group, AllReduce::Max),
|
||||
{x});
|
||||
}
|
||||
|
||||
@@ -62,11 +67,12 @@ array all_min(
|
||||
if (group.size() == 1) {
|
||||
return x;
|
||||
}
|
||||
auto stream = detail::communication_stream(group, s);
|
||||
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<AllReduce>(
|
||||
to_stream(s, Device::cpu), group, AllReduce::Min),
|
||||
std::make_shared<AllReduce>(stream, group, AllReduce::Min),
|
||||
{x});
|
||||
}
|
||||
|
||||
@@ -79,6 +85,7 @@ array all_gather(
|
||||
if (group.size() == 1) {
|
||||
return x;
|
||||
}
|
||||
auto stream = detail::communication_stream(group, s);
|
||||
|
||||
auto result_shape = x.shape();
|
||||
if (result_shape.size() == 0) {
|
||||
@@ -89,7 +96,7 @@ array all_gather(
|
||||
return array(
|
||||
std::move(result_shape),
|
||||
x.dtype(),
|
||||
std::make_shared<AllGather>(to_stream(s, Device::cpu), group),
|
||||
std::make_shared<AllGather>(stream, group),
|
||||
{x});
|
||||
}
|
||||
|
||||
@@ -103,6 +110,7 @@ array send(
|
||||
if (group.size() == 1) {
|
||||
throw std::invalid_argument("Cannot send to a singleton group");
|
||||
}
|
||||
auto stream = detail::communication_stream(group, s);
|
||||
|
||||
if (dst < 0 || dst >= group.size()) {
|
||||
std::ostringstream msg;
|
||||
@@ -112,10 +120,7 @@ array send(
|
||||
}
|
||||
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<Send>(to_stream(s, Device::cpu), group, dst),
|
||||
{x});
|
||||
x.shape(), x.dtype(), std::make_shared<Send>(stream, group, dst), {x});
|
||||
}
|
||||
|
||||
array recv(
|
||||
@@ -129,6 +134,7 @@ array recv(
|
||||
if (group.size() == 1) {
|
||||
throw std::invalid_argument("Cannot recv from a singleton group");
|
||||
}
|
||||
auto stream = detail::communication_stream(group, s);
|
||||
|
||||
if (src < 0 || src >= group.size()) {
|
||||
std::ostringstream msg;
|
||||
@@ -139,7 +145,7 @@ array recv(
|
||||
return array(
|
||||
std::move(shape),
|
||||
std::move(dtype),
|
||||
std::make_shared<Recv>(to_stream(s, Device::cpu), group, src),
|
||||
std::make_shared<Recv>(stream, group, src),
|
||||
std::vector<array>{});
|
||||
}
|
||||
|
||||
|
@@ -619,6 +619,10 @@ class RingGroup : public GroupImpl {
|
||||
}
|
||||
}
|
||||
|
||||
Stream communication_stream(StreamOrDevice s) override {
|
||||
return to_stream(s, Device::cpu);
|
||||
}
|
||||
|
||||
int rank() override {
|
||||
return rank_;
|
||||
}
|
||||
|
Reference in New Issue
Block a user