mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
9 Commits
ibv-backen
...
7a82455b35
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a82455b35 | ||
|
|
643a9a6ba6 | ||
|
|
82097a8f85 | ||
|
|
29d9cd836a | ||
|
|
2d10020178 | ||
|
|
031e62539a | ||
|
|
97f74543b1 | ||
|
|
0dbe63397d | ||
|
|
873df2e0f7 |
@@ -119,6 +119,10 @@ if(MLX_BUILD_METAL)
|
|||||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||||
OUTPUT_VARIABLE MACOS_SDK_VERSION
|
OUTPUT_VARIABLE MACOS_SDK_VERSION
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||||
|
execute_process(
|
||||||
|
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-path"
|
||||||
|
OUTPUT_VARIABLE CMAKE_OSX_SYSROOT
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||||
|
|
||||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||||
message(
|
message(
|
||||||
|
|||||||
@@ -4,6 +4,11 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
|
||||||
|
|
||||||
|
if(MLX_BUILD_CPU AND NOT WIN32)
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
||||||
|
endif()
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ibv)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
#include "mlx/backend/cuda/cuda.h"
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
|
#include "mlx/distributed/ibv/ibv.h"
|
||||||
#include "mlx/distributed/mpi/mpi.h"
|
#include "mlx/distributed/mpi/mpi.h"
|
||||||
#include "mlx/distributed/nccl/nccl.h"
|
#include "mlx/distributed/nccl/nccl.h"
|
||||||
#include "mlx/distributed/ring/ring.h"
|
#include "mlx/distributed/ring/ring.h"
|
||||||
@@ -102,7 +103,8 @@ class EmptyGroup : public GroupImpl {
|
|||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return mpi::is_available() || ring::is_available() || nccl::is_available();
|
return mpi::is_available() || ring::is_available() || nccl::is_available() ||
|
||||||
|
ibv::is_available();
|
||||||
}
|
}
|
||||||
|
|
||||||
int Group::rank() const {
|
int Group::rank() const {
|
||||||
@@ -135,6 +137,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
|||||||
group = ring::init(strict);
|
group = ring::init(strict);
|
||||||
} else if (bk == "nccl") {
|
} else if (bk == "nccl") {
|
||||||
group = nccl::init(strict);
|
group = nccl::init(strict);
|
||||||
|
} else if (bk == "ibv") {
|
||||||
|
group = ibv::init(strict);
|
||||||
} else if (bk == "any") {
|
} else if (bk == "any") {
|
||||||
if (mlx::core::cu::is_available()) {
|
if (mlx::core::cu::is_available()) {
|
||||||
group = nccl::init(false);
|
group = nccl::init(false);
|
||||||
@@ -148,13 +152,17 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
|||||||
group = mpi::init(false);
|
group = mpi::init(false);
|
||||||
bk_ = "mpi";
|
bk_ = "mpi";
|
||||||
}
|
}
|
||||||
|
if (group == nullptr) {
|
||||||
|
group = ibv::init(false);
|
||||||
|
bk_ = "ibv";
|
||||||
|
}
|
||||||
if (group == nullptr && strict) {
|
if (group == nullptr && strict) {
|
||||||
throw std::runtime_error("[distributed] Couldn't initialize any backend");
|
throw std::runtime_error("[distributed] Couldn't initialize any backend");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[distributed] The only valid values for backend are 'any', 'mpi' "
|
msg << "[distributed] The only valid values for backend are 'any', 'mpi', 'nccl', "
|
||||||
<< "and 'ring' but '" << bk << "' was provided.";
|
<< "'ibv' and 'ring' but '" << bk << "' was provided.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
8
mlx/distributed/ibv/CMakeLists.txt
Normal file
8
mlx/distributed/ibv/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
if(MLX_BUILD_CPU
|
||||||
|
AND ${CMAKE_SYSTEM_NAME} MATCHES "Darwin"
|
||||||
|
AND MACOS_SDK_VERSION GREATER_EQUAL 26.2)
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ibv.cpp)
|
||||||
|
target_link_libraries(mlx PRIVATE rdma)
|
||||||
|
else()
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_ibv.cpp)
|
||||||
|
endif()
|
||||||
1122
mlx/distributed/ibv/ibv.cpp
Normal file
1122
mlx/distributed/ibv/ibv.cpp
Normal file
File diff suppressed because it is too large
Load Diff
12
mlx/distributed/ibv/ibv.h
Normal file
12
mlx/distributed/ibv/ibv.h
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/distributed/distributed.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::ibv {
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
|
bool is_available();
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict = false);
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::ibv
|
||||||
20
mlx/distributed/ibv/no_ibv.cpp
Normal file
20
mlx/distributed/ibv/no_ibv.cpp
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/distributed/ibv/ibv.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::ibv {
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||||
|
if (strict) {
|
||||||
|
throw std::runtime_error("Cannot initialize ibv distributed backend.");
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::ibv
|
||||||
38
mlx/distributed/reduction_ops.h
Normal file
38
mlx/distributed/reduction_ops.h
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::detail {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct SumOp {
|
||||||
|
void operator()(const T* input, T* output, size_t N) const {
|
||||||
|
while (N-- > 0) {
|
||||||
|
*output += *input;
|
||||||
|
input++;
|
||||||
|
output++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct MaxOp {
|
||||||
|
void operator()(const T* input, T* output, size_t N) const {
|
||||||
|
while (N-- > 0) {
|
||||||
|
*output = std::max(*output, *input);
|
||||||
|
input++;
|
||||||
|
output++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct MinOp {
|
||||||
|
void operator()(const T* input, T* output, size_t N) const {
|
||||||
|
while (N-- > 0) {
|
||||||
|
*output = std::min(*output, *input);
|
||||||
|
input++;
|
||||||
|
output++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::detail
|
||||||
@@ -1,9 +1,6 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include <arpa/inet.h>
|
|
||||||
#include <fcntl.h>
|
#include <fcntl.h>
|
||||||
#include <netdb.h>
|
|
||||||
#include <netinet/in.h>
|
|
||||||
#include <netinet/tcp.h>
|
#include <netinet/tcp.h>
|
||||||
#include <sys/socket.h>
|
#include <sys/socket.h>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
@@ -22,6 +19,8 @@
|
|||||||
#include "mlx/backend/cpu/encoder.h"
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
|
#include "mlx/distributed/reduction_ops.h"
|
||||||
|
#include "mlx/distributed/utils.h"
|
||||||
#include "mlx/threadpool.h"
|
#include "mlx/threadpool.h"
|
||||||
|
|
||||||
#ifndef SOL_TCP
|
#ifndef SOL_TCP
|
||||||
@@ -94,6 +93,7 @@ constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024;
|
|||||||
constexpr const size_t ALL_SUM_BUFFERS = 2;
|
constexpr const size_t ALL_SUM_BUFFERS = 2;
|
||||||
constexpr const int CONN_ATTEMPTS = 5;
|
constexpr const int CONN_ATTEMPTS = 5;
|
||||||
constexpr const int CONN_WAIT = 1000;
|
constexpr const int CONN_WAIT = 1000;
|
||||||
|
constexpr const char* RING_TAG = "[ring]";
|
||||||
|
|
||||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
@@ -296,55 +296,6 @@ class CommunicationThreads {
|
|||||||
std::unordered_map<int, SocketThread> threads_;
|
std::unordered_map<int, SocketThread> threads_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct address_t {
|
|
||||||
sockaddr_storage addr;
|
|
||||||
socklen_t len;
|
|
||||||
|
|
||||||
const sockaddr* get() const {
|
|
||||||
return (struct sockaddr*)&addr;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse a sockaddr from an ip and port provided as strings.
|
|
||||||
*/
|
|
||||||
address_t parse_address(const std::string& ip, const std::string& port) {
|
|
||||||
struct addrinfo hints, *res;
|
|
||||||
memset(&hints, 0, sizeof(hints));
|
|
||||||
hints.ai_family = AF_UNSPEC;
|
|
||||||
hints.ai_socktype = SOCK_STREAM;
|
|
||||||
|
|
||||||
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
|
|
||||||
if (status != 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Can't parse address " << ip << ":" << port;
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
address_t result;
|
|
||||||
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
|
|
||||||
result.len = res->ai_addrlen;
|
|
||||||
freeaddrinfo(res);
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse a sockaddr provided as an <ip>:<port> string.
|
|
||||||
*/
|
|
||||||
address_t parse_address(const std::string& ip_port) {
|
|
||||||
auto colon = ip_port.find(":");
|
|
||||||
if (colon == std::string::npos) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Can't parse address " << ip_port;
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
std::string ip(ip_port.begin(), ip_port.begin() + colon);
|
|
||||||
std::string port(ip_port.begin() + colon + 1, ip_port.end());
|
|
||||||
|
|
||||||
return parse_address(ip, port);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load all addresses from the json hostfile. The hostfile is a list of
|
* Load all addresses from the json hostfile. The hostfile is a list of
|
||||||
* addresses in order of rank. For each rank there can be many addresses so
|
* addresses in order of rank. For each rank there can be many addresses so
|
||||||
@@ -357,15 +308,15 @@ address_t parse_address(const std::string& ip_port) {
|
|||||||
* ["ip3:5000", "ip3:5001"],
|
* ["ip3:5000", "ip3:5001"],
|
||||||
* ]
|
* ]
|
||||||
*/
|
*/
|
||||||
std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
|
std::vector<std::vector<detail::address_t>> load_nodes(const char* hostfile) {
|
||||||
std::vector<std::vector<address_t>> nodes;
|
std::vector<std::vector<detail::address_t>> nodes;
|
||||||
std::ifstream f(hostfile);
|
std::ifstream f(hostfile);
|
||||||
|
|
||||||
json hosts = json::parse(f);
|
json hosts = json::parse(f);
|
||||||
for (auto& h : hosts) {
|
for (auto& h : hosts) {
|
||||||
std::vector<address_t> host;
|
std::vector<detail::address_t> host;
|
||||||
for (auto& ips : h) {
|
for (auto& ips : h) {
|
||||||
host.push_back(parse_address(ips.get<std::string>()));
|
host.push_back(std::move(detail::parse_address(ips.get<std::string>())));
|
||||||
}
|
}
|
||||||
nodes.push_back(std::move(host));
|
nodes.push_back(std::move(host));
|
||||||
}
|
}
|
||||||
@@ -377,73 +328,15 @@ std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
|
|||||||
* Create a socket and accept one connection for each of the provided
|
* Create a socket and accept one connection for each of the provided
|
||||||
* addresses.
|
* addresses.
|
||||||
*/
|
*/
|
||||||
std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
|
std::vector<int> accept_connections(
|
||||||
|
const std::vector<detail::address_t>& addresses) {
|
||||||
std::vector<int> sockets;
|
std::vector<int> sockets;
|
||||||
int success;
|
int success;
|
||||||
|
|
||||||
for (auto& address : addresses) {
|
for (auto& address : addresses) {
|
||||||
// Create the socket to wait for connections from the peers
|
detail::TCPSocket socket(RING_TAG);
|
||||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
socket.listen(RING_TAG, address);
|
||||||
if (sock < 0) {
|
sockets.push_back(socket.accept(RING_TAG).detach());
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't create socket (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure we can launch immediately after shutdown by setting the
|
|
||||||
// reuseaddr option so that we don't get address already in use errors
|
|
||||||
int enable = 1;
|
|
||||||
success = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
|
||||||
if (success < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't enable reuseaddr (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
success = setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
|
|
||||||
if (success < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't enable reuseport (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bind the socket to the address and port
|
|
||||||
success = bind(sock, address.get(), address.len);
|
|
||||||
if (success < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't bind socket (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for connections
|
|
||||||
success = listen(sock, 0);
|
|
||||||
if (success < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't listen (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
int peer_socket = accept(sock, nullptr, nullptr);
|
|
||||||
if (peer_socket < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Accept failed (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close the listening socket
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
|
|
||||||
sockets.push_back(peer_socket);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return sockets;
|
return sockets;
|
||||||
@@ -454,93 +347,42 @@ std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
|
|||||||
* provided addresses.
|
* provided addresses.
|
||||||
*/
|
*/
|
||||||
std::vector<int> make_connections(
|
std::vector<int> make_connections(
|
||||||
const std::vector<address_t>& addresses,
|
const std::vector<detail::address_t>& addresses,
|
||||||
bool verbose) {
|
bool verbose) {
|
||||||
std::vector<int> sockets;
|
std::vector<int> sockets;
|
||||||
int success;
|
int success;
|
||||||
|
|
||||||
for (auto& address : addresses) {
|
for (auto& address : addresses) {
|
||||||
int sock;
|
sockets.push_back(detail::TCPSocket::connect(
|
||||||
|
RING_TAG,
|
||||||
// Attempt to connect to the peer CONN_ATTEMPTS times with exponential
|
address,
|
||||||
// backoff. TODO: Do we need that?
|
CONN_ATTEMPTS,
|
||||||
for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
|
CONN_WAIT,
|
||||||
// Create the socket
|
[verbose](int attempt, int wait) {
|
||||||
sock = socket(AF_INET, SOCK_STREAM, 0);
|
|
||||||
if (sock < 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't create socket (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (attempt > 0) {
|
|
||||||
int wait = (1 << (attempt - 1)) * CONN_WAIT;
|
|
||||||
log_info(
|
log_info(
|
||||||
verbose,
|
verbose,
|
||||||
"Attempt",
|
"Attempt",
|
||||||
attempt,
|
attempt,
|
||||||
"wait",
|
"waiting",
|
||||||
wait,
|
wait,
|
||||||
"ms (error:",
|
"ms (error:",
|
||||||
errno,
|
errno,
|
||||||
")");
|
")");
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
|
})
|
||||||
}
|
.detach());
|
||||||
|
|
||||||
success = connect(sock, address.get(), address.len);
|
|
||||||
if (success == 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (success < 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't connect (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
sockets.push_back(sock);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return sockets;
|
return sockets;
|
||||||
}
|
}
|
||||||
template <typename T>
|
|
||||||
struct SumOp {
|
|
||||||
void operator()(const T* input, T* output, size_t N) {
|
|
||||||
while (N-- > 0) {
|
|
||||||
*output += *input;
|
|
||||||
input++;
|
|
||||||
output++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct MaxOp {
|
|
||||||
void operator()(const T* input, T* output, size_t N) {
|
|
||||||
while (N-- > 0) {
|
|
||||||
*output = std::max(*output, *input);
|
|
||||||
input++;
|
|
||||||
output++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct MinOp {
|
|
||||||
void operator()(const T* input, T* output, size_t N) {
|
|
||||||
while (N-- > 0) {
|
|
||||||
*output = std::min(*output, *input);
|
|
||||||
input++;
|
|
||||||
output++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
class RingGroup : public GroupImpl {
|
class RingGroup : public GroupImpl {
|
||||||
public:
|
public:
|
||||||
RingGroup(int rank, std::vector<std::vector<address_t>> nodes, bool verbose)
|
RingGroup(
|
||||||
|
int rank,
|
||||||
|
std::vector<std::vector<detail::address_t>> nodes,
|
||||||
|
bool verbose)
|
||||||
: rank_(rank), verbose_(verbose), pool_(0) {
|
: rank_(rank), verbose_(verbose), pool_(0) {
|
||||||
if (rank_ > 0 && rank_ >= nodes.size()) {
|
if (rank_ > 0 && rank_ >= nodes.size()) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
@@ -633,17 +475,17 @@ class RingGroup : public GroupImpl {
|
|||||||
|
|
||||||
void all_sum(const array& input, array& output, Stream stream) override {
|
void all_sum(const array& input, array& output, Stream stream) override {
|
||||||
SWITCH_TYPE(
|
SWITCH_TYPE(
|
||||||
output, all_reduce<T, SumOp<T>>(input, output, stream, SumOp<T>()));
|
output, all_reduce<T>(input, output, stream, detail::SumOp<T>()));
|
||||||
}
|
}
|
||||||
|
|
||||||
void all_max(const array& input, array& output, Stream stream) override {
|
void all_max(const array& input, array& output, Stream stream) override {
|
||||||
SWITCH_TYPE(
|
SWITCH_TYPE(
|
||||||
output, all_reduce<T, MaxOp<T>>(input, output, stream, MaxOp<T>()));
|
output, all_reduce<T>(input, output, stream, detail::MaxOp<T>()));
|
||||||
}
|
}
|
||||||
|
|
||||||
void all_min(const array& input, array& output, Stream stream) override {
|
void all_min(const array& input, array& output, Stream stream) override {
|
||||||
SWITCH_TYPE(
|
SWITCH_TYPE(
|
||||||
output, all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>()));
|
output, all_reduce<T>(input, output, stream, detail::MinOp<T>()));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||||
|
|||||||
203
mlx/distributed/utils.cpp
Normal file
203
mlx/distributed/utils.cpp
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <netdb.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <sstream>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
#include "mlx/distributed/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::detail {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse a sockaddr from an ip and port provided as strings.
|
||||||
|
*/
|
||||||
|
address_t parse_address(const std::string& ip, const std::string& port) {
|
||||||
|
struct addrinfo hints, *res;
|
||||||
|
memset(&hints, 0, sizeof(hints));
|
||||||
|
hints.ai_family = AF_UNSPEC;
|
||||||
|
hints.ai_socktype = SOCK_STREAM;
|
||||||
|
|
||||||
|
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
|
||||||
|
if (status != 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Can't parse address " << ip << ":" << port;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
address_t result;
|
||||||
|
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
|
||||||
|
result.len = res->ai_addrlen;
|
||||||
|
freeaddrinfo(res);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse a sockaddr provided as an <ip>:<port> string.
|
||||||
|
*/
|
||||||
|
address_t parse_address(const std::string& ip_port) {
|
||||||
|
auto colon = ip_port.find(":");
|
||||||
|
if (colon == std::string::npos) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Can't parse address " << ip_port;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
std::string ip(ip_port.begin(), ip_port.begin() + colon);
|
||||||
|
std::string port(ip_port.begin() + colon + 1, ip_port.end());
|
||||||
|
|
||||||
|
return parse_address(ip, port);
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket::TCPSocket(const char* tag) {
|
||||||
|
sock_ = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
|
if (sock_ < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't create socket (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket::TCPSocket(TCPSocket&& s) {
|
||||||
|
sock_ = s.sock_;
|
||||||
|
s.sock_ = -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket& TCPSocket::operator=(TCPSocket&& s) {
|
||||||
|
if (this != &s) {
|
||||||
|
sock_ = s.sock_;
|
||||||
|
s.sock_ = -1;
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket::TCPSocket(int s) : sock_(s) {}
|
||||||
|
|
||||||
|
TCPSocket::~TCPSocket() {
|
||||||
|
if (sock_ > 0) {
|
||||||
|
shutdown(sock_, 2);
|
||||||
|
close(sock_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int TCPSocket::detach() {
|
||||||
|
int s = sock_;
|
||||||
|
sock_ = -1;
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TCPSocket::listen(const char* tag, const address_t& addr) {
|
||||||
|
int success;
|
||||||
|
|
||||||
|
// Make sure we can launch immediately after shutdown by setting the
|
||||||
|
// reuseaddr option so that we don't get address already in use errors
|
||||||
|
int enable = 1;
|
||||||
|
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
||||||
|
if (success < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't enable reuseaddr (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
|
||||||
|
if (success < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't enable reuseport (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bind the socket to the address and port
|
||||||
|
success = bind(sock_, addr.get(), addr.len);
|
||||||
|
if (success < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't bind socket (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare waiting for connections
|
||||||
|
success = ::listen(sock_, 0);
|
||||||
|
if (success < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't listen (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket TCPSocket::accept(const char* tag) {
|
||||||
|
int peer = ::accept(sock_, nullptr, nullptr);
|
||||||
|
if (peer < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Accept failed (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return TCPSocket(peer);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TCPSocket::send(const char* tag, const void* data, size_t len) {
|
||||||
|
while (len > 0) {
|
||||||
|
auto n = ::send(sock_, data, len, 0);
|
||||||
|
if (n <= 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Send failed with errno=" << errno;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
len -= n;
|
||||||
|
data = static_cast<const char*>(data) + n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TCPSocket::recv(const char* tag, void* data, size_t len) {
|
||||||
|
while (len > 0) {
|
||||||
|
auto n = ::recv(sock_, data, len, 0);
|
||||||
|
if (n <= 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Recv failed with errno=" << errno;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
len -= n;
|
||||||
|
data = static_cast<char*>(data) + n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket TCPSocket::connect(
|
||||||
|
const char* tag,
|
||||||
|
const address_t& addr,
|
||||||
|
int num_retries,
|
||||||
|
int wait,
|
||||||
|
std::function<void(int, int)> cb) {
|
||||||
|
int sock, success;
|
||||||
|
|
||||||
|
// Attempt to connect `num_retries` times with exponential backoff.
|
||||||
|
for (int attempt = 0; attempt < num_retries; attempt++) {
|
||||||
|
// Create the socket
|
||||||
|
sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
|
if (sock < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't create socket to connect (error: " << errno
|
||||||
|
<< ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
success = ::connect(sock, addr.get(), addr.len);
|
||||||
|
if (success == 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
cb(attempt, wait);
|
||||||
|
if (wait > 0) {
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
|
||||||
|
}
|
||||||
|
|
||||||
|
wait <<= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (success < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't connect (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return TCPSocket(sock);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::detail
|
||||||
65
mlx/distributed/utils.h
Normal file
65
mlx/distributed/utils.h
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <sys/socket.h>
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::detail {
|
||||||
|
|
||||||
|
struct address_t {
|
||||||
|
sockaddr_storage addr;
|
||||||
|
socklen_t len;
|
||||||
|
|
||||||
|
const sockaddr* get() const {
|
||||||
|
return (struct sockaddr*)&addr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse a sockaddr from an ip and port provided as strings.
|
||||||
|
*/
|
||||||
|
address_t parse_address(const std::string& ip, const std::string& port);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse a sockaddr provided as an <ip>:<port> string.
|
||||||
|
*/
|
||||||
|
address_t parse_address(const std::string& ip_port);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Small wrapper over a TCP socket to simplify initiating connections.
|
||||||
|
*/
|
||||||
|
class TCPSocket {
|
||||||
|
public:
|
||||||
|
TCPSocket(const char* tag);
|
||||||
|
TCPSocket(const TCPSocket&) = delete;
|
||||||
|
TCPSocket& operator=(const TCPSocket&) = delete;
|
||||||
|
TCPSocket(TCPSocket&& s);
|
||||||
|
TCPSocket& operator=(TCPSocket&&);
|
||||||
|
~TCPSocket();
|
||||||
|
|
||||||
|
void listen(const char* tag, const address_t& addr);
|
||||||
|
TCPSocket accept(const char* tag);
|
||||||
|
|
||||||
|
void send(const char* tag, const void* data, size_t len);
|
||||||
|
void recv(const char* tag, void* data, size_t len);
|
||||||
|
|
||||||
|
int detach();
|
||||||
|
|
||||||
|
operator int() const {
|
||||||
|
return sock_;
|
||||||
|
}
|
||||||
|
|
||||||
|
static TCPSocket connect(
|
||||||
|
const char* tag,
|
||||||
|
const address_t& addr,
|
||||||
|
int num_retries = 1,
|
||||||
|
int wait = 0,
|
||||||
|
std::function<void(int, int)> cb = nullptr);
|
||||||
|
|
||||||
|
private:
|
||||||
|
TCPSocket(int sock);
|
||||||
|
|
||||||
|
int sock_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::detail
|
||||||
Reference in New Issue
Block a user