mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Ring distributed backend (#1784)
This commit is contained in:
parent
2235dee906
commit
ccb61d7aae
@ -160,6 +160,7 @@ jobs:
|
|||||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||||
|
/bin/bash python/tests/run_ring_test.sh
|
||||||
- run:
|
- run:
|
||||||
name: Build example extension
|
name: Build example extension
|
||||||
command: |
|
command: |
|
||||||
|
@ -5,3 +5,4 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
#include "mlx/distributed/mpi/mpi.h"
|
#include "mlx/distributed/mpi/mpi.h"
|
||||||
|
#include "mlx/distributed/ring/ring.h"
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
@ -65,7 +68,7 @@ class EmptyGroup : public GroupImpl {
|
|||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return mpi::is_available();
|
return mpi::is_available() || ring::is_available();
|
||||||
}
|
}
|
||||||
|
|
||||||
int Group::rank() const {
|
int Group::rank() const {
|
||||||
@ -80,20 +83,50 @@ Group Group::split(int color, int key /* = -1 */) const {
|
|||||||
return Group(group_->split(color, key));
|
return Group(group_->split(color, key));
|
||||||
}
|
}
|
||||||
|
|
||||||
Group init(bool strict /* = false */) {
|
Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
||||||
auto init_group = [strict]() {
|
static std::unordered_map<std::string, std::shared_ptr<detail::GroupImpl>>
|
||||||
auto default_group = mpi::init(strict);
|
backends;
|
||||||
if (default_group == nullptr) {
|
|
||||||
default_group = std::make_shared<detail::EmptyGroup>();
|
// Already initialized so return the group.
|
||||||
|
if (auto g = backends.find(bk); g != backends.end()) {
|
||||||
|
return Group(g->second);
|
||||||
}
|
}
|
||||||
return default_group;
|
|
||||||
};
|
// Create the requested communication group
|
||||||
static std::shared_ptr<detail::GroupImpl> default_group = init_group();
|
std::shared_ptr<detail::GroupImpl> group;
|
||||||
|
std::string bk_ = bk;
|
||||||
|
if (bk == "mpi") {
|
||||||
|
group = mpi::init(strict);
|
||||||
|
} else if (bk == "ring") {
|
||||||
|
group = ring::init(strict);
|
||||||
|
} else if (bk == "any") {
|
||||||
|
group = ring::init(false);
|
||||||
|
bk_ = "ring";
|
||||||
|
if (group == nullptr) {
|
||||||
|
group = mpi::init(false);
|
||||||
|
bk_ = "mpi";
|
||||||
|
}
|
||||||
|
if (group == nullptr && strict) {
|
||||||
|
throw std::runtime_error("[distributed] Couldn't initialize any backend");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[distributed] The only valid values for backend are 'any', 'mpi' "
|
||||||
|
<< "and 'ring' but '" << bk << "' was provided.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (group == nullptr) {
|
||||||
|
group = std::make_shared<detail::EmptyGroup>();
|
||||||
|
} else {
|
||||||
|
backends.insert({"any", group});
|
||||||
|
}
|
||||||
|
backends.insert({std::move(bk_), group});
|
||||||
|
|
||||||
// Ensure the communication stream is alive before
|
// Ensure the communication stream is alive before
|
||||||
// the graph is evaluated
|
// the graph is evaluated
|
||||||
detail::communication_stream();
|
detail::communication_stream();
|
||||||
return Group(default_group);
|
return Group(group);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::distributed
|
} // namespace mlx::core::distributed
|
||||||
|
@ -53,6 +53,6 @@ struct Group {
|
|||||||
* distributed subsystem. Otherwise simply return a singleton group which will
|
* distributed subsystem. Otherwise simply return a singleton group which will
|
||||||
* render communication operations as no-op.
|
* render communication operations as no-op.
|
||||||
*/
|
*/
|
||||||
Group init(bool strict = false);
|
Group init(bool strict = false, const std::string& bk = "any");
|
||||||
|
|
||||||
} // namespace mlx::core::distributed
|
} // namespace mlx::core::distributed
|
||||||
|
@ -11,6 +11,8 @@ namespace mlx::core::distributed::detail {
|
|||||||
*/
|
*/
|
||||||
class GroupImpl {
|
class GroupImpl {
|
||||||
public:
|
public:
|
||||||
|
virtual ~GroupImpl() {}
|
||||||
|
|
||||||
virtual int rank() = 0;
|
virtual int rank() = 0;
|
||||||
virtual int size() = 0;
|
virtual int size() = 0;
|
||||||
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
|
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
|
||||||
|
5
mlx/distributed/ring/CMakeLists.txt
Normal file
5
mlx/distributed/ring/CMakeLists.txt
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
if(MLX_BUILD_CPU)
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ring.cpp)
|
||||||
|
else()
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_ring.cpp)
|
||||||
|
endif()
|
20
mlx/distributed/ring/no_ring.cpp
Normal file
20
mlx/distributed/ring/no_ring.cpp
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/distributed/ring/ring.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::ring {
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||||
|
if (strict) {
|
||||||
|
throw std::runtime_error("Cannot initialize ring distributed backend.");
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::ring
|
827
mlx/distributed/ring/ring.cpp
Normal file
827
mlx/distributed/ring/ring.cpp
Normal file
@ -0,0 +1,827 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <arpa/inet.h>
|
||||||
|
#include <netdb.h>
|
||||||
|
#include <sys/socket.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
#include <fstream>
|
||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
#include <json.hpp>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/copy.h"
|
||||||
|
#include "mlx/distributed/distributed.h"
|
||||||
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
|
#include "mlx/threadpool.h"
|
||||||
|
|
||||||
|
#define SWITCH_TYPE(x, ...) \
|
||||||
|
switch ((x).dtype()) { \
|
||||||
|
case bool_: { \
|
||||||
|
using T = bool; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case int8: { \
|
||||||
|
using T = int8_t; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case int16: { \
|
||||||
|
using T = int16_t; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case int32: { \
|
||||||
|
using T = int32_t; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case int64: { \
|
||||||
|
using T = int64_t; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case uint8: { \
|
||||||
|
using T = uint8_t; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case uint16: { \
|
||||||
|
using T = uint16_t; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case uint32: { \
|
||||||
|
using T = uint32_t; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case uint64: { \
|
||||||
|
using T = uint64_t; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case bfloat16: { \
|
||||||
|
using T = bfloat16_t; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case float16: { \
|
||||||
|
using T = float16_t; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case float32: { \
|
||||||
|
using T = float; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case complex64: { \
|
||||||
|
using T = complex64_t; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::ring {
|
||||||
|
|
||||||
|
constexpr const size_t PACKET_SIZE = 262144;
|
||||||
|
constexpr const int CONN_ATTEMPTS = 5;
|
||||||
|
constexpr const int CONN_WAIT = 1000;
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
using json = nlohmann::json;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class Barrier {
|
||||||
|
public:
|
||||||
|
explicit Barrier(int n_threads)
|
||||||
|
: n_threads_(n_threads), count_(0), flag_(false) {}
|
||||||
|
|
||||||
|
void arrive_and_wait() {
|
||||||
|
std::unique_lock<std::mutex> lock(mtx_);
|
||||||
|
|
||||||
|
// Keep the flag that marks the current use of the barrier. The next use is
|
||||||
|
// going to have this flag flipped.
|
||||||
|
bool initial_flag = flag_;
|
||||||
|
|
||||||
|
// Increment the count
|
||||||
|
count_++;
|
||||||
|
|
||||||
|
// We are the last thread to arrive so reset the count, change the flag and
|
||||||
|
// notify everybody.
|
||||||
|
if (count_ == n_threads_) {
|
||||||
|
count_ = 0;
|
||||||
|
flag_ = !flag_;
|
||||||
|
cv_.notify_all();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the rest to arrive
|
||||||
|
else {
|
||||||
|
cv_.wait(lock, [this, initial_flag]() { return initial_flag != flag_; });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::mutex mtx_;
|
||||||
|
std::condition_variable cv_;
|
||||||
|
int n_threads_;
|
||||||
|
|
||||||
|
int count_;
|
||||||
|
bool flag_; // we need this for sequential use of the barrier
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void log(std::ostream& os, T first) {
|
||||||
|
os << first << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename... Args>
|
||||||
|
void log(std::ostream& os, T first, Args... args) {
|
||||||
|
log(os << first << " ", args...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
void log_info(bool verbose, Args... args) {
|
||||||
|
if (!verbose) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
log(std::cerr, "[ring]", args...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
decltype(T() * U()) ceildiv(T a, U b) {
|
||||||
|
return (a + b - 1) / b;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct address_t {
|
||||||
|
sockaddr_storage addr;
|
||||||
|
socklen_t len;
|
||||||
|
|
||||||
|
const sockaddr* get() const {
|
||||||
|
return (struct sockaddr*)&addr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse a sockaddr from an ip and port provided as strings.
|
||||||
|
*/
|
||||||
|
address_t parse_address(const std::string& ip, const std::string& port) {
|
||||||
|
struct addrinfo hints, *res;
|
||||||
|
memset(&hints, 0, sizeof(hints));
|
||||||
|
hints.ai_family = AF_UNSPEC;
|
||||||
|
hints.ai_socktype = SOCK_STREAM;
|
||||||
|
|
||||||
|
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
|
||||||
|
if (status != 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Can't parse address " << ip << ":" << port;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
address_t result;
|
||||||
|
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
|
||||||
|
result.len = res->ai_addrlen;
|
||||||
|
freeaddrinfo(res);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse a sockaddr provided as an <ip>:<port> string.
|
||||||
|
*/
|
||||||
|
address_t parse_address(const std::string& ip_port) {
|
||||||
|
auto colon = ip_port.find(":");
|
||||||
|
if (colon == std::string::npos) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Can't parse address " << ip_port;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
std::string ip(ip_port.begin(), ip_port.begin() + colon);
|
||||||
|
std::string port(ip_port.begin() + colon + 1, ip_port.end());
|
||||||
|
|
||||||
|
return parse_address(ip, port);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load all addresses from the json hostfile. The hostfile is a list of
|
||||||
|
* addresses in order of rank. For each rank there can be many addresses so
|
||||||
|
* that we can have multiple connections between peers.
|
||||||
|
*
|
||||||
|
* For example:
|
||||||
|
* [
|
||||||
|
* ["ip1:5000", "ip1:5001"],
|
||||||
|
* ["ip2:5000", "ip2:5001"],
|
||||||
|
* ["ip3:5000", "ip3:5001"],
|
||||||
|
* ]
|
||||||
|
*/
|
||||||
|
std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
|
||||||
|
std::vector<std::vector<address_t>> nodes;
|
||||||
|
std::ifstream f(hostfile);
|
||||||
|
|
||||||
|
json hosts = json::parse(f);
|
||||||
|
for (auto& h : hosts) {
|
||||||
|
std::vector<address_t> host;
|
||||||
|
for (auto& ips : h) {
|
||||||
|
host.push_back(std::move(parse_address(ips.get<std::string>())));
|
||||||
|
}
|
||||||
|
nodes.push_back(std::move(host));
|
||||||
|
}
|
||||||
|
|
||||||
|
return nodes;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a socket and accept one connection for each of the provided
|
||||||
|
* addresses.
|
||||||
|
*/
|
||||||
|
std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
|
||||||
|
std::vector<int> sockets;
|
||||||
|
int success;
|
||||||
|
|
||||||
|
for (auto& address : addresses) {
|
||||||
|
// Create the socket to wait for connections from the peers
|
||||||
|
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
|
if (sock < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[ring] Couldn't create socket (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure we can launch immediately after shutdown by setting the
|
||||||
|
// reuseaddr option so that we don't get address already in use errors
|
||||||
|
int enable = 1;
|
||||||
|
success = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
||||||
|
if (success < 0) {
|
||||||
|
shutdown(sock, 2);
|
||||||
|
close(sock);
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[ring] Couldn't enable reuseaddr (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
success = setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
|
||||||
|
if (success < 0) {
|
||||||
|
shutdown(sock, 2);
|
||||||
|
close(sock);
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[ring] Couldn't enable reuseport (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bind the socket to the address and port
|
||||||
|
success = bind(sock, address.get(), address.len);
|
||||||
|
if (success < 0) {
|
||||||
|
shutdown(sock, 2);
|
||||||
|
close(sock);
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[ring] Couldn't bind socket (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for connections
|
||||||
|
success = listen(sock, 0);
|
||||||
|
if (success < 0) {
|
||||||
|
shutdown(sock, 2);
|
||||||
|
close(sock);
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[ring] Couldn't listen (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
int peer_socket = accept(sock, nullptr, nullptr);
|
||||||
|
if (peer_socket < 0) {
|
||||||
|
shutdown(sock, 2);
|
||||||
|
close(sock);
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[ring] Accept failed (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the listening socket
|
||||||
|
shutdown(sock, 2);
|
||||||
|
close(sock);
|
||||||
|
|
||||||
|
sockets.push_back(peer_socket);
|
||||||
|
}
|
||||||
|
|
||||||
|
return sockets;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The counterpoint of `accept_connections`. Basically connect to each of the
|
||||||
|
* provided addresses.
|
||||||
|
*/
|
||||||
|
std::vector<int> make_connections(
|
||||||
|
const std::vector<address_t>& addresses,
|
||||||
|
bool verbose) {
|
||||||
|
std::vector<int> sockets;
|
||||||
|
int success;
|
||||||
|
|
||||||
|
for (auto& address : addresses) {
|
||||||
|
int sock;
|
||||||
|
|
||||||
|
// Attempt to connect to the peer CONN_ATTEMPTS times with exponential
|
||||||
|
// backoff. TODO: Do we need that?
|
||||||
|
for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
|
||||||
|
// Create the socket
|
||||||
|
sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
|
if (sock < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[ring] Couldn't create socket (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (attempt > 0) {
|
||||||
|
int wait = (1 << (attempt - 1)) * CONN_WAIT;
|
||||||
|
log_info(
|
||||||
|
verbose,
|
||||||
|
"Attempt",
|
||||||
|
attempt,
|
||||||
|
"wait",
|
||||||
|
wait,
|
||||||
|
"ms (error:",
|
||||||
|
errno,
|
||||||
|
")");
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
|
||||||
|
}
|
||||||
|
|
||||||
|
success = connect(sock, address.get(), address.len);
|
||||||
|
if (success == 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (success < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[ring] Couldn't connect (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
sockets.push_back(sock);
|
||||||
|
}
|
||||||
|
|
||||||
|
return sockets;
|
||||||
|
}
|
||||||
|
|
||||||
|
array ensure_row_contiguous(const array& arr) {
|
||||||
|
if (arr.flags().row_contiguous) {
|
||||||
|
return arr;
|
||||||
|
} else {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy(arr, arr_copy, CopyType::General);
|
||||||
|
return arr_copy;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void sum_inplace(const T* input, T* output, size_t N) {
|
||||||
|
while (N-- > 0) {
|
||||||
|
*output += *input;
|
||||||
|
input++;
|
||||||
|
output++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void _send(int sock, T* data, size_t start, size_t stop) {
|
||||||
|
if (stop <= start) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
data += start;
|
||||||
|
size_t len = (stop - start) * sizeof(T);
|
||||||
|
const char* buffer = (const char*)data;
|
||||||
|
while (len > 0) {
|
||||||
|
ssize_t r = send(sock, buffer, len, 0);
|
||||||
|
if (r <= 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Send of " << len << " bytes failed (errno: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
buffer += r;
|
||||||
|
len -= r;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void _recv(int sock, T* data, size_t start, size_t stop) {
|
||||||
|
if (stop <= start) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
data += start;
|
||||||
|
size_t len = (stop - start) * sizeof(T);
|
||||||
|
char* buffer = (char*)data;
|
||||||
|
while (len > 0) {
|
||||||
|
ssize_t r = recv(sock, buffer, len, 0);
|
||||||
|
if (r <= 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Recv of " << len << " bytes failed (errno: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
buffer += r;
|
||||||
|
len -= r;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void _recv_sum(int sock, T* data, size_t start, size_t stop) {
|
||||||
|
if (stop <= start) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
data += start;
|
||||||
|
char buffer[PACKET_SIZE];
|
||||||
|
size_t len = (stop - start) * sizeof(T);
|
||||||
|
while (len > 0) {
|
||||||
|
ssize_t r = 0;
|
||||||
|
do {
|
||||||
|
ssize_t partial_r =
|
||||||
|
recv(sock, buffer + r, std::min(len, PACKET_SIZE) - r, 0);
|
||||||
|
if (partial_r <= 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Recv of " << len << " bytes failed (errno: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
r += partial_r;
|
||||||
|
} while (r % sizeof(T));
|
||||||
|
sum_inplace((const T*)buffer, data, r / sizeof(T));
|
||||||
|
data += r / sizeof(T);
|
||||||
|
len -= r;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void ring_send(
|
||||||
|
Barrier& barrier,
|
||||||
|
int socket,
|
||||||
|
int rank,
|
||||||
|
int size,
|
||||||
|
T* data,
|
||||||
|
size_t data_size,
|
||||||
|
int direction = -1) {
|
||||||
|
// We split the data into `size_` segments of size `segment_size`
|
||||||
|
size_t segment_size = ceildiv(data_size, size);
|
||||||
|
|
||||||
|
// Initial segment
|
||||||
|
int segment = rank;
|
||||||
|
|
||||||
|
// 1st send
|
||||||
|
for (int i = 0; i < size - 1; i++) {
|
||||||
|
size_t start = segment * segment_size;
|
||||||
|
size_t stop = std::min((segment + 1) * segment_size, data_size);
|
||||||
|
_send<T>(socket, data, start, stop);
|
||||||
|
barrier.arrive_and_wait();
|
||||||
|
segment = (segment + size + direction) % size;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2nd send
|
||||||
|
for (int i = 0; i < size - 1; i++) {
|
||||||
|
size_t start = segment * segment_size;
|
||||||
|
size_t stop = std::min((segment + 1) * segment_size, data_size);
|
||||||
|
_send<T>(socket, data, start, stop);
|
||||||
|
barrier.arrive_and_wait();
|
||||||
|
segment = (segment + size + direction) % size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void ring_recv_sum(
|
||||||
|
Barrier& barrier,
|
||||||
|
int socket,
|
||||||
|
int rank,
|
||||||
|
int size,
|
||||||
|
T* data,
|
||||||
|
size_t data_size,
|
||||||
|
int direction = -1) {
|
||||||
|
// We split the data into `size_` segments of size `segment_size`
|
||||||
|
size_t segment_size = ceildiv(data_size, size);
|
||||||
|
|
||||||
|
// Initial segment
|
||||||
|
int segment = (rank + size + direction) % size;
|
||||||
|
|
||||||
|
// Recv sum
|
||||||
|
for (int i = 0; i < size - 1; i++) {
|
||||||
|
size_t start = segment * segment_size;
|
||||||
|
size_t stop = std::min((segment + 1) * segment_size, data_size);
|
||||||
|
_recv_sum<T>(socket, data, start, stop);
|
||||||
|
barrier.arrive_and_wait();
|
||||||
|
segment = (segment + size + direction) % size;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recv
|
||||||
|
for (int i = 0; i < size - 1; i++) {
|
||||||
|
size_t start = segment * segment_size;
|
||||||
|
size_t stop = std::min((segment + 1) * segment_size, data_size);
|
||||||
|
_recv<T>(socket, data, start, stop);
|
||||||
|
barrier.arrive_and_wait();
|
||||||
|
segment = (segment + size + direction) % size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
class RingGroup : public GroupImpl {
|
||||||
|
public:
|
||||||
|
RingGroup(int rank, std::vector<std::vector<address_t>> nodes, bool verbose)
|
||||||
|
: rank_(rank), verbose_(verbose), pool_(0) {
|
||||||
|
if (rank_ > 0 && rank_ >= nodes.size()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[ring] Rank cannot be larger than the size of the group");
|
||||||
|
}
|
||||||
|
|
||||||
|
size_ = nodes.size();
|
||||||
|
int connect_to = (rank_ + 1) % size_;
|
||||||
|
|
||||||
|
// We define the connection order by having the rank_ == size_ - 1 connect
|
||||||
|
// first and accept after.
|
||||||
|
if (rank_ < connect_to) {
|
||||||
|
log_info(verbose_, "Rank", rank_, "accepting");
|
||||||
|
recv_sockets_ = std::move(accept_connections(nodes[rank_]));
|
||||||
|
log_info(verbose_, "Rank", rank_, "connecting to", connect_to);
|
||||||
|
send_sockets_ = std::move(make_connections(nodes[connect_to], verbose));
|
||||||
|
} else {
|
||||||
|
log_info(verbose_, "Rank", rank_, "connecting to", connect_to);
|
||||||
|
send_sockets_ = std::move(make_connections(nodes[connect_to], verbose));
|
||||||
|
log_info(verbose_, "Rank", rank_, "accepting");
|
||||||
|
recv_sockets_ = std::move(accept_connections(nodes[rank_]));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Failure if we couldn't make send or recv sockets
|
||||||
|
if (send_sockets_.empty()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[ring] Rank " << rank_ << " has no send sockets.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (recv_sockets_.empty()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[ring] Rank " << rank_ << " has no recv sockets.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// The following could be relaxed since we can define non-homogeneous rings
|
||||||
|
// but it makes things a bit simpler for now.
|
||||||
|
if (send_sockets_.size() != recv_sockets_.size()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[ring] It is required to have as many connections to the left as "
|
||||||
|
<< "to the right but rank " << rank_ << " has "
|
||||||
|
<< send_sockets_.size() << " connections to the right and "
|
||||||
|
<< recv_sockets_.size() << " to the left.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the necessary threads for completely parallel operation on all
|
||||||
|
// channels. One thread to send, one to receive per socket.
|
||||||
|
pool_.resize(send_sockets_.size() * 2 * 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
~RingGroup() {
|
||||||
|
for (auto s : send_sockets_) {
|
||||||
|
shutdown(s, 2);
|
||||||
|
close(s);
|
||||||
|
}
|
||||||
|
for (auto s : recv_sockets_) {
|
||||||
|
shutdown(s, 2);
|
||||||
|
close(s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int rank() override {
|
||||||
|
return rank_;
|
||||||
|
}
|
||||||
|
|
||||||
|
int size() override {
|
||||||
|
return size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_sum(const array& input_, array& output) override {
|
||||||
|
SWITCH_TYPE(output, all_sum<T>(input_, output));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||||
|
throw std::runtime_error("[ring] Group split not supported.");
|
||||||
|
}
|
||||||
|
void all_gather(const array& input, array& output) override {
|
||||||
|
throw std::runtime_error("[ring] All gather not supported.");
|
||||||
|
}
|
||||||
|
void send(const array& input, int dst) override {
|
||||||
|
throw std::runtime_error("[ring] Send not supported.");
|
||||||
|
}
|
||||||
|
void recv(array& out, int src) override {
|
||||||
|
throw std::runtime_error("[ring] Recv not supported.");
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename T>
|
||||||
|
void all_sum(const array& input_, array& output) {
|
||||||
|
// Make sure that the input is row contiguous
|
||||||
|
array input = ensure_row_contiguous(input_);
|
||||||
|
|
||||||
|
// If the input data cannot be split into size_ segments then copy it and
|
||||||
|
// all reduce a local buffer prefilled with 0s.
|
||||||
|
if (input.size() < size_) {
|
||||||
|
// TODO: Maybe allocate dynamically so we don't have the constraint below?
|
||||||
|
if (input.itemsize() * size_ > 1024) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Can't perform the ring all reduce of " << output.size()
|
||||||
|
<< " elements with a ring of size " << size_;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::future<void> sent, recvd;
|
||||||
|
auto barrier = std::make_unique<Barrier>(2);
|
||||||
|
char buffer[1024];
|
||||||
|
std::memset(buffer, 0, size_ * input.itemsize());
|
||||||
|
std::memcpy(buffer, input.data<char>(), input.nbytes());
|
||||||
|
sent = pool_.enqueue(
|
||||||
|
ring_send<T>,
|
||||||
|
std::reference_wrapper(*barrier),
|
||||||
|
send_sockets_[0],
|
||||||
|
rank_,
|
||||||
|
size_,
|
||||||
|
(T*)buffer,
|
||||||
|
size_,
|
||||||
|
-1);
|
||||||
|
recvd = pool_.enqueue(
|
||||||
|
ring_recv_sum<T>,
|
||||||
|
std::reference_wrapper(*barrier),
|
||||||
|
recv_sockets_[0],
|
||||||
|
rank_,
|
||||||
|
size_,
|
||||||
|
(T*)buffer,
|
||||||
|
size_,
|
||||||
|
-1);
|
||||||
|
sent.wait();
|
||||||
|
recvd.wait();
|
||||||
|
std::memcpy(output.data<char>(), buffer, output.nbytes());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If not inplace all reduce then copy the input to the output first
|
||||||
|
if (input.data<void>() != output.data<void>()) {
|
||||||
|
std::memcpy(output.data<char>(), input.data<char>(), input.nbytes());
|
||||||
|
}
|
||||||
|
|
||||||
|
// All reduce in place. We have `send_channels_.size()` bidirectional
|
||||||
|
// channels so let's split the message up and perform as many parallel
|
||||||
|
// ring-reductions as possible.
|
||||||
|
std::vector<std::future<void>> reductions;
|
||||||
|
std::vector<std::unique_ptr<Barrier>> barriers;
|
||||||
|
size_t packets = ceildiv(output.size(), size_ * PACKET_SIZE);
|
||||||
|
|
||||||
|
// Large all reduce territory so let's use all we got
|
||||||
|
if (packets >= 2 * send_sockets_.size()) {
|
||||||
|
size_t segment = ceildiv(output.size(), 2 * send_sockets_.size());
|
||||||
|
for (int i = 0; i < send_sockets_.size(); i++) {
|
||||||
|
// 1st ring reduce
|
||||||
|
barriers.emplace_back(std::make_unique<Barrier>(2));
|
||||||
|
reductions.push_back(pool_.enqueue(
|
||||||
|
ring_send<T>,
|
||||||
|
std::reference_wrapper(*barriers.back()),
|
||||||
|
send_sockets_[i],
|
||||||
|
rank_,
|
||||||
|
size_,
|
||||||
|
output.data<T>() + 2 * i * segment,
|
||||||
|
std::min(output.size() - 2 * i * segment, segment),
|
||||||
|
-1));
|
||||||
|
reductions.push_back(pool_.enqueue(
|
||||||
|
ring_recv_sum<T>,
|
||||||
|
std::reference_wrapper(*barriers.back()),
|
||||||
|
recv_sockets_[i],
|
||||||
|
rank_,
|
||||||
|
size_,
|
||||||
|
output.data<T>() + 2 * i * segment,
|
||||||
|
std::min(output.size() - 2 * i * segment, segment),
|
||||||
|
-1));
|
||||||
|
|
||||||
|
// 2nd ring reduce
|
||||||
|
barriers.emplace_back(std::make_unique<Barrier>(2));
|
||||||
|
reductions.push_back(pool_.enqueue(
|
||||||
|
ring_send<T>,
|
||||||
|
std::reference_wrapper(*barriers.back()),
|
||||||
|
recv_sockets_[i],
|
||||||
|
rank_,
|
||||||
|
size_,
|
||||||
|
output.data<T>() + (2 * i + 1) * segment,
|
||||||
|
std::min(output.size() - (2 * i + 1) * segment, segment),
|
||||||
|
1));
|
||||||
|
reductions.push_back(pool_.enqueue(
|
||||||
|
ring_recv_sum<T>,
|
||||||
|
std::reference_wrapper(*barriers.back()),
|
||||||
|
send_sockets_[i],
|
||||||
|
rank_,
|
||||||
|
size_,
|
||||||
|
output.data<T>() + (2 * i + 1) * segment,
|
||||||
|
std::min(output.size() - (2 * i + 1) * segment, segment),
|
||||||
|
1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// At least 2 reductions so we can be from small to medium
|
||||||
|
else if (packets > 1) {
|
||||||
|
size_t segment = ceildiv(output.size(), packets);
|
||||||
|
for (int i = 0; i < send_sockets_.size(); i++) {
|
||||||
|
barriers.emplace_back(std::make_unique<Barrier>(2));
|
||||||
|
reductions.push_back(pool_.enqueue(
|
||||||
|
ring_send<T>,
|
||||||
|
std::reference_wrapper(*barriers.back()),
|
||||||
|
send_sockets_[i],
|
||||||
|
rank_,
|
||||||
|
size_,
|
||||||
|
output.data<T>() + i * segment,
|
||||||
|
std::min(output.size() - i * segment, segment),
|
||||||
|
-1));
|
||||||
|
reductions.push_back(pool_.enqueue(
|
||||||
|
ring_recv_sum<T>,
|
||||||
|
std::reference_wrapper(*barriers.back()),
|
||||||
|
recv_sockets_[i],
|
||||||
|
rank_,
|
||||||
|
size_,
|
||||||
|
output.data<T>() + i * segment,
|
||||||
|
std::min(output.size() - i * segment, segment),
|
||||||
|
-1));
|
||||||
|
}
|
||||||
|
for (int i = 0; i < packets - send_sockets_.size(); i++) {
|
||||||
|
barriers.emplace_back(std::make_unique<Barrier>(2));
|
||||||
|
reductions.push_back(pool_.enqueue(
|
||||||
|
ring_send<T>,
|
||||||
|
std::reference_wrapper(*barriers.back()),
|
||||||
|
recv_sockets_[i],
|
||||||
|
rank_,
|
||||||
|
size_,
|
||||||
|
output.data<T>() + (send_sockets_.size() + i) * segment,
|
||||||
|
std::min(
|
||||||
|
output.size() - (send_sockets_.size() + i) * segment, segment),
|
||||||
|
1));
|
||||||
|
reductions.push_back(pool_.enqueue(
|
||||||
|
ring_recv_sum<T>,
|
||||||
|
std::reference_wrapper(*barriers.back()),
|
||||||
|
send_sockets_[i],
|
||||||
|
rank_,
|
||||||
|
size_,
|
||||||
|
output.data<T>() + (send_sockets_.size() + i) * segment,
|
||||||
|
std::min(
|
||||||
|
output.size() - (send_sockets_.size() + i) * segment, segment),
|
||||||
|
1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Small reduction which won't really benefit much from parallelization.
|
||||||
|
// TODO: Verify that this is true cause PACKET_SIZE * size_ can still be a
|
||||||
|
// fairly large array.
|
||||||
|
else {
|
||||||
|
barriers.emplace_back(std::make_unique<Barrier>(2));
|
||||||
|
reductions.push_back(pool_.enqueue(
|
||||||
|
ring_send<T>,
|
||||||
|
std::reference_wrapper(*barriers.back()),
|
||||||
|
send_sockets_[0],
|
||||||
|
rank_,
|
||||||
|
size_,
|
||||||
|
output.data<T>(),
|
||||||
|
output.size(),
|
||||||
|
-1));
|
||||||
|
reductions.push_back(pool_.enqueue(
|
||||||
|
ring_recv_sum<T>,
|
||||||
|
std::reference_wrapper(*barriers.back()),
|
||||||
|
recv_sockets_[0],
|
||||||
|
rank_,
|
||||||
|
size_,
|
||||||
|
output.data<T>(),
|
||||||
|
output.size(),
|
||||||
|
-1));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the reductions to finish.
|
||||||
|
for (auto& f : reductions) {
|
||||||
|
f.wait();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int rank_;
|
||||||
|
int size_;
|
||||||
|
|
||||||
|
bool verbose_;
|
||||||
|
|
||||||
|
ThreadPool pool_;
|
||||||
|
|
||||||
|
std::vector<int> send_sockets_;
|
||||||
|
std::vector<int> recv_sockets_;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||||
|
const char* hostfile = std::getenv("MLX_HOSTFILE");
|
||||||
|
const char* rank_str = std::getenv("MLX_RANK");
|
||||||
|
const char* ring_verbose = std::getenv("MLX_RING_VERBOSE");
|
||||||
|
|
||||||
|
if (!hostfile || !rank_str) {
|
||||||
|
if (strict) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[ring] You need to provide via environment variables both a rank (MLX_RANK) "
|
||||||
|
<< "and a hostfile (MLX_HOSTFILE) but provided MLX_RANK=\""
|
||||||
|
<< ((rank_str) ? rank_str : "") << "\" and MLX_HOSTFILE=\""
|
||||||
|
<< ((hostfile) ? hostfile : "") << "\"";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto nodes = load_nodes(hostfile);
|
||||||
|
int rank = std::atoi(rank_str);
|
||||||
|
|
||||||
|
return std::make_shared<RingGroup>(rank, nodes, ring_verbose != nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::ring
|
12
mlx/distributed/ring/ring.h
Normal file
12
mlx/distributed/ring/ring.h
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/distributed/distributed.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::ring {
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
|
bool is_available();
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict = false);
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::ring
|
@ -13,7 +13,7 @@
|
|||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "mlx/io/threadpool.h"
|
#include "mlx/threadpool.h"
|
||||||
|
|
||||||
// Strictly we need to operate on files in binary mode (to avoid \r getting
|
// Strictly we need to operate on files in binary mode (to avoid \r getting
|
||||||
// automatically inserted), but every modern system except for Windows no
|
// automatically inserted), but every modern system except for Windows no
|
||||||
|
16
mlx/ops.cpp
16
mlx/ops.cpp
@ -652,7 +652,7 @@ void normalize_dynamic_slice_inputs(
|
|||||||
const array& a,
|
const array& a,
|
||||||
const array& start,
|
const array& start,
|
||||||
std::vector<int>& axes,
|
std::vector<int>& axes,
|
||||||
const std::string prefix) {
|
std::string_view prefix) {
|
||||||
if (start.size() > a.ndim()) {
|
if (start.size() > a.ndim()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << prefix << " Invalid number of starting positions for "
|
msg << prefix << " Invalid number of starting positions for "
|
||||||
@ -690,7 +690,9 @@ void normalize_dynamic_slice_inputs(
|
|||||||
}
|
}
|
||||||
std::set dims(axes.begin(), axes.end());
|
std::set dims(axes.begin(), axes.end());
|
||||||
if (dims.size() != axes.size()) {
|
if (dims.size() != axes.size()) {
|
||||||
throw std::invalid_argument(prefix + " Repeat axes not allowed.");
|
std::ostringstream msg;
|
||||||
|
msg << prefix << " Repeat axes not allowed.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -927,7 +929,7 @@ split(const array& a, int num_splits, StreamOrDevice s /* = {} */) {
|
|||||||
std::vector<array> meshgrid(
|
std::vector<array> meshgrid(
|
||||||
const std::vector<array>& arrays,
|
const std::vector<array>& arrays,
|
||||||
bool sparse /* = false */,
|
bool sparse /* = false */,
|
||||||
std::string indexing /* = "xy" */,
|
const std::string& indexing /* = "xy" */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
if (indexing != "xy" && indexing != "ij") {
|
if (indexing != "xy" && indexing != "ij") {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -1186,7 +1188,7 @@ array pad(
|
|||||||
const Shape& low_pad_size,
|
const Shape& low_pad_size,
|
||||||
const Shape& high_pad_size,
|
const Shape& high_pad_size,
|
||||||
const array& pad_value /*= array(0)*/,
|
const array& pad_value /*= array(0)*/,
|
||||||
const std::string mode /*= "constant"*/,
|
const std::string& mode /*= "constant"*/,
|
||||||
StreamOrDevice s /* = {}*/) {
|
StreamOrDevice s /* = {}*/) {
|
||||||
if (axes.size() != low_pad_size.size() ||
|
if (axes.size() != low_pad_size.size() ||
|
||||||
axes.size() != high_pad_size.size()) {
|
axes.size() != high_pad_size.size()) {
|
||||||
@ -1238,7 +1240,7 @@ array pad(
|
|||||||
const array& a,
|
const array& a,
|
||||||
const std::vector<std::pair<int, int>>& pad_width,
|
const std::vector<std::pair<int, int>>& pad_width,
|
||||||
const array& pad_value /*= array(0)*/,
|
const array& pad_value /*= array(0)*/,
|
||||||
const std::string mode /*= "constant"*/,
|
const std::string& mode /*= "constant"*/,
|
||||||
StreamOrDevice s /*= {}*/) {
|
StreamOrDevice s /*= {}*/) {
|
||||||
std::vector<int> axes(a.ndim(), 0);
|
std::vector<int> axes(a.ndim(), 0);
|
||||||
std::iota(axes.begin(), axes.end(), 0);
|
std::iota(axes.begin(), axes.end(), 0);
|
||||||
@ -1258,7 +1260,7 @@ array pad(
|
|||||||
const array& a,
|
const array& a,
|
||||||
const std::pair<int, int>& pad_width,
|
const std::pair<int, int>& pad_width,
|
||||||
const array& pad_value /*= array(0)*/,
|
const array& pad_value /*= array(0)*/,
|
||||||
const std::string mode /*= "constant"*/,
|
const std::string& mode /*= "constant"*/,
|
||||||
StreamOrDevice s /*= {}*/) {
|
StreamOrDevice s /*= {}*/) {
|
||||||
return pad(
|
return pad(
|
||||||
a,
|
a,
|
||||||
@ -1272,7 +1274,7 @@ array pad(
|
|||||||
const array& a,
|
const array& a,
|
||||||
int pad_width,
|
int pad_width,
|
||||||
const array& pad_value /*= array(0)*/,
|
const array& pad_value /*= array(0)*/,
|
||||||
const std::string mode /*= "constant"*/,
|
const std::string& mode /*= "constant"*/,
|
||||||
StreamOrDevice s /*= {}*/) {
|
StreamOrDevice s /*= {}*/) {
|
||||||
return pad(
|
return pad(
|
||||||
a,
|
a,
|
||||||
|
10
mlx/ops.h
10
mlx/ops.h
@ -222,7 +222,7 @@ split(const array& a, const Shape& indices, StreamOrDevice s = {});
|
|||||||
std::vector<array> meshgrid(
|
std::vector<array> meshgrid(
|
||||||
const std::vector<array>& arrays,
|
const std::vector<array>& arrays,
|
||||||
bool sparse = false,
|
bool sparse = false,
|
||||||
std::string indexing = "xy",
|
const std::string& indexing = "xy",
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -274,7 +274,7 @@ array pad(
|
|||||||
const Shape& low_pad_size,
|
const Shape& low_pad_size,
|
||||||
const Shape& high_pad_size,
|
const Shape& high_pad_size,
|
||||||
const array& pad_value = array(0),
|
const array& pad_value = array(0),
|
||||||
const std::string mode = "constant",
|
const std::string& mode = "constant",
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Pad an array with a constant value along all axes */
|
/** Pad an array with a constant value along all axes */
|
||||||
@ -282,19 +282,19 @@ array pad(
|
|||||||
const array& a,
|
const array& a,
|
||||||
const std::vector<std::pair<int, int>>& pad_width,
|
const std::vector<std::pair<int, int>>& pad_width,
|
||||||
const array& pad_value = array(0),
|
const array& pad_value = array(0),
|
||||||
const std::string mode = "constant",
|
const std::string& mode = "constant",
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
array pad(
|
array pad(
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::pair<int, int>& pad_width,
|
const std::pair<int, int>& pad_width,
|
||||||
const array& pad_value = array(0),
|
const array& pad_value = array(0),
|
||||||
const std::string mode = "constant",
|
const std::string& mode = "constant",
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
array pad(
|
array pad(
|
||||||
const array& a,
|
const array& a,
|
||||||
int pad_width,
|
int pad_width,
|
||||||
const array& pad_value = array(0),
|
const array& pad_value = array(0),
|
||||||
const std::string mode = "constant",
|
const std::string& mode = "constant",
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Permutes the dimensions in reverse order. */
|
/** Permutes the dimensions in reverse order. */
|
||||||
|
@ -38,9 +38,13 @@ class ThreadPool {
|
|||||||
template <class F, class... Args>
|
template <class F, class... Args>
|
||||||
auto enqueue(F&& f, Args&&... args)
|
auto enqueue(F&& f, Args&&... args)
|
||||||
-> std::future<typename std::invoke_result_t<F, Args...>>;
|
-> std::future<typename std::invoke_result_t<F, Args...>>;
|
||||||
|
void resize(size_t);
|
||||||
~ThreadPool();
|
~ThreadPool();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void stop_and_wait();
|
||||||
|
void start_threads(size_t);
|
||||||
|
|
||||||
std::vector<std::thread> workers;
|
std::vector<std::thread> workers;
|
||||||
std::queue<std::function<void()>> tasks;
|
std::queue<std::function<void()>> tasks;
|
||||||
std::mutex queue_mutex;
|
std::mutex queue_mutex;
|
||||||
@ -49,24 +53,7 @@ class ThreadPool {
|
|||||||
};
|
};
|
||||||
|
|
||||||
inline ThreadPool::ThreadPool(size_t threads) : stop(false) {
|
inline ThreadPool::ThreadPool(size_t threads) : stop(false) {
|
||||||
for (size_t i = 0; i < threads; ++i)
|
start_threads(threads);
|
||||||
workers.emplace_back([this] {
|
|
||||||
for (;;) {
|
|
||||||
std::function<void()> task;
|
|
||||||
|
|
||||||
{
|
|
||||||
std::unique_lock<std::mutex> lock(this->queue_mutex);
|
|
||||||
this->condition.wait(
|
|
||||||
lock, [this] { return this->stop || !this->tasks.empty(); });
|
|
||||||
if (this->stop && this->tasks.empty())
|
|
||||||
return;
|
|
||||||
task = std::move(this->tasks.front());
|
|
||||||
this->tasks.pop();
|
|
||||||
}
|
|
||||||
|
|
||||||
task();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class F, class... Args>
|
template <class F, class... Args>
|
||||||
@ -92,12 +79,55 @@ auto ThreadPool::enqueue(F&& f, Args&&... args)
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void ThreadPool::resize(size_t threads) {
|
||||||
|
if (workers.size() == threads) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (workers.size() > threads) {
|
||||||
|
stop_and_wait();
|
||||||
|
}
|
||||||
|
start_threads(threads - workers.size());
|
||||||
|
}
|
||||||
|
|
||||||
inline ThreadPool::~ThreadPool() {
|
inline ThreadPool::~ThreadPool() {
|
||||||
|
stop_and_wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ThreadPool::stop_and_wait() {
|
||||||
|
// Stop the current threads and wait until they finish
|
||||||
{
|
{
|
||||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||||
stop = true;
|
stop = true;
|
||||||
}
|
}
|
||||||
condition.notify_all();
|
condition.notify_all();
|
||||||
for (std::thread& worker : workers)
|
for (std::thread& worker : workers) {
|
||||||
worker.join();
|
worker.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset the member variables so that the threadpool is reusable
|
||||||
|
stop = false;
|
||||||
|
workers.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ThreadPool::start_threads(size_t threads) {
|
||||||
|
for (size_t i = 0; i < threads; ++i) {
|
||||||
|
workers.emplace_back([this] {
|
||||||
|
for (;;) {
|
||||||
|
std::function<void()> task;
|
||||||
|
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lock(this->queue_mutex);
|
||||||
|
this->condition.wait(
|
||||||
|
lock, [this] { return this->stop || !this->tasks.empty(); });
|
||||||
|
if (this->stop && this->tasks.empty())
|
||||||
|
return;
|
||||||
|
task = std::move(this->tasks.front());
|
||||||
|
this->tasks.pop();
|
||||||
|
}
|
||||||
|
|
||||||
|
task();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
@ -3,6 +3,7 @@
|
|||||||
#include <nanobind/nanobind.h>
|
#include <nanobind/nanobind.h>
|
||||||
#include <nanobind/stl/optional.h>
|
#include <nanobind/stl/optional.h>
|
||||||
#include <nanobind/stl/shared_ptr.h>
|
#include <nanobind/stl/shared_ptr.h>
|
||||||
|
#include <nanobind/stl/string.h>
|
||||||
#include <nanobind/stl/variant.h>
|
#include <nanobind/stl/variant.h>
|
||||||
#include <nanobind/stl/vector.h>
|
#include <nanobind/stl/vector.h>
|
||||||
|
|
||||||
@ -58,14 +59,26 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
"init",
|
"init",
|
||||||
&mx::distributed::init,
|
&mx::distributed::init,
|
||||||
"strict"_a = false,
|
"strict"_a = false,
|
||||||
nb::sig("def init(strict: bool = False) -> Group"),
|
"backend"_a = "any",
|
||||||
|
nb::sig("def init(strict: bool = False, backend: str = 'any') -> Group"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Initialize the communication backend and create the global communication group.
|
Initialize the communication backend and create the global communication group.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
group = mx.distributed.init(backend="ring")
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
strict (bool, optional): If set to False it returns a singleton group
|
strict (bool, optional): If set to False it returns a singleton group
|
||||||
in case ``mx.distributed.is_available()`` returns False otherwise
|
in case ``mx.distributed.is_available()`` returns False otherwise
|
||||||
it throws a runtime error. Default: ``False``
|
it throws a runtime error. Default: ``False``
|
||||||
|
backend (str, optional): Select a specific distributed backend to
|
||||||
|
initialize. If set to ``any`` then try all available backends and
|
||||||
|
return the first one that succeeds. Subsequent calls will return
|
||||||
|
the first backend that was initialized. Default: ``any``
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Group: The group representing all the launched processes.
|
Group: The group representing all the launched processes.
|
||||||
|
@ -34,6 +34,8 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
|||||||
mx.int32,
|
mx.int32,
|
||||||
mx.uint32,
|
mx.uint32,
|
||||||
mx.float32,
|
mx.float32,
|
||||||
|
mx.float16,
|
||||||
|
mx.bfloat16,
|
||||||
mx.complex64,
|
mx.complex64,
|
||||||
]
|
]
|
||||||
for dt in dtypes:
|
for dt in dtypes:
|
||||||
|
61
python/tests/ring_test_distributed.py
Normal file
61
python/tests/ring_test_distributed.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
|
class TestRingDistributed(mlx_tests.MLXTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
world = mx.distributed.init(strict=True, backend="ring")
|
||||||
|
|
||||||
|
def test_groups(self):
|
||||||
|
world = mx.distributed.init()
|
||||||
|
self.assertEqual(world.size(), 8)
|
||||||
|
self.assertTrue(0 <= world.rank() < 8)
|
||||||
|
|
||||||
|
world2 = mx.distributed.init()
|
||||||
|
self.assertEqual(world.size(), world2.size())
|
||||||
|
self.assertEqual(world.rank(), world2.rank())
|
||||||
|
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
sub = world.split(world.rank() % 2)
|
||||||
|
|
||||||
|
def test_all_reduce(self):
|
||||||
|
world = mx.distributed.init()
|
||||||
|
dtypes = [
|
||||||
|
(mx.int8, 0),
|
||||||
|
(mx.uint8, 0),
|
||||||
|
(mx.int16, 0),
|
||||||
|
(mx.uint16, 0),
|
||||||
|
(mx.int32, 0),
|
||||||
|
(mx.uint32, 0),
|
||||||
|
(mx.float32, 1e-6),
|
||||||
|
(mx.float16, 5e-3),
|
||||||
|
(mx.bfloat16, 1e-1),
|
||||||
|
(mx.complex64, 1e-6),
|
||||||
|
]
|
||||||
|
sizes = [
|
||||||
|
(7,),
|
||||||
|
(10,),
|
||||||
|
(1024,),
|
||||||
|
(1024, 1024),
|
||||||
|
]
|
||||||
|
key = mx.random.key(0)
|
||||||
|
for dt, rtol in dtypes:
|
||||||
|
for sh in sizes:
|
||||||
|
x = (
|
||||||
|
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
|
||||||
|
).astype(dt)
|
||||||
|
y = mx.distributed.all_sum(x[world.rank()])
|
||||||
|
z = sum(
|
||||||
|
x[i] for i in range(world.size())
|
||||||
|
) # to ensure that we don't sum to int32
|
||||||
|
maxrelerror = ((y - z).abs() / z.abs()).max()
|
||||||
|
self.assertLessEqual(maxrelerror, rtol)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
25
python/tests/run_ring_test.sh
Normal file
25
python/tests/run_ring_test.sh
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
tmpfile=$(mktemp)
|
||||||
|
cat <<HOSTFILE >$tmpfile
|
||||||
|
[
|
||||||
|
["127.0.0.1:5000"],
|
||||||
|
["127.0.0.1:5001"],
|
||||||
|
["127.0.0.1:5002"],
|
||||||
|
["127.0.0.1:5003"],
|
||||||
|
["127.0.0.1:5004"],
|
||||||
|
["127.0.0.1:5005"],
|
||||||
|
["127.0.0.1:5006"],
|
||||||
|
["127.0.0.1:5007"]
|
||||||
|
]
|
||||||
|
HOSTFILE
|
||||||
|
|
||||||
|
ring_test="$(dirname ${BASH_SOURCE[0]})/ring_test_distributed.py"
|
||||||
|
|
||||||
|
for i in {0..7}; do
|
||||||
|
if (($i == 7)); then
|
||||||
|
sleep 1
|
||||||
|
fi
|
||||||
|
DEVICE=cpu MLX_RING_VERBOSE=1 MLX_HOSTFILE=$tmpfile MLX_RANK=$i python $ring_test &
|
||||||
|
done
|
||||||
|
wait
|
Loading…
Reference in New Issue
Block a user