Ring distributed backend (#1784)

This commit is contained in:
Angelos Katharopoulos 2025-01-27 22:15:01 -08:00 committed by GitHub
parent 2235dee906
commit ccb61d7aae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 1078 additions and 44 deletions

View File

@ -160,6 +160,7 @@ jobs:
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
/bin/bash python/tests/run_ring_test.sh
- run: - run:
name: Build example extension name: Build example extension
command: | command: |

View File

@ -5,3 +5,4 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)

View File

@ -1,8 +1,11 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <unordered_map>
#include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/mpi/mpi.h" #include "mlx/distributed/mpi/mpi.h"
#include "mlx/distributed/ring/ring.h"
#include "mlx/scheduler.h" #include "mlx/scheduler.h"
namespace mlx::core::distributed { namespace mlx::core::distributed {
@ -65,7 +68,7 @@ class EmptyGroup : public GroupImpl {
} // namespace detail } // namespace detail
bool is_available() { bool is_available() {
return mpi::is_available(); return mpi::is_available() || ring::is_available();
} }
int Group::rank() const { int Group::rank() const {
@ -80,20 +83,50 @@ Group Group::split(int color, int key /* = -1 */) const {
return Group(group_->split(color, key)); return Group(group_->split(color, key));
} }
Group init(bool strict /* = false */) { Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
auto init_group = [strict]() { static std::unordered_map<std::string, std::shared_ptr<detail::GroupImpl>>
auto default_group = mpi::init(strict); backends;
if (default_group == nullptr) {
default_group = std::make_shared<detail::EmptyGroup>(); // Already initialized so return the group.
if (auto g = backends.find(bk); g != backends.end()) {
return Group(g->second);
} }
return default_group;
}; // Create the requested communication group
static std::shared_ptr<detail::GroupImpl> default_group = init_group(); std::shared_ptr<detail::GroupImpl> group;
std::string bk_ = bk;
if (bk == "mpi") {
group = mpi::init(strict);
} else if (bk == "ring") {
group = ring::init(strict);
} else if (bk == "any") {
group = ring::init(false);
bk_ = "ring";
if (group == nullptr) {
group = mpi::init(false);
bk_ = "mpi";
}
if (group == nullptr && strict) {
throw std::runtime_error("[distributed] Couldn't initialize any backend");
}
} else {
std::ostringstream msg;
msg << "[distributed] The only valid values for backend are 'any', 'mpi' "
<< "and 'ring' but '" << bk << "' was provided.";
throw std::invalid_argument(msg.str());
}
if (group == nullptr) {
group = std::make_shared<detail::EmptyGroup>();
} else {
backends.insert({"any", group});
}
backends.insert({std::move(bk_), group});
// Ensure the communication stream is alive before // Ensure the communication stream is alive before
// the graph is evaluated // the graph is evaluated
detail::communication_stream(); detail::communication_stream();
return Group(default_group); return Group(group);
} }
} // namespace mlx::core::distributed } // namespace mlx::core::distributed

View File

@ -53,6 +53,6 @@ struct Group {
* distributed subsystem. Otherwise simply return a singleton group which will * distributed subsystem. Otherwise simply return a singleton group which will
* render communication operations as no-op. * render communication operations as no-op.
*/ */
Group init(bool strict = false); Group init(bool strict = false, const std::string& bk = "any");
} // namespace mlx::core::distributed } // namespace mlx::core::distributed

View File

@ -11,6 +11,8 @@ namespace mlx::core::distributed::detail {
*/ */
class GroupImpl { class GroupImpl {
public: public:
virtual ~GroupImpl() {}
virtual int rank() = 0; virtual int rank() = 0;
virtual int size() = 0; virtual int size() = 0;
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0; virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;

View File

@ -0,0 +1,5 @@
if(MLX_BUILD_CPU)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ring.cpp)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_ring.cpp)
endif()

View File

@ -0,0 +1,20 @@
// Copyright © 2024 Apple Inc.
#include "mlx/distributed/ring/ring.h"
namespace mlx::core::distributed::ring {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available() {
return false;
}
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
if (strict) {
throw std::runtime_error("Cannot initialize ring distributed backend.");
}
return nullptr;
}
} // namespace mlx::core::distributed::ring

View File

@ -0,0 +1,827 @@
// Copyright © 2024 Apple Inc.
#include <arpa/inet.h>
#include <netdb.h>
#include <sys/socket.h>
#include <unistd.h>
#include <chrono>
#include <fstream>
#include <iostream>
#include <sstream>
#include <thread>
#include <json.hpp>
#include "mlx/backend/common/copy.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/threadpool.h"
#define SWITCH_TYPE(x, ...) \
switch ((x).dtype()) { \
case bool_: { \
using T = bool; \
__VA_ARGS__; \
} break; \
case int8: { \
using T = int8_t; \
__VA_ARGS__; \
} break; \
case int16: { \
using T = int16_t; \
__VA_ARGS__; \
} break; \
case int32: { \
using T = int32_t; \
__VA_ARGS__; \
} break; \
case int64: { \
using T = int64_t; \
__VA_ARGS__; \
} break; \
case uint8: { \
using T = uint8_t; \
__VA_ARGS__; \
} break; \
case uint16: { \
using T = uint16_t; \
__VA_ARGS__; \
} break; \
case uint32: { \
using T = uint32_t; \
__VA_ARGS__; \
} break; \
case uint64: { \
using T = uint64_t; \
__VA_ARGS__; \
} break; \
case bfloat16: { \
using T = bfloat16_t; \
__VA_ARGS__; \
} break; \
case float16: { \
using T = float16_t; \
__VA_ARGS__; \
} break; \
case float32: { \
using T = float; \
__VA_ARGS__; \
} break; \
case complex64: { \
using T = complex64_t; \
__VA_ARGS__; \
} break; \
}
namespace mlx::core::distributed::ring {
constexpr const size_t PACKET_SIZE = 262144;
constexpr const int CONN_ATTEMPTS = 5;
constexpr const int CONN_WAIT = 1000;
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
using json = nlohmann::json;
namespace {
class Barrier {
public:
explicit Barrier(int n_threads)
: n_threads_(n_threads), count_(0), flag_(false) {}
void arrive_and_wait() {
std::unique_lock<std::mutex> lock(mtx_);
// Keep the flag that marks the current use of the barrier. The next use is
// going to have this flag flipped.
bool initial_flag = flag_;
// Increment the count
count_++;
// We are the last thread to arrive so reset the count, change the flag and
// notify everybody.
if (count_ == n_threads_) {
count_ = 0;
flag_ = !flag_;
cv_.notify_all();
}
// Wait for the rest to arrive
else {
cv_.wait(lock, [this, initial_flag]() { return initial_flag != flag_; });
}
}
private:
std::mutex mtx_;
std::condition_variable cv_;
int n_threads_;
int count_;
bool flag_; // we need this for sequential use of the barrier
};
template <typename T>
void log(std::ostream& os, T first) {
os << first << std::endl;
}
template <typename T, typename... Args>
void log(std::ostream& os, T first, Args... args) {
log(os << first << " ", args...);
}
template <typename... Args>
void log_info(bool verbose, Args... args) {
if (!verbose) {
return;
}
log(std::cerr, "[ring]", args...);
}
template <typename T, typename U>
decltype(T() * U()) ceildiv(T a, U b) {
return (a + b - 1) / b;
}
struct address_t {
sockaddr_storage addr;
socklen_t len;
const sockaddr* get() const {
return (struct sockaddr*)&addr;
}
};
/**
* Parse a sockaddr from an ip and port provided as strings.
*/
address_t parse_address(const std::string& ip, const std::string& port) {
struct addrinfo hints, *res;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
if (status != 0) {
std::ostringstream msg;
msg << "Can't parse address " << ip << ":" << port;
throw std::runtime_error(msg.str());
}
address_t result;
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
result.len = res->ai_addrlen;
freeaddrinfo(res);
return result;
}
/**
* Parse a sockaddr provided as an <ip>:<port> string.
*/
address_t parse_address(const std::string& ip_port) {
auto colon = ip_port.find(":");
if (colon == std::string::npos) {
std::ostringstream msg;
msg << "Can't parse address " << ip_port;
throw std::runtime_error(msg.str());
}
std::string ip(ip_port.begin(), ip_port.begin() + colon);
std::string port(ip_port.begin() + colon + 1, ip_port.end());
return parse_address(ip, port);
}
/**
* Load all addresses from the json hostfile. The hostfile is a list of
* addresses in order of rank. For each rank there can be many addresses so
* that we can have multiple connections between peers.
*
* For example:
* [
* ["ip1:5000", "ip1:5001"],
* ["ip2:5000", "ip2:5001"],
* ["ip3:5000", "ip3:5001"],
* ]
*/
std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
std::vector<std::vector<address_t>> nodes;
std::ifstream f(hostfile);
json hosts = json::parse(f);
for (auto& h : hosts) {
std::vector<address_t> host;
for (auto& ips : h) {
host.push_back(std::move(parse_address(ips.get<std::string>())));
}
nodes.push_back(std::move(host));
}
return nodes;
}
/**
* Create a socket and accept one connection for each of the provided
* addresses.
*/
std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
std::vector<int> sockets;
int success;
for (auto& address : addresses) {
// Create the socket to wait for connections from the peers
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "[ring] Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Make sure we can launch immediately after shutdown by setting the
// reuseaddr option so that we don't get address already in use errors
int enable = 1;
success = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't enable reuseaddr (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
success = setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't enable reuseport (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Bind the socket to the address and port
success = bind(sock, address.get(), address.len);
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't bind socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Wait for connections
success = listen(sock, 0);
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't listen (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
int peer_socket = accept(sock, nullptr, nullptr);
if (peer_socket < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Accept failed (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Close the listening socket
shutdown(sock, 2);
close(sock);
sockets.push_back(peer_socket);
}
return sockets;
}
/**
* The counterpoint of `accept_connections`. Basically connect to each of the
* provided addresses.
*/
std::vector<int> make_connections(
const std::vector<address_t>& addresses,
bool verbose) {
std::vector<int> sockets;
int success;
for (auto& address : addresses) {
int sock;
// Attempt to connect to the peer CONN_ATTEMPTS times with exponential
// backoff. TODO: Do we need that?
for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
// Create the socket
sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "[ring] Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
if (attempt > 0) {
int wait = (1 << (attempt - 1)) * CONN_WAIT;
log_info(
verbose,
"Attempt",
attempt,
"wait",
wait,
"ms (error:",
errno,
")");
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
}
success = connect(sock, address.get(), address.len);
if (success == 0) {
break;
}
}
if (success < 0) {
std::ostringstream msg;
msg << "[ring] Couldn't connect (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
sockets.push_back(sock);
}
return sockets;
}
array ensure_row_contiguous(const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
}
template <typename T>
void sum_inplace(const T* input, T* output, size_t N) {
while (N-- > 0) {
*output += *input;
input++;
output++;
}
}
template <typename T>
void _send(int sock, T* data, size_t start, size_t stop) {
if (stop <= start) {
return;
}
data += start;
size_t len = (stop - start) * sizeof(T);
const char* buffer = (const char*)data;
while (len > 0) {
ssize_t r = send(sock, buffer, len, 0);
if (r <= 0) {
std::ostringstream msg;
msg << "Send of " << len << " bytes failed (errno: " << errno << ")";
throw std::runtime_error(msg.str());
}
buffer += r;
len -= r;
}
}
template <typename T>
void _recv(int sock, T* data, size_t start, size_t stop) {
if (stop <= start) {
return;
}
data += start;
size_t len = (stop - start) * sizeof(T);
char* buffer = (char*)data;
while (len > 0) {
ssize_t r = recv(sock, buffer, len, 0);
if (r <= 0) {
std::ostringstream msg;
msg << "Recv of " << len << " bytes failed (errno: " << errno << ")";
throw std::runtime_error(msg.str());
}
buffer += r;
len -= r;
}
}
template <typename T>
void _recv_sum(int sock, T* data, size_t start, size_t stop) {
if (stop <= start) {
return;
}
data += start;
char buffer[PACKET_SIZE];
size_t len = (stop - start) * sizeof(T);
while (len > 0) {
ssize_t r = 0;
do {
ssize_t partial_r =
recv(sock, buffer + r, std::min(len, PACKET_SIZE) - r, 0);
if (partial_r <= 0) {
std::ostringstream msg;
msg << "Recv of " << len << " bytes failed (errno: " << errno << ")";
throw std::runtime_error(msg.str());
}
r += partial_r;
} while (r % sizeof(T));
sum_inplace((const T*)buffer, data, r / sizeof(T));
data += r / sizeof(T);
len -= r;
}
}
template <typename T>
void ring_send(
Barrier& barrier,
int socket,
int rank,
int size,
T* data,
size_t data_size,
int direction = -1) {
// We split the data into `size_` segments of size `segment_size`
size_t segment_size = ceildiv(data_size, size);
// Initial segment
int segment = rank;
// 1st send
for (int i = 0; i < size - 1; i++) {
size_t start = segment * segment_size;
size_t stop = std::min((segment + 1) * segment_size, data_size);
_send<T>(socket, data, start, stop);
barrier.arrive_and_wait();
segment = (segment + size + direction) % size;
}
// 2nd send
for (int i = 0; i < size - 1; i++) {
size_t start = segment * segment_size;
size_t stop = std::min((segment + 1) * segment_size, data_size);
_send<T>(socket, data, start, stop);
barrier.arrive_and_wait();
segment = (segment + size + direction) % size;
}
}
template <typename T>
void ring_recv_sum(
Barrier& barrier,
int socket,
int rank,
int size,
T* data,
size_t data_size,
int direction = -1) {
// We split the data into `size_` segments of size `segment_size`
size_t segment_size = ceildiv(data_size, size);
// Initial segment
int segment = (rank + size + direction) % size;
// Recv sum
for (int i = 0; i < size - 1; i++) {
size_t start = segment * segment_size;
size_t stop = std::min((segment + 1) * segment_size, data_size);
_recv_sum<T>(socket, data, start, stop);
barrier.arrive_and_wait();
segment = (segment + size + direction) % size;
}
// Recv
for (int i = 0; i < size - 1; i++) {
size_t start = segment * segment_size;
size_t stop = std::min((segment + 1) * segment_size, data_size);
_recv<T>(socket, data, start, stop);
barrier.arrive_and_wait();
segment = (segment + size + direction) % size;
}
}
} // namespace
class RingGroup : public GroupImpl {
public:
RingGroup(int rank, std::vector<std::vector<address_t>> nodes, bool verbose)
: rank_(rank), verbose_(verbose), pool_(0) {
if (rank_ > 0 && rank_ >= nodes.size()) {
throw std::runtime_error(
"[ring] Rank cannot be larger than the size of the group");
}
size_ = nodes.size();
int connect_to = (rank_ + 1) % size_;
// We define the connection order by having the rank_ == size_ - 1 connect
// first and accept after.
if (rank_ < connect_to) {
log_info(verbose_, "Rank", rank_, "accepting");
recv_sockets_ = std::move(accept_connections(nodes[rank_]));
log_info(verbose_, "Rank", rank_, "connecting to", connect_to);
send_sockets_ = std::move(make_connections(nodes[connect_to], verbose));
} else {
log_info(verbose_, "Rank", rank_, "connecting to", connect_to);
send_sockets_ = std::move(make_connections(nodes[connect_to], verbose));
log_info(verbose_, "Rank", rank_, "accepting");
recv_sockets_ = std::move(accept_connections(nodes[rank_]));
}
// Failure if we couldn't make send or recv sockets
if (send_sockets_.empty()) {
std::ostringstream msg;
msg << "[ring] Rank " << rank_ << " has no send sockets.";
throw std::invalid_argument(msg.str());
}
if (recv_sockets_.empty()) {
std::ostringstream msg;
msg << "[ring] Rank " << rank_ << " has no recv sockets.";
throw std::invalid_argument(msg.str());
}
// The following could be relaxed since we can define non-homogeneous rings
// but it makes things a bit simpler for now.
if (send_sockets_.size() != recv_sockets_.size()) {
std::ostringstream msg;
msg << "[ring] It is required to have as many connections to the left as "
<< "to the right but rank " << rank_ << " has "
<< send_sockets_.size() << " connections to the right and "
<< recv_sockets_.size() << " to the left.";
throw std::invalid_argument(msg.str());
}
// Start the necessary threads for completely parallel operation on all
// channels. One thread to send, one to receive per socket.
pool_.resize(send_sockets_.size() * 2 * 2);
}
~RingGroup() {
for (auto s : send_sockets_) {
shutdown(s, 2);
close(s);
}
for (auto s : recv_sockets_) {
shutdown(s, 2);
close(s);
}
}
int rank() override {
return rank_;
}
int size() override {
return size_;
}
void all_sum(const array& input_, array& output) override {
SWITCH_TYPE(output, all_sum<T>(input_, output));
}
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
throw std::runtime_error("[ring] Group split not supported.");
}
void all_gather(const array& input, array& output) override {
throw std::runtime_error("[ring] All gather not supported.");
}
void send(const array& input, int dst) override {
throw std::runtime_error("[ring] Send not supported.");
}
void recv(array& out, int src) override {
throw std::runtime_error("[ring] Recv not supported.");
}
private:
template <typename T>
void all_sum(const array& input_, array& output) {
// Make sure that the input is row contiguous
array input = ensure_row_contiguous(input_);
// If the input data cannot be split into size_ segments then copy it and
// all reduce a local buffer prefilled with 0s.
if (input.size() < size_) {
// TODO: Maybe allocate dynamically so we don't have the constraint below?
if (input.itemsize() * size_ > 1024) {
std::ostringstream msg;
msg << "Can't perform the ring all reduce of " << output.size()
<< " elements with a ring of size " << size_;
throw std::runtime_error(msg.str());
}
std::future<void> sent, recvd;
auto barrier = std::make_unique<Barrier>(2);
char buffer[1024];
std::memset(buffer, 0, size_ * input.itemsize());
std::memcpy(buffer, input.data<char>(), input.nbytes());
sent = pool_.enqueue(
ring_send<T>,
std::reference_wrapper(*barrier),
send_sockets_[0],
rank_,
size_,
(T*)buffer,
size_,
-1);
recvd = pool_.enqueue(
ring_recv_sum<T>,
std::reference_wrapper(*barrier),
recv_sockets_[0],
rank_,
size_,
(T*)buffer,
size_,
-1);
sent.wait();
recvd.wait();
std::memcpy(output.data<char>(), buffer, output.nbytes());
return;
}
// If not inplace all reduce then copy the input to the output first
if (input.data<void>() != output.data<void>()) {
std::memcpy(output.data<char>(), input.data<char>(), input.nbytes());
}
// All reduce in place. We have `send_channels_.size()` bidirectional
// channels so let's split the message up and perform as many parallel
// ring-reductions as possible.
std::vector<std::future<void>> reductions;
std::vector<std::unique_ptr<Barrier>> barriers;
size_t packets = ceildiv(output.size(), size_ * PACKET_SIZE);
// Large all reduce territory so let's use all we got
if (packets >= 2 * send_sockets_.size()) {
size_t segment = ceildiv(output.size(), 2 * send_sockets_.size());
for (int i = 0; i < send_sockets_.size(); i++) {
// 1st ring reduce
barriers.emplace_back(std::make_unique<Barrier>(2));
reductions.push_back(pool_.enqueue(
ring_send<T>,
std::reference_wrapper(*barriers.back()),
send_sockets_[i],
rank_,
size_,
output.data<T>() + 2 * i * segment,
std::min(output.size() - 2 * i * segment, segment),
-1));
reductions.push_back(pool_.enqueue(
ring_recv_sum<T>,
std::reference_wrapper(*barriers.back()),
recv_sockets_[i],
rank_,
size_,
output.data<T>() + 2 * i * segment,
std::min(output.size() - 2 * i * segment, segment),
-1));
// 2nd ring reduce
barriers.emplace_back(std::make_unique<Barrier>(2));
reductions.push_back(pool_.enqueue(
ring_send<T>,
std::reference_wrapper(*barriers.back()),
recv_sockets_[i],
rank_,
size_,
output.data<T>() + (2 * i + 1) * segment,
std::min(output.size() - (2 * i + 1) * segment, segment),
1));
reductions.push_back(pool_.enqueue(
ring_recv_sum<T>,
std::reference_wrapper(*barriers.back()),
send_sockets_[i],
rank_,
size_,
output.data<T>() + (2 * i + 1) * segment,
std::min(output.size() - (2 * i + 1) * segment, segment),
1));
}
}
// At least 2 reductions so we can be from small to medium
else if (packets > 1) {
size_t segment = ceildiv(output.size(), packets);
for (int i = 0; i < send_sockets_.size(); i++) {
barriers.emplace_back(std::make_unique<Barrier>(2));
reductions.push_back(pool_.enqueue(
ring_send<T>,
std::reference_wrapper(*barriers.back()),
send_sockets_[i],
rank_,
size_,
output.data<T>() + i * segment,
std::min(output.size() - i * segment, segment),
-1));
reductions.push_back(pool_.enqueue(
ring_recv_sum<T>,
std::reference_wrapper(*barriers.back()),
recv_sockets_[i],
rank_,
size_,
output.data<T>() + i * segment,
std::min(output.size() - i * segment, segment),
-1));
}
for (int i = 0; i < packets - send_sockets_.size(); i++) {
barriers.emplace_back(std::make_unique<Barrier>(2));
reductions.push_back(pool_.enqueue(
ring_send<T>,
std::reference_wrapper(*barriers.back()),
recv_sockets_[i],
rank_,
size_,
output.data<T>() + (send_sockets_.size() + i) * segment,
std::min(
output.size() - (send_sockets_.size() + i) * segment, segment),
1));
reductions.push_back(pool_.enqueue(
ring_recv_sum<T>,
std::reference_wrapper(*barriers.back()),
send_sockets_[i],
rank_,
size_,
output.data<T>() + (send_sockets_.size() + i) * segment,
std::min(
output.size() - (send_sockets_.size() + i) * segment, segment),
1));
}
}
// Small reduction which won't really benefit much from parallelization.
// TODO: Verify that this is true cause PACKET_SIZE * size_ can still be a
// fairly large array.
else {
barriers.emplace_back(std::make_unique<Barrier>(2));
reductions.push_back(pool_.enqueue(
ring_send<T>,
std::reference_wrapper(*barriers.back()),
send_sockets_[0],
rank_,
size_,
output.data<T>(),
output.size(),
-1));
reductions.push_back(pool_.enqueue(
ring_recv_sum<T>,
std::reference_wrapper(*barriers.back()),
recv_sockets_[0],
rank_,
size_,
output.data<T>(),
output.size(),
-1));
}
// Wait for the reductions to finish.
for (auto& f : reductions) {
f.wait();
}
}
int rank_;
int size_;
bool verbose_;
ThreadPool pool_;
std::vector<int> send_sockets_;
std::vector<int> recv_sockets_;
};
bool is_available() {
return true;
}
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
const char* hostfile = std::getenv("MLX_HOSTFILE");
const char* rank_str = std::getenv("MLX_RANK");
const char* ring_verbose = std::getenv("MLX_RING_VERBOSE");
if (!hostfile || !rank_str) {
if (strict) {
std::ostringstream msg;
msg << "[ring] You need to provide via environment variables both a rank (MLX_RANK) "
<< "and a hostfile (MLX_HOSTFILE) but provided MLX_RANK=\""
<< ((rank_str) ? rank_str : "") << "\" and MLX_HOSTFILE=\""
<< ((hostfile) ? hostfile : "") << "\"";
throw std::runtime_error(msg.str());
}
return nullptr;
}
auto nodes = load_nodes(hostfile);
int rank = std::atoi(rank_str);
return std::make_shared<RingGroup>(rank, nodes, ring_verbose != nullptr);
}
} // namespace mlx::core::distributed::ring

View File

@ -0,0 +1,12 @@
// Copyright © 2024 Apple Inc.
#include "mlx/distributed/distributed.h"
namespace mlx::core::distributed::ring {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available();
std::shared_ptr<GroupImpl> init(bool strict = false);
} // namespace mlx::core::distributed::ring

View File

@ -13,7 +13,7 @@
#include <unistd.h> #include <unistd.h>
#endif #endif
#include "mlx/io/threadpool.h" #include "mlx/threadpool.h"
// Strictly we need to operate on files in binary mode (to avoid \r getting // Strictly we need to operate on files in binary mode (to avoid \r getting
// automatically inserted), but every modern system except for Windows no // automatically inserted), but every modern system except for Windows no

View File

@ -652,7 +652,7 @@ void normalize_dynamic_slice_inputs(
const array& a, const array& a,
const array& start, const array& start,
std::vector<int>& axes, std::vector<int>& axes,
const std::string prefix) { std::string_view prefix) {
if (start.size() > a.ndim()) { if (start.size() > a.ndim()) {
std::ostringstream msg; std::ostringstream msg;
msg << prefix << " Invalid number of starting positions for " msg << prefix << " Invalid number of starting positions for "
@ -690,7 +690,9 @@ void normalize_dynamic_slice_inputs(
} }
std::set dims(axes.begin(), axes.end()); std::set dims(axes.begin(), axes.end());
if (dims.size() != axes.size()) { if (dims.size() != axes.size()) {
throw std::invalid_argument(prefix + " Repeat axes not allowed."); std::ostringstream msg;
msg << prefix << " Repeat axes not allowed.";
throw std::invalid_argument(msg.str());
} }
} }
@ -927,7 +929,7 @@ split(const array& a, int num_splits, StreamOrDevice s /* = {} */) {
std::vector<array> meshgrid( std::vector<array> meshgrid(
const std::vector<array>& arrays, const std::vector<array>& arrays,
bool sparse /* = false */, bool sparse /* = false */,
std::string indexing /* = "xy" */, const std::string& indexing /* = "xy" */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
if (indexing != "xy" && indexing != "ij") { if (indexing != "xy" && indexing != "ij") {
throw std::invalid_argument( throw std::invalid_argument(
@ -1186,7 +1188,7 @@ array pad(
const Shape& low_pad_size, const Shape& low_pad_size,
const Shape& high_pad_size, const Shape& high_pad_size,
const array& pad_value /*= array(0)*/, const array& pad_value /*= array(0)*/,
const std::string mode /*= "constant"*/, const std::string& mode /*= "constant"*/,
StreamOrDevice s /* = {}*/) { StreamOrDevice s /* = {}*/) {
if (axes.size() != low_pad_size.size() || if (axes.size() != low_pad_size.size() ||
axes.size() != high_pad_size.size()) { axes.size() != high_pad_size.size()) {
@ -1238,7 +1240,7 @@ array pad(
const array& a, const array& a,
const std::vector<std::pair<int, int>>& pad_width, const std::vector<std::pair<int, int>>& pad_width,
const array& pad_value /*= array(0)*/, const array& pad_value /*= array(0)*/,
const std::string mode /*= "constant"*/, const std::string& mode /*= "constant"*/,
StreamOrDevice s /*= {}*/) { StreamOrDevice s /*= {}*/) {
std::vector<int> axes(a.ndim(), 0); std::vector<int> axes(a.ndim(), 0);
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
@ -1258,7 +1260,7 @@ array pad(
const array& a, const array& a,
const std::pair<int, int>& pad_width, const std::pair<int, int>& pad_width,
const array& pad_value /*= array(0)*/, const array& pad_value /*= array(0)*/,
const std::string mode /*= "constant"*/, const std::string& mode /*= "constant"*/,
StreamOrDevice s /*= {}*/) { StreamOrDevice s /*= {}*/) {
return pad( return pad(
a, a,
@ -1272,7 +1274,7 @@ array pad(
const array& a, const array& a,
int pad_width, int pad_width,
const array& pad_value /*= array(0)*/, const array& pad_value /*= array(0)*/,
const std::string mode /*= "constant"*/, const std::string& mode /*= "constant"*/,
StreamOrDevice s /*= {}*/) { StreamOrDevice s /*= {}*/) {
return pad( return pad(
a, a,

View File

@ -222,7 +222,7 @@ split(const array& a, const Shape& indices, StreamOrDevice s = {});
std::vector<array> meshgrid( std::vector<array> meshgrid(
const std::vector<array>& arrays, const std::vector<array>& arrays,
bool sparse = false, bool sparse = false,
std::string indexing = "xy", const std::string& indexing = "xy",
StreamOrDevice s = {}); StreamOrDevice s = {});
/** /**
@ -274,7 +274,7 @@ array pad(
const Shape& low_pad_size, const Shape& low_pad_size,
const Shape& high_pad_size, const Shape& high_pad_size,
const array& pad_value = array(0), const array& pad_value = array(0),
const std::string mode = "constant", const std::string& mode = "constant",
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Pad an array with a constant value along all axes */ /** Pad an array with a constant value along all axes */
@ -282,19 +282,19 @@ array pad(
const array& a, const array& a,
const std::vector<std::pair<int, int>>& pad_width, const std::vector<std::pair<int, int>>& pad_width,
const array& pad_value = array(0), const array& pad_value = array(0),
const std::string mode = "constant", const std::string& mode = "constant",
StreamOrDevice s = {}); StreamOrDevice s = {});
array pad( array pad(
const array& a, const array& a,
const std::pair<int, int>& pad_width, const std::pair<int, int>& pad_width,
const array& pad_value = array(0), const array& pad_value = array(0),
const std::string mode = "constant", const std::string& mode = "constant",
StreamOrDevice s = {}); StreamOrDevice s = {});
array pad( array pad(
const array& a, const array& a,
int pad_width, int pad_width,
const array& pad_value = array(0), const array& pad_value = array(0),
const std::string mode = "constant", const std::string& mode = "constant",
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Permutes the dimensions in reverse order. */ /** Permutes the dimensions in reverse order. */

View File

@ -38,9 +38,13 @@ class ThreadPool {
template <class F, class... Args> template <class F, class... Args>
auto enqueue(F&& f, Args&&... args) auto enqueue(F&& f, Args&&... args)
-> std::future<typename std::invoke_result_t<F, Args...>>; -> std::future<typename std::invoke_result_t<F, Args...>>;
void resize(size_t);
~ThreadPool(); ~ThreadPool();
private: private:
void stop_and_wait();
void start_threads(size_t);
std::vector<std::thread> workers; std::vector<std::thread> workers;
std::queue<std::function<void()>> tasks; std::queue<std::function<void()>> tasks;
std::mutex queue_mutex; std::mutex queue_mutex;
@ -49,24 +53,7 @@ class ThreadPool {
}; };
inline ThreadPool::ThreadPool(size_t threads) : stop(false) { inline ThreadPool::ThreadPool(size_t threads) : stop(false) {
for (size_t i = 0; i < threads; ++i) start_threads(threads);
workers.emplace_back([this] {
for (;;) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(
lock, [this] { return this->stop || !this->tasks.empty(); });
if (this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
});
} }
template <class F, class... Args> template <class F, class... Args>
@ -92,12 +79,55 @@ auto ThreadPool::enqueue(F&& f, Args&&... args)
return res; return res;
} }
inline void ThreadPool::resize(size_t threads) {
if (workers.size() == threads) {
return;
}
if (workers.size() > threads) {
stop_and_wait();
}
start_threads(threads - workers.size());
}
inline ThreadPool::~ThreadPool() { inline ThreadPool::~ThreadPool() {
stop_and_wait();
}
inline void ThreadPool::stop_and_wait() {
// Stop the current threads and wait until they finish
{ {
std::unique_lock<std::mutex> lock(queue_mutex); std::unique_lock<std::mutex> lock(queue_mutex);
stop = true; stop = true;
} }
condition.notify_all(); condition.notify_all();
for (std::thread& worker : workers) for (std::thread& worker : workers) {
worker.join(); worker.join();
}
// Reset the member variables so that the threadpool is reusable
stop = false;
workers.clear();
}
inline void ThreadPool::start_threads(size_t threads) {
for (size_t i = 0; i < threads; ++i) {
workers.emplace_back([this] {
for (;;) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(
lock, [this] { return this->stop || !this->tasks.empty(); });
if (this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
});
}
} }

View File

@ -3,6 +3,7 @@
#include <nanobind/nanobind.h> #include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h> #include <nanobind/stl/optional.h>
#include <nanobind/stl/shared_ptr.h> #include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/variant.h> #include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h> #include <nanobind/stl/vector.h>
@ -58,14 +59,26 @@ void init_distributed(nb::module_& parent_module) {
"init", "init",
&mx::distributed::init, &mx::distributed::init,
"strict"_a = false, "strict"_a = false,
nb::sig("def init(strict: bool = False) -> Group"), "backend"_a = "any",
nb::sig("def init(strict: bool = False, backend: str = 'any') -> Group"),
R"pbdoc( R"pbdoc(
Initialize the communication backend and create the global communication group. Initialize the communication backend and create the global communication group.
Example:
import mlx.core as mx
group = mx.distributed.init(backend="ring")
Args: Args:
strict (bool, optional): If set to False it returns a singleton group strict (bool, optional): If set to False it returns a singleton group
in case ``mx.distributed.is_available()`` returns False otherwise in case ``mx.distributed.is_available()`` returns False otherwise
it throws a runtime error. Default: ``False`` it throws a runtime error. Default: ``False``
backend (str, optional): Select a specific distributed backend to
initialize. If set to ``any`` then try all available backends and
return the first one that succeeds. Subsequent calls will return
the first backend that was initialized. Default: ``any``
Returns: Returns:
Group: The group representing all the launched processes. Group: The group representing all the launched processes.

View File

@ -34,6 +34,8 @@ class TestDistributed(mlx_tests.MLXTestCase):
mx.int32, mx.int32,
mx.uint32, mx.uint32,
mx.float32, mx.float32,
mx.float16,
mx.bfloat16,
mx.complex64, mx.complex64,
] ]
for dt in dtypes: for dt in dtypes:

View File

@ -0,0 +1,61 @@
# Copyright © 2024 Apple Inc.
import unittest
import mlx.core as mx
import mlx_tests
class TestRingDistributed(mlx_tests.MLXTestCase):
@classmethod
def setUpClass(cls):
world = mx.distributed.init(strict=True, backend="ring")
def test_groups(self):
world = mx.distributed.init()
self.assertEqual(world.size(), 8)
self.assertTrue(0 <= world.rank() < 8)
world2 = mx.distributed.init()
self.assertEqual(world.size(), world2.size())
self.assertEqual(world.rank(), world2.rank())
with self.assertRaises(RuntimeError):
sub = world.split(world.rank() % 2)
def test_all_reduce(self):
world = mx.distributed.init()
dtypes = [
(mx.int8, 0),
(mx.uint8, 0),
(mx.int16, 0),
(mx.uint16, 0),
(mx.int32, 0),
(mx.uint32, 0),
(mx.float32, 1e-6),
(mx.float16, 5e-3),
(mx.bfloat16, 1e-1),
(mx.complex64, 1e-6),
]
sizes = [
(7,),
(10,),
(1024,),
(1024, 1024),
]
key = mx.random.key(0)
for dt, rtol in dtypes:
for sh in sizes:
x = (
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
).astype(dt)
y = mx.distributed.all_sum(x[world.rank()])
z = sum(
x[i] for i in range(world.size())
) # to ensure that we don't sum to int32
maxrelerror = ((y - z).abs() / z.abs()).max()
self.assertLessEqual(maxrelerror, rtol)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,25 @@
#!/bin/bash
tmpfile=$(mktemp)
cat <<HOSTFILE >$tmpfile
[
["127.0.0.1:5000"],
["127.0.0.1:5001"],
["127.0.0.1:5002"],
["127.0.0.1:5003"],
["127.0.0.1:5004"],
["127.0.0.1:5005"],
["127.0.0.1:5006"],
["127.0.0.1:5007"]
]
HOSTFILE
ring_test="$(dirname ${BASH_SOURCE[0]})/ring_test_distributed.py"
for i in {0..7}; do
if (($i == 7)); then
sleep 1
fi
DEVICE=cpu MLX_RING_VERBOSE=1 MLX_HOSTFILE=$tmpfile MLX_RANK=$i python $ring_test &
done
wait