mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Ring distributed backend (#1784)
This commit is contained in:
parent
2235dee906
commit
ccb61d7aae
@ -160,6 +160,7 @@ jobs:
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
/bin/bash python/tests/run_ring_test.sh
|
||||
- run:
|
||||
name: Build example extension
|
||||
command: |
|
||||
|
@ -5,3 +5,4 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
||||
|
@ -1,8 +1,11 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/mpi/mpi.h"
|
||||
#include "mlx/distributed/ring/ring.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
@ -65,7 +68,7 @@ class EmptyGroup : public GroupImpl {
|
||||
} // namespace detail
|
||||
|
||||
bool is_available() {
|
||||
return mpi::is_available();
|
||||
return mpi::is_available() || ring::is_available();
|
||||
}
|
||||
|
||||
int Group::rank() const {
|
||||
@ -80,20 +83,50 @@ Group Group::split(int color, int key /* = -1 */) const {
|
||||
return Group(group_->split(color, key));
|
||||
}
|
||||
|
||||
Group init(bool strict /* = false */) {
|
||||
auto init_group = [strict]() {
|
||||
auto default_group = mpi::init(strict);
|
||||
if (default_group == nullptr) {
|
||||
default_group = std::make_shared<detail::EmptyGroup>();
|
||||
Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
||||
static std::unordered_map<std::string, std::shared_ptr<detail::GroupImpl>>
|
||||
backends;
|
||||
|
||||
// Already initialized so return the group.
|
||||
if (auto g = backends.find(bk); g != backends.end()) {
|
||||
return Group(g->second);
|
||||
}
|
||||
|
||||
// Create the requested communication group
|
||||
std::shared_ptr<detail::GroupImpl> group;
|
||||
std::string bk_ = bk;
|
||||
if (bk == "mpi") {
|
||||
group = mpi::init(strict);
|
||||
} else if (bk == "ring") {
|
||||
group = ring::init(strict);
|
||||
} else if (bk == "any") {
|
||||
group = ring::init(false);
|
||||
bk_ = "ring";
|
||||
if (group == nullptr) {
|
||||
group = mpi::init(false);
|
||||
bk_ = "mpi";
|
||||
}
|
||||
return default_group;
|
||||
};
|
||||
static std::shared_ptr<detail::GroupImpl> default_group = init_group();
|
||||
if (group == nullptr && strict) {
|
||||
throw std::runtime_error("[distributed] Couldn't initialize any backend");
|
||||
}
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "[distributed] The only valid values for backend are 'any', 'mpi' "
|
||||
<< "and 'ring' but '" << bk << "' was provided.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (group == nullptr) {
|
||||
group = std::make_shared<detail::EmptyGroup>();
|
||||
} else {
|
||||
backends.insert({"any", group});
|
||||
}
|
||||
backends.insert({std::move(bk_), group});
|
||||
|
||||
// Ensure the communication stream is alive before
|
||||
// the graph is evaluated
|
||||
detail::communication_stream();
|
||||
return Group(default_group);
|
||||
return Group(group);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
@ -53,6 +53,6 @@ struct Group {
|
||||
* distributed subsystem. Otherwise simply return a singleton group which will
|
||||
* render communication operations as no-op.
|
||||
*/
|
||||
Group init(bool strict = false);
|
||||
Group init(bool strict = false, const std::string& bk = "any");
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
@ -11,6 +11,8 @@ namespace mlx::core::distributed::detail {
|
||||
*/
|
||||
class GroupImpl {
|
||||
public:
|
||||
virtual ~GroupImpl() {}
|
||||
|
||||
virtual int rank() = 0;
|
||||
virtual int size() = 0;
|
||||
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
|
||||
|
5
mlx/distributed/ring/CMakeLists.txt
Normal file
5
mlx/distributed/ring/CMakeLists.txt
Normal file
@ -0,0 +1,5 @@
|
||||
if(MLX_BUILD_CPU)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ring.cpp)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_ring.cpp)
|
||||
endif()
|
20
mlx/distributed/ring/no_ring.cpp
Normal file
20
mlx/distributed/ring/no_ring.cpp
Normal file
@ -0,0 +1,20 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/distributed/ring/ring.h"
|
||||
|
||||
namespace mlx::core::distributed::ring {
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||
if (strict) {
|
||||
throw std::runtime_error("Cannot initialize ring distributed backend.");
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed::ring
|
827
mlx/distributed/ring/ring.cpp
Normal file
827
mlx/distributed/ring/ring.cpp
Normal file
@ -0,0 +1,827 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <netdb.h>
|
||||
#include <sys/socket.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
|
||||
#include <json.hpp>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/threadpool.h"
|
||||
|
||||
#define SWITCH_TYPE(x, ...) \
|
||||
switch ((x).dtype()) { \
|
||||
case bool_: { \
|
||||
using T = bool; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case int8: { \
|
||||
using T = int8_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case int16: { \
|
||||
using T = int16_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case int32: { \
|
||||
using T = int32_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case int64: { \
|
||||
using T = int64_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case uint8: { \
|
||||
using T = uint8_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case uint16: { \
|
||||
using T = uint16_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case uint32: { \
|
||||
using T = uint32_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case uint64: { \
|
||||
using T = uint64_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case bfloat16: { \
|
||||
using T = bfloat16_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case float16: { \
|
||||
using T = float16_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case float32: { \
|
||||
using T = float; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case complex64: { \
|
||||
using T = complex64_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
}
|
||||
|
||||
namespace mlx::core::distributed::ring {
|
||||
|
||||
constexpr const size_t PACKET_SIZE = 262144;
|
||||
constexpr const int CONN_ATTEMPTS = 5;
|
||||
constexpr const int CONN_WAIT = 1000;
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
using json = nlohmann::json;
|
||||
|
||||
namespace {
|
||||
|
||||
class Barrier {
|
||||
public:
|
||||
explicit Barrier(int n_threads)
|
||||
: n_threads_(n_threads), count_(0), flag_(false) {}
|
||||
|
||||
void arrive_and_wait() {
|
||||
std::unique_lock<std::mutex> lock(mtx_);
|
||||
|
||||
// Keep the flag that marks the current use of the barrier. The next use is
|
||||
// going to have this flag flipped.
|
||||
bool initial_flag = flag_;
|
||||
|
||||
// Increment the count
|
||||
count_++;
|
||||
|
||||
// We are the last thread to arrive so reset the count, change the flag and
|
||||
// notify everybody.
|
||||
if (count_ == n_threads_) {
|
||||
count_ = 0;
|
||||
flag_ = !flag_;
|
||||
cv_.notify_all();
|
||||
}
|
||||
|
||||
// Wait for the rest to arrive
|
||||
else {
|
||||
cv_.wait(lock, [this, initial_flag]() { return initial_flag != flag_; });
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::mutex mtx_;
|
||||
std::condition_variable cv_;
|
||||
int n_threads_;
|
||||
|
||||
int count_;
|
||||
bool flag_; // we need this for sequential use of the barrier
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void log(std::ostream& os, T first) {
|
||||
os << first << std::endl;
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
void log(std::ostream& os, T first, Args... args) {
|
||||
log(os << first << " ", args...);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void log_info(bool verbose, Args... args) {
|
||||
if (!verbose) {
|
||||
return;
|
||||
}
|
||||
|
||||
log(std::cerr, "[ring]", args...);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
decltype(T() * U()) ceildiv(T a, U b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
struct address_t {
|
||||
sockaddr_storage addr;
|
||||
socklen_t len;
|
||||
|
||||
const sockaddr* get() const {
|
||||
return (struct sockaddr*)&addr;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Parse a sockaddr from an ip and port provided as strings.
|
||||
*/
|
||||
address_t parse_address(const std::string& ip, const std::string& port) {
|
||||
struct addrinfo hints, *res;
|
||||
memset(&hints, 0, sizeof(hints));
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
|
||||
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
|
||||
if (status != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Can't parse address " << ip << ":" << port;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
address_t result;
|
||||
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
|
||||
result.len = res->ai_addrlen;
|
||||
freeaddrinfo(res);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse a sockaddr provided as an <ip>:<port> string.
|
||||
*/
|
||||
address_t parse_address(const std::string& ip_port) {
|
||||
auto colon = ip_port.find(":");
|
||||
if (colon == std::string::npos) {
|
||||
std::ostringstream msg;
|
||||
msg << "Can't parse address " << ip_port;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
std::string ip(ip_port.begin(), ip_port.begin() + colon);
|
||||
std::string port(ip_port.begin() + colon + 1, ip_port.end());
|
||||
|
||||
return parse_address(ip, port);
|
||||
}
|
||||
|
||||
/**
|
||||
* Load all addresses from the json hostfile. The hostfile is a list of
|
||||
* addresses in order of rank. For each rank there can be many addresses so
|
||||
* that we can have multiple connections between peers.
|
||||
*
|
||||
* For example:
|
||||
* [
|
||||
* ["ip1:5000", "ip1:5001"],
|
||||
* ["ip2:5000", "ip2:5001"],
|
||||
* ["ip3:5000", "ip3:5001"],
|
||||
* ]
|
||||
*/
|
||||
std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
|
||||
std::vector<std::vector<address_t>> nodes;
|
||||
std::ifstream f(hostfile);
|
||||
|
||||
json hosts = json::parse(f);
|
||||
for (auto& h : hosts) {
|
||||
std::vector<address_t> host;
|
||||
for (auto& ips : h) {
|
||||
host.push_back(std::move(parse_address(ips.get<std::string>())));
|
||||
}
|
||||
nodes.push_back(std::move(host));
|
||||
}
|
||||
|
||||
return nodes;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a socket and accept one connection for each of the provided
|
||||
* addresses.
|
||||
*/
|
||||
std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
|
||||
std::vector<int> sockets;
|
||||
int success;
|
||||
|
||||
for (auto& address : addresses) {
|
||||
// Create the socket to wait for connections from the peers
|
||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't create socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Make sure we can launch immediately after shutdown by setting the
|
||||
// reuseaddr option so that we don't get address already in use errors
|
||||
int enable = 1;
|
||||
success = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't enable reuseaddr (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
success = setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't enable reuseport (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Bind the socket to the address and port
|
||||
success = bind(sock, address.get(), address.len);
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't bind socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Wait for connections
|
||||
success = listen(sock, 0);
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't listen (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
int peer_socket = accept(sock, nullptr, nullptr);
|
||||
if (peer_socket < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Accept failed (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Close the listening socket
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
|
||||
sockets.push_back(peer_socket);
|
||||
}
|
||||
|
||||
return sockets;
|
||||
}
|
||||
|
||||
/**
|
||||
* The counterpoint of `accept_connections`. Basically connect to each of the
|
||||
* provided addresses.
|
||||
*/
|
||||
std::vector<int> make_connections(
|
||||
const std::vector<address_t>& addresses,
|
||||
bool verbose) {
|
||||
std::vector<int> sockets;
|
||||
int success;
|
||||
|
||||
for (auto& address : addresses) {
|
||||
int sock;
|
||||
|
||||
// Attempt to connect to the peer CONN_ATTEMPTS times with exponential
|
||||
// backoff. TODO: Do we need that?
|
||||
for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
|
||||
// Create the socket
|
||||
sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't create socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
if (attempt > 0) {
|
||||
int wait = (1 << (attempt - 1)) * CONN_WAIT;
|
||||
log_info(
|
||||
verbose,
|
||||
"Attempt",
|
||||
attempt,
|
||||
"wait",
|
||||
wait,
|
||||
"ms (error:",
|
||||
errno,
|
||||
")");
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
|
||||
}
|
||||
|
||||
success = connect(sock, address.get(), address.len);
|
||||
if (success == 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (success < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't connect (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
sockets.push_back(sock);
|
||||
}
|
||||
|
||||
return sockets;
|
||||
}
|
||||
|
||||
array ensure_row_contiguous(const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
return arr_copy;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void sum_inplace(const T* input, T* output, size_t N) {
|
||||
while (N-- > 0) {
|
||||
*output += *input;
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void _send(int sock, T* data, size_t start, size_t stop) {
|
||||
if (stop <= start) {
|
||||
return;
|
||||
}
|
||||
data += start;
|
||||
size_t len = (stop - start) * sizeof(T);
|
||||
const char* buffer = (const char*)data;
|
||||
while (len > 0) {
|
||||
ssize_t r = send(sock, buffer, len, 0);
|
||||
if (r <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Send of " << len << " bytes failed (errno: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
buffer += r;
|
||||
len -= r;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void _recv(int sock, T* data, size_t start, size_t stop) {
|
||||
if (stop <= start) {
|
||||
return;
|
||||
}
|
||||
data += start;
|
||||
size_t len = (stop - start) * sizeof(T);
|
||||
char* buffer = (char*)data;
|
||||
while (len > 0) {
|
||||
ssize_t r = recv(sock, buffer, len, 0);
|
||||
if (r <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Recv of " << len << " bytes failed (errno: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
buffer += r;
|
||||
len -= r;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void _recv_sum(int sock, T* data, size_t start, size_t stop) {
|
||||
if (stop <= start) {
|
||||
return;
|
||||
}
|
||||
data += start;
|
||||
char buffer[PACKET_SIZE];
|
||||
size_t len = (stop - start) * sizeof(T);
|
||||
while (len > 0) {
|
||||
ssize_t r = 0;
|
||||
do {
|
||||
ssize_t partial_r =
|
||||
recv(sock, buffer + r, std::min(len, PACKET_SIZE) - r, 0);
|
||||
if (partial_r <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Recv of " << len << " bytes failed (errno: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
r += partial_r;
|
||||
} while (r % sizeof(T));
|
||||
sum_inplace((const T*)buffer, data, r / sizeof(T));
|
||||
data += r / sizeof(T);
|
||||
len -= r;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ring_send(
|
||||
Barrier& barrier,
|
||||
int socket,
|
||||
int rank,
|
||||
int size,
|
||||
T* data,
|
||||
size_t data_size,
|
||||
int direction = -1) {
|
||||
// We split the data into `size_` segments of size `segment_size`
|
||||
size_t segment_size = ceildiv(data_size, size);
|
||||
|
||||
// Initial segment
|
||||
int segment = rank;
|
||||
|
||||
// 1st send
|
||||
for (int i = 0; i < size - 1; i++) {
|
||||
size_t start = segment * segment_size;
|
||||
size_t stop = std::min((segment + 1) * segment_size, data_size);
|
||||
_send<T>(socket, data, start, stop);
|
||||
barrier.arrive_and_wait();
|
||||
segment = (segment + size + direction) % size;
|
||||
}
|
||||
|
||||
// 2nd send
|
||||
for (int i = 0; i < size - 1; i++) {
|
||||
size_t start = segment * segment_size;
|
||||
size_t stop = std::min((segment + 1) * segment_size, data_size);
|
||||
_send<T>(socket, data, start, stop);
|
||||
barrier.arrive_and_wait();
|
||||
segment = (segment + size + direction) % size;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ring_recv_sum(
|
||||
Barrier& barrier,
|
||||
int socket,
|
||||
int rank,
|
||||
int size,
|
||||
T* data,
|
||||
size_t data_size,
|
||||
int direction = -1) {
|
||||
// We split the data into `size_` segments of size `segment_size`
|
||||
size_t segment_size = ceildiv(data_size, size);
|
||||
|
||||
// Initial segment
|
||||
int segment = (rank + size + direction) % size;
|
||||
|
||||
// Recv sum
|
||||
for (int i = 0; i < size - 1; i++) {
|
||||
size_t start = segment * segment_size;
|
||||
size_t stop = std::min((segment + 1) * segment_size, data_size);
|
||||
_recv_sum<T>(socket, data, start, stop);
|
||||
barrier.arrive_and_wait();
|
||||
segment = (segment + size + direction) % size;
|
||||
}
|
||||
|
||||
// Recv
|
||||
for (int i = 0; i < size - 1; i++) {
|
||||
size_t start = segment * segment_size;
|
||||
size_t stop = std::min((segment + 1) * segment_size, data_size);
|
||||
_recv<T>(socket, data, start, stop);
|
||||
barrier.arrive_and_wait();
|
||||
segment = (segment + size + direction) % size;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class RingGroup : public GroupImpl {
|
||||
public:
|
||||
RingGroup(int rank, std::vector<std::vector<address_t>> nodes, bool verbose)
|
||||
: rank_(rank), verbose_(verbose), pool_(0) {
|
||||
if (rank_ > 0 && rank_ >= nodes.size()) {
|
||||
throw std::runtime_error(
|
||||
"[ring] Rank cannot be larger than the size of the group");
|
||||
}
|
||||
|
||||
size_ = nodes.size();
|
||||
int connect_to = (rank_ + 1) % size_;
|
||||
|
||||
// We define the connection order by having the rank_ == size_ - 1 connect
|
||||
// first and accept after.
|
||||
if (rank_ < connect_to) {
|
||||
log_info(verbose_, "Rank", rank_, "accepting");
|
||||
recv_sockets_ = std::move(accept_connections(nodes[rank_]));
|
||||
log_info(verbose_, "Rank", rank_, "connecting to", connect_to);
|
||||
send_sockets_ = std::move(make_connections(nodes[connect_to], verbose));
|
||||
} else {
|
||||
log_info(verbose_, "Rank", rank_, "connecting to", connect_to);
|
||||
send_sockets_ = std::move(make_connections(nodes[connect_to], verbose));
|
||||
log_info(verbose_, "Rank", rank_, "accepting");
|
||||
recv_sockets_ = std::move(accept_connections(nodes[rank_]));
|
||||
}
|
||||
|
||||
// Failure if we couldn't make send or recv sockets
|
||||
if (send_sockets_.empty()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Rank " << rank_ << " has no send sockets.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (recv_sockets_.empty()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Rank " << rank_ << " has no recv sockets.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// The following could be relaxed since we can define non-homogeneous rings
|
||||
// but it makes things a bit simpler for now.
|
||||
if (send_sockets_.size() != recv_sockets_.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] It is required to have as many connections to the left as "
|
||||
<< "to the right but rank " << rank_ << " has "
|
||||
<< send_sockets_.size() << " connections to the right and "
|
||||
<< recv_sockets_.size() << " to the left.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Start the necessary threads for completely parallel operation on all
|
||||
// channels. One thread to send, one to receive per socket.
|
||||
pool_.resize(send_sockets_.size() * 2 * 2);
|
||||
}
|
||||
|
||||
~RingGroup() {
|
||||
for (auto s : send_sockets_) {
|
||||
shutdown(s, 2);
|
||||
close(s);
|
||||
}
|
||||
for (auto s : recv_sockets_) {
|
||||
shutdown(s, 2);
|
||||
close(s);
|
||||
}
|
||||
}
|
||||
|
||||
int rank() override {
|
||||
return rank_;
|
||||
}
|
||||
|
||||
int size() override {
|
||||
return size_;
|
||||
}
|
||||
|
||||
void all_sum(const array& input_, array& output) override {
|
||||
SWITCH_TYPE(output, all_sum<T>(input_, output));
|
||||
}
|
||||
|
||||
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||
throw std::runtime_error("[ring] Group split not supported.");
|
||||
}
|
||||
void all_gather(const array& input, array& output) override {
|
||||
throw std::runtime_error("[ring] All gather not supported.");
|
||||
}
|
||||
void send(const array& input, int dst) override {
|
||||
throw std::runtime_error("[ring] Send not supported.");
|
||||
}
|
||||
void recv(array& out, int src) override {
|
||||
throw std::runtime_error("[ring] Recv not supported.");
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void all_sum(const array& input_, array& output) {
|
||||
// Make sure that the input is row contiguous
|
||||
array input = ensure_row_contiguous(input_);
|
||||
|
||||
// If the input data cannot be split into size_ segments then copy it and
|
||||
// all reduce a local buffer prefilled with 0s.
|
||||
if (input.size() < size_) {
|
||||
// TODO: Maybe allocate dynamically so we don't have the constraint below?
|
||||
if (input.itemsize() * size_ > 1024) {
|
||||
std::ostringstream msg;
|
||||
msg << "Can't perform the ring all reduce of " << output.size()
|
||||
<< " elements with a ring of size " << size_;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
std::future<void> sent, recvd;
|
||||
auto barrier = std::make_unique<Barrier>(2);
|
||||
char buffer[1024];
|
||||
std::memset(buffer, 0, size_ * input.itemsize());
|
||||
std::memcpy(buffer, input.data<char>(), input.nbytes());
|
||||
sent = pool_.enqueue(
|
||||
ring_send<T>,
|
||||
std::reference_wrapper(*barrier),
|
||||
send_sockets_[0],
|
||||
rank_,
|
||||
size_,
|
||||
(T*)buffer,
|
||||
size_,
|
||||
-1);
|
||||
recvd = pool_.enqueue(
|
||||
ring_recv_sum<T>,
|
||||
std::reference_wrapper(*barrier),
|
||||
recv_sockets_[0],
|
||||
rank_,
|
||||
size_,
|
||||
(T*)buffer,
|
||||
size_,
|
||||
-1);
|
||||
sent.wait();
|
||||
recvd.wait();
|
||||
std::memcpy(output.data<char>(), buffer, output.nbytes());
|
||||
return;
|
||||
}
|
||||
|
||||
// If not inplace all reduce then copy the input to the output first
|
||||
if (input.data<void>() != output.data<void>()) {
|
||||
std::memcpy(output.data<char>(), input.data<char>(), input.nbytes());
|
||||
}
|
||||
|
||||
// All reduce in place. We have `send_channels_.size()` bidirectional
|
||||
// channels so let's split the message up and perform as many parallel
|
||||
// ring-reductions as possible.
|
||||
std::vector<std::future<void>> reductions;
|
||||
std::vector<std::unique_ptr<Barrier>> barriers;
|
||||
size_t packets = ceildiv(output.size(), size_ * PACKET_SIZE);
|
||||
|
||||
// Large all reduce territory so let's use all we got
|
||||
if (packets >= 2 * send_sockets_.size()) {
|
||||
size_t segment = ceildiv(output.size(), 2 * send_sockets_.size());
|
||||
for (int i = 0; i < send_sockets_.size(); i++) {
|
||||
// 1st ring reduce
|
||||
barriers.emplace_back(std::make_unique<Barrier>(2));
|
||||
reductions.push_back(pool_.enqueue(
|
||||
ring_send<T>,
|
||||
std::reference_wrapper(*barriers.back()),
|
||||
send_sockets_[i],
|
||||
rank_,
|
||||
size_,
|
||||
output.data<T>() + 2 * i * segment,
|
||||
std::min(output.size() - 2 * i * segment, segment),
|
||||
-1));
|
||||
reductions.push_back(pool_.enqueue(
|
||||
ring_recv_sum<T>,
|
||||
std::reference_wrapper(*barriers.back()),
|
||||
recv_sockets_[i],
|
||||
rank_,
|
||||
size_,
|
||||
output.data<T>() + 2 * i * segment,
|
||||
std::min(output.size() - 2 * i * segment, segment),
|
||||
-1));
|
||||
|
||||
// 2nd ring reduce
|
||||
barriers.emplace_back(std::make_unique<Barrier>(2));
|
||||
reductions.push_back(pool_.enqueue(
|
||||
ring_send<T>,
|
||||
std::reference_wrapper(*barriers.back()),
|
||||
recv_sockets_[i],
|
||||
rank_,
|
||||
size_,
|
||||
output.data<T>() + (2 * i + 1) * segment,
|
||||
std::min(output.size() - (2 * i + 1) * segment, segment),
|
||||
1));
|
||||
reductions.push_back(pool_.enqueue(
|
||||
ring_recv_sum<T>,
|
||||
std::reference_wrapper(*barriers.back()),
|
||||
send_sockets_[i],
|
||||
rank_,
|
||||
size_,
|
||||
output.data<T>() + (2 * i + 1) * segment,
|
||||
std::min(output.size() - (2 * i + 1) * segment, segment),
|
||||
1));
|
||||
}
|
||||
}
|
||||
|
||||
// At least 2 reductions so we can be from small to medium
|
||||
else if (packets > 1) {
|
||||
size_t segment = ceildiv(output.size(), packets);
|
||||
for (int i = 0; i < send_sockets_.size(); i++) {
|
||||
barriers.emplace_back(std::make_unique<Barrier>(2));
|
||||
reductions.push_back(pool_.enqueue(
|
||||
ring_send<T>,
|
||||
std::reference_wrapper(*barriers.back()),
|
||||
send_sockets_[i],
|
||||
rank_,
|
||||
size_,
|
||||
output.data<T>() + i * segment,
|
||||
std::min(output.size() - i * segment, segment),
|
||||
-1));
|
||||
reductions.push_back(pool_.enqueue(
|
||||
ring_recv_sum<T>,
|
||||
std::reference_wrapper(*barriers.back()),
|
||||
recv_sockets_[i],
|
||||
rank_,
|
||||
size_,
|
||||
output.data<T>() + i * segment,
|
||||
std::min(output.size() - i * segment, segment),
|
||||
-1));
|
||||
}
|
||||
for (int i = 0; i < packets - send_sockets_.size(); i++) {
|
||||
barriers.emplace_back(std::make_unique<Barrier>(2));
|
||||
reductions.push_back(pool_.enqueue(
|
||||
ring_send<T>,
|
||||
std::reference_wrapper(*barriers.back()),
|
||||
recv_sockets_[i],
|
||||
rank_,
|
||||
size_,
|
||||
output.data<T>() + (send_sockets_.size() + i) * segment,
|
||||
std::min(
|
||||
output.size() - (send_sockets_.size() + i) * segment, segment),
|
||||
1));
|
||||
reductions.push_back(pool_.enqueue(
|
||||
ring_recv_sum<T>,
|
||||
std::reference_wrapper(*barriers.back()),
|
||||
send_sockets_[i],
|
||||
rank_,
|
||||
size_,
|
||||
output.data<T>() + (send_sockets_.size() + i) * segment,
|
||||
std::min(
|
||||
output.size() - (send_sockets_.size() + i) * segment, segment),
|
||||
1));
|
||||
}
|
||||
}
|
||||
|
||||
// Small reduction which won't really benefit much from parallelization.
|
||||
// TODO: Verify that this is true cause PACKET_SIZE * size_ can still be a
|
||||
// fairly large array.
|
||||
else {
|
||||
barriers.emplace_back(std::make_unique<Barrier>(2));
|
||||
reductions.push_back(pool_.enqueue(
|
||||
ring_send<T>,
|
||||
std::reference_wrapper(*barriers.back()),
|
||||
send_sockets_[0],
|
||||
rank_,
|
||||
size_,
|
||||
output.data<T>(),
|
||||
output.size(),
|
||||
-1));
|
||||
reductions.push_back(pool_.enqueue(
|
||||
ring_recv_sum<T>,
|
||||
std::reference_wrapper(*barriers.back()),
|
||||
recv_sockets_[0],
|
||||
rank_,
|
||||
size_,
|
||||
output.data<T>(),
|
||||
output.size(),
|
||||
-1));
|
||||
}
|
||||
|
||||
// Wait for the reductions to finish.
|
||||
for (auto& f : reductions) {
|
||||
f.wait();
|
||||
}
|
||||
}
|
||||
|
||||
int rank_;
|
||||
int size_;
|
||||
|
||||
bool verbose_;
|
||||
|
||||
ThreadPool pool_;
|
||||
|
||||
std::vector<int> send_sockets_;
|
||||
std::vector<int> recv_sockets_;
|
||||
};
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||
const char* hostfile = std::getenv("MLX_HOSTFILE");
|
||||
const char* rank_str = std::getenv("MLX_RANK");
|
||||
const char* ring_verbose = std::getenv("MLX_RING_VERBOSE");
|
||||
|
||||
if (!hostfile || !rank_str) {
|
||||
if (strict) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] You need to provide via environment variables both a rank (MLX_RANK) "
|
||||
<< "and a hostfile (MLX_HOSTFILE) but provided MLX_RANK=\""
|
||||
<< ((rank_str) ? rank_str : "") << "\" and MLX_HOSTFILE=\""
|
||||
<< ((hostfile) ? hostfile : "") << "\"";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto nodes = load_nodes(hostfile);
|
||||
int rank = std::atoi(rank_str);
|
||||
|
||||
return std::make_shared<RingGroup>(rank, nodes, ring_verbose != nullptr);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed::ring
|
12
mlx/distributed/ring/ring.h
Normal file
12
mlx/distributed/ring/ring.h
Normal file
@ -0,0 +1,12 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
|
||||
namespace mlx::core::distributed::ring {
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
|
||||
bool is_available();
|
||||
std::shared_ptr<GroupImpl> init(bool strict = false);
|
||||
|
||||
} // namespace mlx::core::distributed::ring
|
@ -13,7 +13,7 @@
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#include "mlx/io/threadpool.h"
|
||||
#include "mlx/threadpool.h"
|
||||
|
||||
// Strictly we need to operate on files in binary mode (to avoid \r getting
|
||||
// automatically inserted), but every modern system except for Windows no
|
||||
|
16
mlx/ops.cpp
16
mlx/ops.cpp
@ -652,7 +652,7 @@ void normalize_dynamic_slice_inputs(
|
||||
const array& a,
|
||||
const array& start,
|
||||
std::vector<int>& axes,
|
||||
const std::string prefix) {
|
||||
std::string_view prefix) {
|
||||
if (start.size() > a.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << prefix << " Invalid number of starting positions for "
|
||||
@ -690,7 +690,9 @@ void normalize_dynamic_slice_inputs(
|
||||
}
|
||||
std::set dims(axes.begin(), axes.end());
|
||||
if (dims.size() != axes.size()) {
|
||||
throw std::invalid_argument(prefix + " Repeat axes not allowed.");
|
||||
std::ostringstream msg;
|
||||
msg << prefix << " Repeat axes not allowed.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
@ -927,7 +929,7 @@ split(const array& a, int num_splits, StreamOrDevice s /* = {} */) {
|
||||
std::vector<array> meshgrid(
|
||||
const std::vector<array>& arrays,
|
||||
bool sparse /* = false */,
|
||||
std::string indexing /* = "xy" */,
|
||||
const std::string& indexing /* = "xy" */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (indexing != "xy" && indexing != "ij") {
|
||||
throw std::invalid_argument(
|
||||
@ -1186,7 +1188,7 @@ array pad(
|
||||
const Shape& low_pad_size,
|
||||
const Shape& high_pad_size,
|
||||
const array& pad_value /*= array(0)*/,
|
||||
const std::string mode /*= "constant"*/,
|
||||
const std::string& mode /*= "constant"*/,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
if (axes.size() != low_pad_size.size() ||
|
||||
axes.size() != high_pad_size.size()) {
|
||||
@ -1238,7 +1240,7 @@ array pad(
|
||||
const array& a,
|
||||
const std::vector<std::pair<int, int>>& pad_width,
|
||||
const array& pad_value /*= array(0)*/,
|
||||
const std::string mode /*= "constant"*/,
|
||||
const std::string& mode /*= "constant"*/,
|
||||
StreamOrDevice s /*= {}*/) {
|
||||
std::vector<int> axes(a.ndim(), 0);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
@ -1258,7 +1260,7 @@ array pad(
|
||||
const array& a,
|
||||
const std::pair<int, int>& pad_width,
|
||||
const array& pad_value /*= array(0)*/,
|
||||
const std::string mode /*= "constant"*/,
|
||||
const std::string& mode /*= "constant"*/,
|
||||
StreamOrDevice s /*= {}*/) {
|
||||
return pad(
|
||||
a,
|
||||
@ -1272,7 +1274,7 @@ array pad(
|
||||
const array& a,
|
||||
int pad_width,
|
||||
const array& pad_value /*= array(0)*/,
|
||||
const std::string mode /*= "constant"*/,
|
||||
const std::string& mode /*= "constant"*/,
|
||||
StreamOrDevice s /*= {}*/) {
|
||||
return pad(
|
||||
a,
|
||||
|
10
mlx/ops.h
10
mlx/ops.h
@ -222,7 +222,7 @@ split(const array& a, const Shape& indices, StreamOrDevice s = {});
|
||||
std::vector<array> meshgrid(
|
||||
const std::vector<array>& arrays,
|
||||
bool sparse = false,
|
||||
std::string indexing = "xy",
|
||||
const std::string& indexing = "xy",
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/**
|
||||
@ -274,7 +274,7 @@ array pad(
|
||||
const Shape& low_pad_size,
|
||||
const Shape& high_pad_size,
|
||||
const array& pad_value = array(0),
|
||||
const std::string mode = "constant",
|
||||
const std::string& mode = "constant",
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Pad an array with a constant value along all axes */
|
||||
@ -282,19 +282,19 @@ array pad(
|
||||
const array& a,
|
||||
const std::vector<std::pair<int, int>>& pad_width,
|
||||
const array& pad_value = array(0),
|
||||
const std::string mode = "constant",
|
||||
const std::string& mode = "constant",
|
||||
StreamOrDevice s = {});
|
||||
array pad(
|
||||
const array& a,
|
||||
const std::pair<int, int>& pad_width,
|
||||
const array& pad_value = array(0),
|
||||
const std::string mode = "constant",
|
||||
const std::string& mode = "constant",
|
||||
StreamOrDevice s = {});
|
||||
array pad(
|
||||
const array& a,
|
||||
int pad_width,
|
||||
const array& pad_value = array(0),
|
||||
const std::string mode = "constant",
|
||||
const std::string& mode = "constant",
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Permutes the dimensions in reverse order. */
|
||||
|
@ -38,9 +38,13 @@ class ThreadPool {
|
||||
template <class F, class... Args>
|
||||
auto enqueue(F&& f, Args&&... args)
|
||||
-> std::future<typename std::invoke_result_t<F, Args...>>;
|
||||
void resize(size_t);
|
||||
~ThreadPool();
|
||||
|
||||
private:
|
||||
void stop_and_wait();
|
||||
void start_threads(size_t);
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
std::queue<std::function<void()>> tasks;
|
||||
std::mutex queue_mutex;
|
||||
@ -49,24 +53,7 @@ class ThreadPool {
|
||||
};
|
||||
|
||||
inline ThreadPool::ThreadPool(size_t threads) : stop(false) {
|
||||
for (size_t i = 0; i < threads; ++i)
|
||||
workers.emplace_back([this] {
|
||||
for (;;) {
|
||||
std::function<void()> task;
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->queue_mutex);
|
||||
this->condition.wait(
|
||||
lock, [this] { return this->stop || !this->tasks.empty(); });
|
||||
if (this->stop && this->tasks.empty())
|
||||
return;
|
||||
task = std::move(this->tasks.front());
|
||||
this->tasks.pop();
|
||||
}
|
||||
|
||||
task();
|
||||
}
|
||||
});
|
||||
start_threads(threads);
|
||||
}
|
||||
|
||||
template <class F, class... Args>
|
||||
@ -92,12 +79,55 @@ auto ThreadPool::enqueue(F&& f, Args&&... args)
|
||||
return res;
|
||||
}
|
||||
|
||||
inline void ThreadPool::resize(size_t threads) {
|
||||
if (workers.size() == threads) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (workers.size() > threads) {
|
||||
stop_and_wait();
|
||||
}
|
||||
start_threads(threads - workers.size());
|
||||
}
|
||||
|
||||
inline ThreadPool::~ThreadPool() {
|
||||
stop_and_wait();
|
||||
}
|
||||
|
||||
inline void ThreadPool::stop_and_wait() {
|
||||
// Stop the current threads and wait until they finish
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
stop = true;
|
||||
}
|
||||
condition.notify_all();
|
||||
for (std::thread& worker : workers)
|
||||
for (std::thread& worker : workers) {
|
||||
worker.join();
|
||||
}
|
||||
|
||||
// Reset the member variables so that the threadpool is reusable
|
||||
stop = false;
|
||||
workers.clear();
|
||||
}
|
||||
|
||||
inline void ThreadPool::start_threads(size_t threads) {
|
||||
for (size_t i = 0; i < threads; ++i) {
|
||||
workers.emplace_back([this] {
|
||||
for (;;) {
|
||||
std::function<void()> task;
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->queue_mutex);
|
||||
this->condition.wait(
|
||||
lock, [this] { return this->stop || !this->tasks.empty(); });
|
||||
if (this->stop && this->tasks.empty())
|
||||
return;
|
||||
task = std::move(this->tasks.front());
|
||||
this->tasks.pop();
|
||||
}
|
||||
|
||||
task();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
@ -3,6 +3,7 @@
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
@ -58,14 +59,26 @@ void init_distributed(nb::module_& parent_module) {
|
||||
"init",
|
||||
&mx::distributed::init,
|
||||
"strict"_a = false,
|
||||
nb::sig("def init(strict: bool = False) -> Group"),
|
||||
"backend"_a = "any",
|
||||
nb::sig("def init(strict: bool = False, backend: str = 'any') -> Group"),
|
||||
R"pbdoc(
|
||||
Initialize the communication backend and create the global communication group.
|
||||
|
||||
Example:
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
group = mx.distributed.init(backend="ring")
|
||||
|
||||
|
||||
Args:
|
||||
strict (bool, optional): If set to False it returns a singleton group
|
||||
in case ``mx.distributed.is_available()`` returns False otherwise
|
||||
it throws a runtime error. Default: ``False``
|
||||
backend (str, optional): Select a specific distributed backend to
|
||||
initialize. If set to ``any`` then try all available backends and
|
||||
return the first one that succeeds. Subsequent calls will return
|
||||
the first backend that was initialized. Default: ``any``
|
||||
|
||||
Returns:
|
||||
Group: The group representing all the launched processes.
|
||||
|
@ -34,6 +34,8 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
||||
mx.int32,
|
||||
mx.uint32,
|
||||
mx.float32,
|
||||
mx.float16,
|
||||
mx.bfloat16,
|
||||
mx.complex64,
|
||||
]
|
||||
for dt in dtypes:
|
||||
|
61
python/tests/ring_test_distributed.py
Normal file
61
python/tests/ring_test_distributed.py
Normal file
@ -0,0 +1,61 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestRingDistributed(mlx_tests.MLXTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
world = mx.distributed.init(strict=True, backend="ring")
|
||||
|
||||
def test_groups(self):
|
||||
world = mx.distributed.init()
|
||||
self.assertEqual(world.size(), 8)
|
||||
self.assertTrue(0 <= world.rank() < 8)
|
||||
|
||||
world2 = mx.distributed.init()
|
||||
self.assertEqual(world.size(), world2.size())
|
||||
self.assertEqual(world.rank(), world2.rank())
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
sub = world.split(world.rank() % 2)
|
||||
|
||||
def test_all_reduce(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
(mx.int8, 0),
|
||||
(mx.uint8, 0),
|
||||
(mx.int16, 0),
|
||||
(mx.uint16, 0),
|
||||
(mx.int32, 0),
|
||||
(mx.uint32, 0),
|
||||
(mx.float32, 1e-6),
|
||||
(mx.float16, 5e-3),
|
||||
(mx.bfloat16, 1e-1),
|
||||
(mx.complex64, 1e-6),
|
||||
]
|
||||
sizes = [
|
||||
(7,),
|
||||
(10,),
|
||||
(1024,),
|
||||
(1024, 1024),
|
||||
]
|
||||
key = mx.random.key(0)
|
||||
for dt, rtol in dtypes:
|
||||
for sh in sizes:
|
||||
x = (
|
||||
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
|
||||
).astype(dt)
|
||||
y = mx.distributed.all_sum(x[world.rank()])
|
||||
z = sum(
|
||||
x[i] for i in range(world.size())
|
||||
) # to ensure that we don't sum to int32
|
||||
maxrelerror = ((y - z).abs() / z.abs()).max()
|
||||
self.assertLessEqual(maxrelerror, rtol)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
25
python/tests/run_ring_test.sh
Normal file
25
python/tests/run_ring_test.sh
Normal file
@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
|
||||
tmpfile=$(mktemp)
|
||||
cat <<HOSTFILE >$tmpfile
|
||||
[
|
||||
["127.0.0.1:5000"],
|
||||
["127.0.0.1:5001"],
|
||||
["127.0.0.1:5002"],
|
||||
["127.0.0.1:5003"],
|
||||
["127.0.0.1:5004"],
|
||||
["127.0.0.1:5005"],
|
||||
["127.0.0.1:5006"],
|
||||
["127.0.0.1:5007"]
|
||||
]
|
||||
HOSTFILE
|
||||
|
||||
ring_test="$(dirname ${BASH_SOURCE[0]})/ring_test_distributed.py"
|
||||
|
||||
for i in {0..7}; do
|
||||
if (($i == 7)); then
|
||||
sleep 1
|
||||
fi
|
||||
DEVICE=cpu MLX_RING_VERBOSE=1 MLX_HOSTFILE=$tmpfile MLX_RANK=$i python $ring_test &
|
||||
done
|
||||
wait
|
Loading…
Reference in New Issue
Block a user