mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
17 Commits
main
...
ebda161a86
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ebda161a86 | ||
|
|
fa31a4b295 | ||
|
|
9d707ba3b5 | ||
|
|
405d30b6e5 | ||
|
|
cd4b12ce1b | ||
|
|
425043ccca | ||
|
|
95d92af8a0 | ||
|
|
bfdddd644b | ||
|
|
1216afdc91 | ||
|
|
04e94d78bb | ||
|
|
60d4e8b2a8 | ||
|
|
c5745fddd2 | ||
|
|
e937a8033f | ||
|
|
4dfe02d7c6 | ||
|
|
5c2cff9329 | ||
|
|
325dab9559 | ||
|
|
67e454ab0a |
@@ -119,6 +119,10 @@ if(MLX_BUILD_METAL)
|
||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_SDK_VERSION
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-path"
|
||||
OUTPUT_VARIABLE CMAKE_OSX_SYSROOT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||
message(
|
||||
|
||||
@@ -4,6 +4,11 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
|
||||
|
||||
if(MLX_BUILD_CPU AND NOT WIN32)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/jaccl)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/jaccl/jaccl.h"
|
||||
#include "mlx/distributed/mpi/mpi.h"
|
||||
#include "mlx/distributed/nccl/nccl.h"
|
||||
#include "mlx/distributed/ring/ring.h"
|
||||
@@ -102,7 +103,27 @@ class EmptyGroup : public GroupImpl {
|
||||
} // namespace detail
|
||||
|
||||
bool is_available() {
|
||||
return mpi::is_available() || ring::is_available() || nccl::is_available();
|
||||
return mpi::is_available() || ring::is_available() || nccl::is_available() ||
|
||||
jaccl::is_available();
|
||||
}
|
||||
|
||||
bool is_available(const std::string& bk) {
|
||||
if (bk == "any") {
|
||||
return is_available();
|
||||
}
|
||||
if (bk == "mpi") {
|
||||
return mpi::is_available();
|
||||
}
|
||||
if (bk == "ring") {
|
||||
return ring::is_available();
|
||||
}
|
||||
if (bk == "nccl") {
|
||||
return nccl::is_available();
|
||||
}
|
||||
if (bk == "jaccl") {
|
||||
return jaccl::is_available();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int Group::rank() const {
|
||||
@@ -135,6 +156,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
||||
group = ring::init(strict);
|
||||
} else if (bk == "nccl") {
|
||||
group = nccl::init(strict);
|
||||
} else if (bk == "jaccl") {
|
||||
group = jaccl::init(strict);
|
||||
} else if (bk == "any") {
|
||||
if (mlx::core::cu::is_available()) {
|
||||
group = nccl::init(false);
|
||||
@@ -148,13 +171,17 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
||||
group = mpi::init(false);
|
||||
bk_ = "mpi";
|
||||
}
|
||||
if (group == nullptr) {
|
||||
group = jaccl::init(false);
|
||||
bk_ = "jaccl";
|
||||
}
|
||||
if (group == nullptr && strict) {
|
||||
throw std::runtime_error("[distributed] Couldn't initialize any backend");
|
||||
}
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "[distributed] The only valid values for backend are 'any', 'mpi' "
|
||||
<< "and 'ring' but '" << bk << "' was provided.";
|
||||
msg << "[distributed] The only valid values for backend are 'any', 'mpi', 'nccl', "
|
||||
<< "'jaccl' and 'ring' but '" << bk << "' was provided.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ class GroupImpl;
|
||||
|
||||
/* Check if a communication backend is available */
|
||||
bool is_available();
|
||||
bool is_available(const std::string& bk);
|
||||
|
||||
/**
|
||||
* A distributed::Group represents a group of independent mlx processes that
|
||||
|
||||
8
mlx/distributed/jaccl/CMakeLists.txt
Normal file
8
mlx/distributed/jaccl/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
if(MLX_BUILD_CPU
|
||||
AND ${CMAKE_SYSTEM_NAME} MATCHES "Darwin"
|
||||
AND MACOS_SDK_VERSION GREATER_EQUAL 26.2)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jaccl.cpp)
|
||||
target_link_libraries(mlx PRIVATE rdma)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_jaccl.cpp)
|
||||
endif()
|
||||
1123
mlx/distributed/jaccl/jaccl.cpp
Normal file
1123
mlx/distributed/jaccl/jaccl.cpp
Normal file
File diff suppressed because it is too large
Load Diff
12
mlx/distributed/jaccl/jaccl.h
Normal file
12
mlx/distributed/jaccl/jaccl.h
Normal file
@@ -0,0 +1,12 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
|
||||
namespace mlx::core::distributed::jaccl {
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
|
||||
bool is_available();
|
||||
std::shared_ptr<GroupImpl> init(bool strict = false);
|
||||
|
||||
} // namespace mlx::core::distributed::jaccl
|
||||
20
mlx/distributed/jaccl/no_jaccl.cpp
Normal file
20
mlx/distributed/jaccl/no_jaccl.cpp
Normal file
@@ -0,0 +1,20 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/distributed/jaccl/jaccl.h"
|
||||
|
||||
namespace mlx::core::distributed::jaccl {
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||
if (strict) {
|
||||
throw std::runtime_error("Cannot initialize jaccl distributed backend.");
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed::jaccl
|
||||
38
mlx/distributed/reduction_ops.h
Normal file
38
mlx/distributed/reduction_ops.h
Normal file
@@ -0,0 +1,38 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
namespace mlx::core::distributed::detail {
|
||||
|
||||
template <typename T>
|
||||
struct SumOp {
|
||||
void operator()(const T* input, T* output, size_t N) const {
|
||||
while (N-- > 0) {
|
||||
*output += *input;
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MaxOp {
|
||||
void operator()(const T* input, T* output, size_t N) const {
|
||||
while (N-- > 0) {
|
||||
*output = std::max(*output, *input);
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MinOp {
|
||||
void operator()(const T* input, T* output, size_t N) const {
|
||||
while (N-- > 0) {
|
||||
*output = std::min(*output, *input);
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::distributed::detail
|
||||
@@ -1,9 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <fcntl.h>
|
||||
#include <netdb.h>
|
||||
#include <netinet/in.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <sys/socket.h>
|
||||
#include <unistd.h>
|
||||
@@ -22,6 +19,8 @@
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/reduction_ops.h"
|
||||
#include "mlx/distributed/utils.h"
|
||||
#include "mlx/threadpool.h"
|
||||
|
||||
#ifndef SOL_TCP
|
||||
@@ -94,6 +93,7 @@ constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024;
|
||||
constexpr const size_t ALL_SUM_BUFFERS = 2;
|
||||
constexpr const int CONN_ATTEMPTS = 5;
|
||||
constexpr const int CONN_WAIT = 1000;
|
||||
constexpr const char* RING_TAG = "[ring]";
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
using json = nlohmann::json;
|
||||
@@ -296,55 +296,6 @@ class CommunicationThreads {
|
||||
std::unordered_map<int, SocketThread> threads_;
|
||||
};
|
||||
|
||||
struct address_t {
|
||||
sockaddr_storage addr;
|
||||
socklen_t len;
|
||||
|
||||
const sockaddr* get() const {
|
||||
return (struct sockaddr*)&addr;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Parse a sockaddr from an ip and port provided as strings.
|
||||
*/
|
||||
address_t parse_address(const std::string& ip, const std::string& port) {
|
||||
struct addrinfo hints, *res;
|
||||
memset(&hints, 0, sizeof(hints));
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
|
||||
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
|
||||
if (status != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Can't parse address " << ip << ":" << port;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
address_t result;
|
||||
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
|
||||
result.len = res->ai_addrlen;
|
||||
freeaddrinfo(res);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse a sockaddr provided as an <ip>:<port> string.
|
||||
*/
|
||||
address_t parse_address(const std::string& ip_port) {
|
||||
auto colon = ip_port.find(":");
|
||||
if (colon == std::string::npos) {
|
||||
std::ostringstream msg;
|
||||
msg << "Can't parse address " << ip_port;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
std::string ip(ip_port.begin(), ip_port.begin() + colon);
|
||||
std::string port(ip_port.begin() + colon + 1, ip_port.end());
|
||||
|
||||
return parse_address(ip, port);
|
||||
}
|
||||
|
||||
/**
|
||||
* Load all addresses from the json hostfile. The hostfile is a list of
|
||||
* addresses in order of rank. For each rank there can be many addresses so
|
||||
@@ -357,15 +308,15 @@ address_t parse_address(const std::string& ip_port) {
|
||||
* ["ip3:5000", "ip3:5001"],
|
||||
* ]
|
||||
*/
|
||||
std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
|
||||
std::vector<std::vector<address_t>> nodes;
|
||||
std::vector<std::vector<detail::address_t>> load_nodes(const char* hostfile) {
|
||||
std::vector<std::vector<detail::address_t>> nodes;
|
||||
std::ifstream f(hostfile);
|
||||
|
||||
json hosts = json::parse(f);
|
||||
for (auto& h : hosts) {
|
||||
std::vector<address_t> host;
|
||||
std::vector<detail::address_t> host;
|
||||
for (auto& ips : h) {
|
||||
host.push_back(parse_address(ips.get<std::string>()));
|
||||
host.push_back(std::move(detail::parse_address(ips.get<std::string>())));
|
||||
}
|
||||
nodes.push_back(std::move(host));
|
||||
}
|
||||
@@ -377,73 +328,15 @@ std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
|
||||
* Create a socket and accept one connection for each of the provided
|
||||
* addresses.
|
||||
*/
|
||||
std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
|
||||
std::vector<int> accept_connections(
|
||||
const std::vector<detail::address_t>& addresses) {
|
||||
std::vector<int> sockets;
|
||||
int success;
|
||||
|
||||
for (auto& address : addresses) {
|
||||
// Create the socket to wait for connections from the peers
|
||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't create socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Make sure we can launch immediately after shutdown by setting the
|
||||
// reuseaddr option so that we don't get address already in use errors
|
||||
int enable = 1;
|
||||
success = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't enable reuseaddr (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
success = setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't enable reuseport (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Bind the socket to the address and port
|
||||
success = bind(sock, address.get(), address.len);
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't bind socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Wait for connections
|
||||
success = listen(sock, 0);
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't listen (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
int peer_socket = accept(sock, nullptr, nullptr);
|
||||
if (peer_socket < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Accept failed (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Close the listening socket
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
|
||||
sockets.push_back(peer_socket);
|
||||
detail::TCPSocket socket(RING_TAG);
|
||||
socket.listen(RING_TAG, address);
|
||||
sockets.push_back(socket.accept(RING_TAG).detach());
|
||||
}
|
||||
|
||||
return sockets;
|
||||
@@ -454,93 +347,42 @@ std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
|
||||
* provided addresses.
|
||||
*/
|
||||
std::vector<int> make_connections(
|
||||
const std::vector<address_t>& addresses,
|
||||
const std::vector<detail::address_t>& addresses,
|
||||
bool verbose) {
|
||||
std::vector<int> sockets;
|
||||
int success;
|
||||
|
||||
for (auto& address : addresses) {
|
||||
int sock;
|
||||
|
||||
// Attempt to connect to the peer CONN_ATTEMPTS times with exponential
|
||||
// backoff. TODO: Do we need that?
|
||||
for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
|
||||
// Create the socket
|
||||
sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't create socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
if (attempt > 0) {
|
||||
int wait = (1 << (attempt - 1)) * CONN_WAIT;
|
||||
log_info(
|
||||
verbose,
|
||||
"Attempt",
|
||||
attempt,
|
||||
"wait",
|
||||
wait,
|
||||
"ms (error:",
|
||||
errno,
|
||||
")");
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
|
||||
}
|
||||
|
||||
success = connect(sock, address.get(), address.len);
|
||||
if (success == 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (success < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't connect (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
sockets.push_back(sock);
|
||||
sockets.push_back(detail::TCPSocket::connect(
|
||||
RING_TAG,
|
||||
address,
|
||||
CONN_ATTEMPTS,
|
||||
CONN_WAIT,
|
||||
[verbose](int attempt, int wait) {
|
||||
log_info(
|
||||
verbose,
|
||||
"Attempt",
|
||||
attempt,
|
||||
"waiting",
|
||||
wait,
|
||||
"ms (error:",
|
||||
errno,
|
||||
")");
|
||||
})
|
||||
.detach());
|
||||
}
|
||||
|
||||
return sockets;
|
||||
}
|
||||
template <typename T>
|
||||
struct SumOp {
|
||||
void operator()(const T* input, T* output, size_t N) {
|
||||
while (N-- > 0) {
|
||||
*output += *input;
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MaxOp {
|
||||
void operator()(const T* input, T* output, size_t N) {
|
||||
while (N-- > 0) {
|
||||
*output = std::max(*output, *input);
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MinOp {
|
||||
void operator()(const T* input, T* output, size_t N) {
|
||||
while (N-- > 0) {
|
||||
*output = std::min(*output, *input);
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
class RingGroup : public GroupImpl {
|
||||
public:
|
||||
RingGroup(int rank, std::vector<std::vector<address_t>> nodes, bool verbose)
|
||||
RingGroup(
|
||||
int rank,
|
||||
std::vector<std::vector<detail::address_t>> nodes,
|
||||
bool verbose)
|
||||
: rank_(rank), verbose_(verbose), pool_(0) {
|
||||
if (rank_ > 0 && rank_ >= nodes.size()) {
|
||||
throw std::runtime_error(
|
||||
@@ -633,17 +475,17 @@ class RingGroup : public GroupImpl {
|
||||
|
||||
void all_sum(const array& input, array& output, Stream stream) override {
|
||||
SWITCH_TYPE(
|
||||
output, all_reduce<T, SumOp<T>>(input, output, stream, SumOp<T>()));
|
||||
output, all_reduce<T>(input, output, stream, detail::SumOp<T>()));
|
||||
}
|
||||
|
||||
void all_max(const array& input, array& output, Stream stream) override {
|
||||
SWITCH_TYPE(
|
||||
output, all_reduce<T, MaxOp<T>>(input, output, stream, MaxOp<T>()));
|
||||
output, all_reduce<T>(input, output, stream, detail::MaxOp<T>()));
|
||||
}
|
||||
|
||||
void all_min(const array& input, array& output, Stream stream) override {
|
||||
SWITCH_TYPE(
|
||||
output, all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>()));
|
||||
output, all_reduce<T>(input, output, stream, detail::MinOp<T>()));
|
||||
}
|
||||
|
||||
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||
|
||||
204
mlx/distributed/utils.cpp
Normal file
204
mlx/distributed/utils.cpp
Normal file
@@ -0,0 +1,204 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <netdb.h>
|
||||
#include <unistd.h>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
|
||||
#include "mlx/distributed/utils.h"
|
||||
|
||||
namespace mlx::core::distributed::detail {
|
||||
|
||||
/**
|
||||
* Parse a sockaddr from an ip and port provided as strings.
|
||||
*/
|
||||
address_t parse_address(const std::string& ip, const std::string& port) {
|
||||
struct addrinfo hints, *res;
|
||||
std::memset(&hints, 0, sizeof(hints));
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
|
||||
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
|
||||
if (status != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Can't parse address " << ip << ":" << port;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
address_t result;
|
||||
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
|
||||
result.len = res->ai_addrlen;
|
||||
freeaddrinfo(res);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse a sockaddr provided as an <ip>:<port> string.
|
||||
*/
|
||||
address_t parse_address(const std::string& ip_port) {
|
||||
auto colon = ip_port.find(":");
|
||||
if (colon == std::string::npos) {
|
||||
std::ostringstream msg;
|
||||
msg << "Can't parse address " << ip_port;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
std::string ip(ip_port.begin(), ip_port.begin() + colon);
|
||||
std::string port(ip_port.begin() + colon + 1, ip_port.end());
|
||||
|
||||
return parse_address(ip, port);
|
||||
}
|
||||
|
||||
TCPSocket::TCPSocket(const char* tag) {
|
||||
sock_ = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock_ < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Couldn't create socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
TCPSocket::TCPSocket(TCPSocket&& s) {
|
||||
sock_ = s.sock_;
|
||||
s.sock_ = -1;
|
||||
}
|
||||
|
||||
TCPSocket& TCPSocket::operator=(TCPSocket&& s) {
|
||||
if (this != &s) {
|
||||
sock_ = s.sock_;
|
||||
s.sock_ = -1;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
TCPSocket::TCPSocket(int s) : sock_(s) {}
|
||||
|
||||
TCPSocket::~TCPSocket() {
|
||||
if (sock_ > 0) {
|
||||
shutdown(sock_, 2);
|
||||
close(sock_);
|
||||
}
|
||||
}
|
||||
|
||||
int TCPSocket::detach() {
|
||||
int s = sock_;
|
||||
sock_ = -1;
|
||||
return s;
|
||||
}
|
||||
|
||||
void TCPSocket::listen(const char* tag, const address_t& addr) {
|
||||
int success;
|
||||
|
||||
// Make sure we can launch immediately after shutdown by setting the
|
||||
// reuseaddr option so that we don't get address already in use errors
|
||||
int enable = 1;
|
||||
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
||||
if (success < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Couldn't enable reuseaddr (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
|
||||
if (success < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Couldn't enable reuseport (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Bind the socket to the address and port
|
||||
success = bind(sock_, addr.get(), addr.len);
|
||||
if (success < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Couldn't bind socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Prepare waiting for connections
|
||||
success = ::listen(sock_, 0);
|
||||
if (success < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Couldn't listen (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
TCPSocket TCPSocket::accept(const char* tag) {
|
||||
int peer = ::accept(sock_, nullptr, nullptr);
|
||||
if (peer < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Accept failed (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return TCPSocket(peer);
|
||||
}
|
||||
|
||||
void TCPSocket::send(const char* tag, const void* data, size_t len) {
|
||||
while (len > 0) {
|
||||
auto n = ::send(sock_, data, len, 0);
|
||||
if (n <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Send failed with errno=" << errno;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
len -= n;
|
||||
data = static_cast<const char*>(data) + n;
|
||||
}
|
||||
}
|
||||
|
||||
void TCPSocket::recv(const char* tag, void* data, size_t len) {
|
||||
while (len > 0) {
|
||||
auto n = ::recv(sock_, data, len, 0);
|
||||
if (n <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Recv failed with errno=" << errno;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
len -= n;
|
||||
data = static_cast<char*>(data) + n;
|
||||
}
|
||||
}
|
||||
|
||||
TCPSocket TCPSocket::connect(
|
||||
const char* tag,
|
||||
const address_t& addr,
|
||||
int num_retries,
|
||||
int wait,
|
||||
std::function<void(int, int)> cb) {
|
||||
int sock, success;
|
||||
|
||||
// Attempt to connect `num_retries` times with exponential backoff.
|
||||
for (int attempt = 0; attempt < num_retries; attempt++) {
|
||||
// Create the socket
|
||||
sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Couldn't create socket to connect (error: " << errno
|
||||
<< ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
success = ::connect(sock, addr.get(), addr.len);
|
||||
if (success == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
cb(attempt, wait);
|
||||
if (wait > 0) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
|
||||
}
|
||||
|
||||
wait <<= 1;
|
||||
}
|
||||
|
||||
if (success < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Couldn't connect (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return TCPSocket(sock);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed::detail
|
||||
67
mlx/distributed/utils.h
Normal file
67
mlx/distributed/utils.h
Normal file
@@ -0,0 +1,67 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <sys/socket.h>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
namespace mlx::core::distributed::detail {
|
||||
|
||||
struct address_t {
|
||||
sockaddr_storage addr;
|
||||
socklen_t len;
|
||||
|
||||
const sockaddr* get() const {
|
||||
return (struct sockaddr*)&addr;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Parse a sockaddr from an ip and port provided as strings.
|
||||
*/
|
||||
address_t parse_address(const std::string& ip, const std::string& port);
|
||||
|
||||
/**
|
||||
* Parse a sockaddr provided as an <ip>:<port> string.
|
||||
*/
|
||||
address_t parse_address(const std::string& ip_port);
|
||||
|
||||
/**
|
||||
* Small wrapper over a TCP socket to simplify initiating connections.
|
||||
*/
|
||||
class TCPSocket {
|
||||
public:
|
||||
TCPSocket(const char* tag);
|
||||
TCPSocket(const TCPSocket&) = delete;
|
||||
TCPSocket& operator=(const TCPSocket&) = delete;
|
||||
TCPSocket(TCPSocket&& s);
|
||||
TCPSocket& operator=(TCPSocket&&);
|
||||
~TCPSocket();
|
||||
|
||||
void listen(const char* tag, const address_t& addr);
|
||||
TCPSocket accept(const char* tag);
|
||||
|
||||
void send(const char* tag, const void* data, size_t len);
|
||||
void recv(const char* tag, void* data, size_t len);
|
||||
|
||||
int detach();
|
||||
|
||||
operator int() const {
|
||||
return sock_;
|
||||
}
|
||||
|
||||
static TCPSocket connect(
|
||||
const char* tag,
|
||||
const address_t& addr,
|
||||
int num_retries = 1,
|
||||
int wait = 0,
|
||||
std::function<void(int, int)> cb = nullptr);
|
||||
|
||||
private:
|
||||
TCPSocket(int sock);
|
||||
|
||||
int sock_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::distributed::detail
|
||||
95
python/mlx/_distributed_utils/common.py
Normal file
95
python/mlx/_distributed_utils/common.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import ipaddress
|
||||
import json
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class Host:
|
||||
rank: int
|
||||
ssh_hostname: str
|
||||
ips: list[str]
|
||||
rdma: list[Optional[str]]
|
||||
|
||||
|
||||
class OptionalBoolAction(argparse.Action):
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
if option_string.startswith("--no-"):
|
||||
setattr(namespace, self.dest, False)
|
||||
else:
|
||||
setattr(namespace, self.dest, True)
|
||||
|
||||
|
||||
def positive_number(x):
|
||||
x = int(x)
|
||||
if x <= 0:
|
||||
raise ValueError("Number should be positive")
|
||||
return x
|
||||
|
||||
|
||||
def log(verbose, *args, **kwargs):
|
||||
if not verbose:
|
||||
return
|
||||
kwargs["file"] = sys.stderr
|
||||
print("\033[32m[INFO]", *args, "\033[0m", **kwargs)
|
||||
|
||||
|
||||
def log_warning(*args, **kwargs):
|
||||
kwargs["file"] = sys.stderr
|
||||
print("\033[33m[WARN]", *args, "\033[0m", **kwargs)
|
||||
|
||||
|
||||
def log_error(*args, **kwargs):
|
||||
kwargs["file"] = sys.stderr
|
||||
print("\033[31m[ERROR]", *args, "\033[0m", **kwargs)
|
||||
|
||||
|
||||
def parse_hostlist(parser, hostlist, repeats):
|
||||
hosts = []
|
||||
for i, h in enumerate(hostlist.split(",")):
|
||||
if h == "":
|
||||
raise ValueError("Hostname cannot be empty")
|
||||
try:
|
||||
ipaddress.ip_address(h)
|
||||
ips = [h]
|
||||
except ValueError:
|
||||
ips = []
|
||||
for i in range(repeats):
|
||||
hosts.append(Host(i, h, ips, []))
|
||||
return hosts
|
||||
|
||||
|
||||
def parse_hostfile(parser, hostfile):
|
||||
"""Parse the json hostfile that contains both the hostnames to ssh into and
|
||||
the ips to communicate over when using the ring backend.
|
||||
|
||||
Example:
|
||||
|
||||
[
|
||||
{"ssh": "hostname1", "ips": ["123.123.123.1"], "rdma": [null, "rdma_en2", "rdma_en3"]},
|
||||
{"ssh": "hostname2", "ips": ["123.123.123.2"], "rdma": ["rdma_en2", null, "rdma_en3"]},
|
||||
...
|
||||
{"ssh": "hostnameN", "ips": ["123.123.123.N"], "rdma": ["rdma_en2", "rdma_en3", null]},
|
||||
]
|
||||
|
||||
Args:
|
||||
hostfile (str): The path to the json file containing the host
|
||||
information
|
||||
"""
|
||||
hostfile = Path(hostfile)
|
||||
if not hostfile.exists():
|
||||
parser.error(f"Hostfile {str(hostfile)} doesn't exist")
|
||||
|
||||
try:
|
||||
hosts = []
|
||||
with open(hostfile) as f:
|
||||
for i, h in enumerate(json.load(f)):
|
||||
hosts.append(Host(i, h["ssh"], h.get("ips", []), h.get("rdma", [])))
|
||||
return hosts
|
||||
except Exception as e:
|
||||
parser.error(f"Failed to parse hostfile {str(hostfile)} ({str(e)})")
|
||||
568
python/mlx/_distributed_utils/config.py
Normal file
568
python/mlx/_distributed_utils/config.py
Normal file
@@ -0,0 +1,568 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import shlex
|
||||
import sys
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from subprocess import DEVNULL, run
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .common import (
|
||||
Host,
|
||||
OptionalBoolAction,
|
||||
log,
|
||||
log_error,
|
||||
parse_hostfile,
|
||||
parse_hostlist,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SSHInfo:
|
||||
can_ssh: bool
|
||||
has_sudo: bool
|
||||
|
||||
def __bool__(self):
|
||||
return self.can_ssh
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThunderboltPort:
|
||||
iface: str
|
||||
uuid: str
|
||||
connected_to: Optional[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThunderboltHost:
|
||||
name: str
|
||||
ports: list[ThunderboltPort]
|
||||
|
||||
|
||||
def add_ethernet_ips(hosts, verbose=False):
|
||||
# Get the ips for each host
|
||||
for h in hosts:
|
||||
log(verbose, "Getting the ip from", h.ssh_hostname)
|
||||
h.ips.append(
|
||||
run(
|
||||
["ssh", h.ssh_hostname, "ipconfig", "getifaddr", "en0"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
).stdout.strip()
|
||||
)
|
||||
|
||||
|
||||
def check_rdma(hosts, verbose=False):
|
||||
# Check whether the hosts are capable of RDMA over thunderbolt
|
||||
warn = False
|
||||
for h in hosts:
|
||||
log(verbose, "Checking that", h.ssh_hostname, "supports RDMA")
|
||||
rdma_devs = (
|
||||
run(["ssh", h.ssh_hostname, "ibv_devices"], capture_output=True, text=True)
|
||||
.stdout.strip()
|
||||
.split()
|
||||
)
|
||||
rdma_devs = [d for d in rdma_devs if d.startswith("rdma_")]
|
||||
if not rdma_devs:
|
||||
log_warning(h.ssh_hostname, "does not seem to have RDMA enabled")
|
||||
warn = True
|
||||
|
||||
if warn:
|
||||
log_warning()
|
||||
log_warning(
|
||||
"Some of the hosts don't have RDMA enabled or they don't support RDMA."
|
||||
)
|
||||
log_warning()
|
||||
log_warning(
|
||||
"See https://ml-explore.github.io/mlx/build/html/usage/distributed.html"
|
||||
)
|
||||
log_warning("for instructions on how to enable RDMA.")
|
||||
|
||||
|
||||
def can_auto_setup(hosts, sshinfo, auto_setup=False):
|
||||
has_sudo = all(info.has_sudo for info in sshinfo)
|
||||
if not has_sudo and auto_setup:
|
||||
log_warning(
|
||||
"Automatic setup requested but the following hosts do not have passwordless sudo"
|
||||
)
|
||||
for h, i in zip(hosts, sshinfo):
|
||||
if not i.has_sudo:
|
||||
log_warning(" - ", h.ssh_hostname)
|
||||
return has_sudo
|
||||
|
||||
|
||||
class IPConfigurator:
|
||||
def __init__(self, hosts, tb_hosts, uuid_reverse_index):
|
||||
assigned = set()
|
||||
ips = defaultdict(list)
|
||||
ip0 = 0
|
||||
ip1 = 0
|
||||
for src_node, h in enumerate(tb_hosts):
|
||||
for src_port, p in enumerate(h.ports):
|
||||
if not p.connected_to:
|
||||
continue
|
||||
if (src_node, src_port) in assigned:
|
||||
continue
|
||||
|
||||
dst_node, dst_port = uuid_reverse_index[p.connected_to]
|
||||
|
||||
ip_src = f"192.168.{ip0}.{ip1 + 1}"
|
||||
ip_dst = f"192.168.{ip0}.{ip1 + 2}"
|
||||
iface_src = p.iface
|
||||
iface_dst = tb_hosts[dst_node].ports[dst_port].iface
|
||||
|
||||
ips[src_node, dst_node].append((iface_src, ip_src))
|
||||
ips[dst_node, src_node].append((iface_dst, ip_dst))
|
||||
|
||||
assigned.add((src_node, src_port))
|
||||
assigned.add((dst_node, dst_port))
|
||||
|
||||
ip1 += 4
|
||||
if ip1 > 255:
|
||||
ip0 += 1
|
||||
ip1 = 0
|
||||
if ip0 > 255:
|
||||
raise ValueError("Ran out of available local IPs")
|
||||
|
||||
self.ips = ips
|
||||
self.hosts = hosts
|
||||
self.tb_hosts = tb_hosts
|
||||
|
||||
def setup(self, verbose=False, auto_setup=False):
|
||||
netmask = "255.255.255.252"
|
||||
for i, (h, th) in enumerate(zip(self.hosts, self.tb_hosts)):
|
||||
command = ""
|
||||
command += "sudo ifconfig bridge0 down\n"
|
||||
for j in range(len(self.hosts)):
|
||||
if i == j or (i, j) not in self.ips:
|
||||
continue
|
||||
for (iface, ip), (_, peer) in zip(self.ips[i, j], self.ips[j, i]):
|
||||
command += f"sudo ifconfig {iface} inet {ip} netmask {netmask}\n"
|
||||
command += f"sudo route change {peer} -interface {iface}\n"
|
||||
if auto_setup:
|
||||
print(f"Running auto setup for {h.ssh_hostname}")
|
||||
command = command.strip().replace("\n", " ; ")
|
||||
command = ["ssh", h.ssh_hostname, command]
|
||||
log(verbose, shlex.join(command))
|
||||
run(command)
|
||||
else:
|
||||
msg = f"Setup for {h.ssh_hostname}"
|
||||
print(msg)
|
||||
print("=" * len(msg))
|
||||
print(command)
|
||||
input("Enter to continue")
|
||||
print()
|
||||
|
||||
|
||||
def parse_hardware_ports(ports_string):
|
||||
ports = {}
|
||||
port_name = None
|
||||
for l in ports_string.decode("utf-8").split("\n"):
|
||||
if l.startswith("Hardware Port:"):
|
||||
port_name = l.strip()[15:]
|
||||
elif l.startswith("Device:"):
|
||||
ports[port_name] = l.strip()[8:]
|
||||
port_name = None
|
||||
return ports
|
||||
|
||||
|
||||
def extract_connectivity(hosts, verbose):
|
||||
# Extract the current connectivity from the remote hosts
|
||||
thunderbolt_connections = []
|
||||
for h in hosts:
|
||||
log(verbose, "Getting connectivity from", h.ssh_hostname)
|
||||
thunderbolt_connections.append(
|
||||
json.loads(
|
||||
run(
|
||||
[
|
||||
"ssh",
|
||||
h.ssh_hostname,
|
||||
"system_profiler",
|
||||
"SPThunderboltDataType",
|
||||
"-json",
|
||||
],
|
||||
capture_output=True,
|
||||
).stdout
|
||||
)
|
||||
)
|
||||
interface_maps = []
|
||||
for h in hosts:
|
||||
log(verbose, "Getting interface names from", h.ssh_hostname)
|
||||
interface_maps.append(
|
||||
parse_hardware_ports(
|
||||
run(
|
||||
[
|
||||
"ssh",
|
||||
h.ssh_hostname,
|
||||
"networksetup",
|
||||
"-listallhardwareports",
|
||||
],
|
||||
capture_output=True,
|
||||
).stdout
|
||||
)
|
||||
)
|
||||
|
||||
# Parse the connectivity into some simple dataclasses
|
||||
tb_hosts = []
|
||||
for c, iface_map in zip(thunderbolt_connections, interface_maps):
|
||||
name = ""
|
||||
ports = []
|
||||
for t in c["SPThunderboltDataType"]:
|
||||
uuid = t.get("domain_uuid_key")
|
||||
if uuid is None:
|
||||
continue
|
||||
name = t["device_name_key"]
|
||||
tag = t["receptacle_1_tag"]["receptacle_id_key"]
|
||||
items = t.get("_items", [])
|
||||
connected_items = [item for item in items if "domain_uuid_key" in item]
|
||||
connected_to = (
|
||||
connected_items[0]["domain_uuid_key"] if connected_items else None
|
||||
)
|
||||
iface = iface_map[f"Thunderbolt {tag}"]
|
||||
ports.append(ThunderboltPort(iface, uuid, connected_to))
|
||||
tb_hosts.append(ThunderboltHost(name, sorted(ports, key=lambda x: x.iface)))
|
||||
|
||||
# Create a reverse index to be able to map uuids to (host, port) quickly
|
||||
uuid_reverse_index = {}
|
||||
for i, h in enumerate(tb_hosts):
|
||||
for j, p in enumerate(h.ports):
|
||||
uuid_reverse_index[p.uuid] = (i, j)
|
||||
|
||||
return tb_hosts, uuid_reverse_index
|
||||
|
||||
|
||||
def make_connectivity_matrix(tb_hosts, uuid_reverse_index):
|
||||
connectivity = []
|
||||
for i, h in enumerate(tb_hosts):
|
||||
c = [0] * len(tb_hosts)
|
||||
for p in h.ports:
|
||||
if p.connected_to is not None:
|
||||
j, _ = uuid_reverse_index[p.connected_to]
|
||||
c[j] += 1
|
||||
connectivity.append(c)
|
||||
return connectivity
|
||||
|
||||
|
||||
def tb_connectivity_to_dot(hosts, tb_hosts, uuid_reverse_index):
|
||||
# Make ids per node
|
||||
names = []
|
||||
for i in range(len(tb_hosts)):
|
||||
n = ""
|
||||
j = i
|
||||
while True:
|
||||
n += chr(97 + j % 26)
|
||||
j //= 26
|
||||
if j == 0:
|
||||
break
|
||||
names.append(n)
|
||||
|
||||
print("graph G {")
|
||||
print(" node [shape=rectangle];")
|
||||
for i, h in enumerate(hosts):
|
||||
print(f' {names[i]} [label="{h.ssh_hostname}"];')
|
||||
for i, h in enumerate(tb_hosts):
|
||||
for p in h.ports:
|
||||
if not p.connected_to:
|
||||
continue
|
||||
dst = uuid_reverse_index[p.connected_to]
|
||||
if dst[0] < i:
|
||||
continue
|
||||
print(f" {names[i]} -- {names[dst[0]]}", end="")
|
||||
print(f' [label="{p.iface}/{tb_hosts[dst[0]].ports[dst[1]].iface}"]')
|
||||
print("}")
|
||||
|
||||
|
||||
def extract_rings(connectivity):
|
||||
rings = []
|
||||
existing_rings = set()
|
||||
num_nodes = len(connectivity)
|
||||
|
||||
def dfs(start_node, node, path, visited):
|
||||
path.append(node)
|
||||
visited.add(node)
|
||||
for j in range(num_nodes):
|
||||
if connectivity[node][j] <= 0:
|
||||
continue
|
||||
if j == start_node:
|
||||
yield path[:]
|
||||
if j not in visited:
|
||||
yield from dfs(start_node, j, path, visited)
|
||||
path.pop()
|
||||
visited.remove(node)
|
||||
|
||||
for start in range(num_nodes):
|
||||
for r in dfs(start, start, [], set()):
|
||||
cnt = min(connectivity[r[i]][r[(i + 1) % len(r)]] for i in range(len(r)))
|
||||
rkey = tuple(sorted(r))
|
||||
if rkey not in existing_rings:
|
||||
rings.append((r, cnt))
|
||||
existing_rings.add(rkey)
|
||||
|
||||
return sorted(rings, key=lambda x: -len(x[0]))
|
||||
|
||||
|
||||
def check_valid_mesh(hosts, connectivity, strict=True):
|
||||
num_nodes = len(connectivity)
|
||||
for i in range(num_nodes):
|
||||
for j in range(num_nodes):
|
||||
if i == j:
|
||||
continue
|
||||
if connectivity[i][j] <= 0:
|
||||
if strict:
|
||||
log_error(
|
||||
f"Incomplete mesh, {hosts[i].ssh_hostname} is not connected to {hosts[j].ssh_hostname}"
|
||||
)
|
||||
log_error()
|
||||
log_error("Try passing --dot to visualize the connectivity")
|
||||
sys.exit(1)
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_ssh_connections(hosts):
|
||||
results = [None] * len(hosts)
|
||||
|
||||
def _check(hostname, i):
|
||||
info = SSHInfo(False, False)
|
||||
results[i] = info
|
||||
|
||||
# Check for ssh
|
||||
result = run(
|
||||
[
|
||||
"ssh",
|
||||
"-o",
|
||||
"BatchMode=yes",
|
||||
"-o",
|
||||
"ConnectTimeout=5",
|
||||
hostname,
|
||||
"echo",
|
||||
"success",
|
||||
],
|
||||
stdout=DEVNULL,
|
||||
stderr=DEVNULL,
|
||||
)
|
||||
info.can_ssh = result.returncode == 0
|
||||
if not info.can_ssh:
|
||||
return
|
||||
|
||||
# Check for sudo
|
||||
result = run(
|
||||
[
|
||||
"ssh",
|
||||
"-o",
|
||||
"BatchMode=yes",
|
||||
"-o",
|
||||
"ConnectTimeout=5",
|
||||
hostname,
|
||||
"sudo",
|
||||
"ls",
|
||||
],
|
||||
stdout=DEVNULL,
|
||||
stderr=DEVNULL,
|
||||
)
|
||||
info.has_sudo = result.returncode == 0
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=_check, args=(h.ssh_hostname, i))
|
||||
for i, h in enumerate(hosts)
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
if not all(results):
|
||||
log_error("Could not ssh to the following hosts:")
|
||||
for i, h in enumerate(hosts):
|
||||
if not results[i]:
|
||||
log_error(" - ", h.ssh_hostname)
|
||||
log_error()
|
||||
log_error("Maybe they are not set-up for password-less ssh?")
|
||||
sys.exit(1)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def prepare_ethernet_hostfile(args, hosts):
|
||||
log(args.verbose, f"Preparing an ethernet hostfile")
|
||||
add_ethernet_ips(hosts, args.verbose)
|
||||
|
||||
hostfile = []
|
||||
for h in hosts:
|
||||
hostfile.append(dict(ssh=h.ssh_hostname, ips=h.ips))
|
||||
|
||||
if args.output_hostfile:
|
||||
with open(args.output_hostfile, "w") as f:
|
||||
json.dump(hostfile, f, indent=4)
|
||||
else:
|
||||
print("Hostfile")
|
||||
print("========")
|
||||
print(json.dumps(hostfile, indent=4))
|
||||
|
||||
|
||||
def configure_ring(args, hosts, ips, ring, sshinfo):
|
||||
log(args.verbose, "Prepare a ring hostfile")
|
||||
ring, count = ring
|
||||
hostfile = []
|
||||
for i, node in enumerate(ring):
|
||||
h = hosts[node]
|
||||
peer = ring[i - 1]
|
||||
hostfile.append(
|
||||
{
|
||||
"ssh": h.ssh_hostname,
|
||||
"ips": [ips.ips[node, peer][c][1] for c in range(count)],
|
||||
"rdma": [],
|
||||
}
|
||||
)
|
||||
|
||||
has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup)
|
||||
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo)
|
||||
|
||||
if args.output_hostfile:
|
||||
with open(args.output_hostfile, "w") as f:
|
||||
json.dump(hostfile, f, indent=4)
|
||||
else:
|
||||
print("Hostfile")
|
||||
print("========")
|
||||
print(json.dumps(hostfile, indent=4))
|
||||
|
||||
|
||||
def configure_jaccl(args, hosts, ips, sshinfo):
|
||||
log(args.verbose, "Prepare a jaccl hostfile")
|
||||
check_rdma(hosts, args.verbose)
|
||||
add_ethernet_ips(hosts, args.verbose)
|
||||
|
||||
hostfile = []
|
||||
for i, h in enumerate(hosts):
|
||||
rdma = []
|
||||
for j in range(len(hosts)):
|
||||
if i == j:
|
||||
rdma.append(None)
|
||||
else:
|
||||
rdma.append(f"rdma_{ips.ips[i, j][0][0]}")
|
||||
hostfile.append({"ssh": h.ssh_hostname, "ips": h.ips, "rdma": rdma})
|
||||
|
||||
has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup)
|
||||
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo)
|
||||
|
||||
if args.output_hostfile:
|
||||
with open(args.output_hostfile, "w") as f:
|
||||
json.dump(hostfile, f, indent=4)
|
||||
else:
|
||||
print("Hostfile")
|
||||
print("========")
|
||||
print(json.dumps(hostfile, indent=4))
|
||||
|
||||
|
||||
def prepare_tb_hostfile(args, hosts, sshinfo):
|
||||
log(args.verbose, f"Preparing for communication over thunderbolt")
|
||||
tb_hosts, uuid_reverse_index = extract_connectivity(hosts, args.verbose)
|
||||
|
||||
if args.dot:
|
||||
tb_connectivity_to_dot(hosts, tb_hosts, uuid_reverse_index)
|
||||
return
|
||||
|
||||
ips = IPConfigurator(hosts, tb_hosts, uuid_reverse_index)
|
||||
connectivity = make_connectivity_matrix(tb_hosts, uuid_reverse_index)
|
||||
|
||||
if args.backend is None:
|
||||
rings = extract_rings(connectivity)
|
||||
has_mesh = check_valid_mesh(hosts, connectivity, False)
|
||||
has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts)
|
||||
|
||||
if not has_ring and not has_mesh:
|
||||
log_error("Neither thunderbolt mesh nor ring found.")
|
||||
log_error("Perhaps run with --dot to generate a plot of the connectivity.")
|
||||
sys.exit(1)
|
||||
|
||||
elif has_ring:
|
||||
configure_ring(args, hosts, ips, rings[0], sshinfo)
|
||||
|
||||
else:
|
||||
configure_jaccl(args, hosts, ips, sshinfo)
|
||||
|
||||
elif args.backend == "ring":
|
||||
rings = extract_rings(connectivity)
|
||||
has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts)
|
||||
if not has_ring:
|
||||
log_error("Could not find a full ring.")
|
||||
log_error()
|
||||
log_error("Try passing --dot to visualize the connectivity")
|
||||
if len(rings) > 0:
|
||||
log_error("Rings found:")
|
||||
for r in rings:
|
||||
log_error(f" - {','.join(hosts[i].ssh_hostname for i in r)}")
|
||||
sys.exit(1)
|
||||
configure_ring(args, hosts, ips, rings[0], sshinfo)
|
||||
|
||||
elif args.backend == "jaccl":
|
||||
check_valid_mesh(hosts, connectivity)
|
||||
configure_jaccl(args, hosts, ips, sshinfo)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Configure remote machines for use with MLX distributed"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", action="store_true", help="Print debug messages in stdout"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
|
||||
)
|
||||
parser.add_argument("--hostfile", help="The file containing the hosts")
|
||||
parser.add_argument(
|
||||
"--over",
|
||||
choices=["thunderbolt", "ethernet"],
|
||||
default="thunderbolt",
|
||||
help="What type of connectivity to configure",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-hostfile", help="If provided, save the hostfile to this path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--auto-setup",
|
||||
"--no-auto-setup",
|
||||
action=OptionalBoolAction,
|
||||
nargs=0,
|
||||
dest="auto_setup",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dot", action="store_true", help="Output the topology in DOT format and exit"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
choices=["ring", "jaccl"],
|
||||
default=None,
|
||||
help="Which distributed backend to configure",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.hostfile is not None:
|
||||
hosts = parse_hostfile(parser, args.hostfile)
|
||||
else:
|
||||
hosts = parse_hostlist(parser, args.hosts, 1)
|
||||
|
||||
# Check that we can ssh
|
||||
log(
|
||||
args.verbose,
|
||||
f"Checking for ssh access for {', '.join(h.ssh_hostname for h in hosts)}",
|
||||
)
|
||||
sshinfo = check_ssh_connections(hosts)
|
||||
|
||||
# Prepare a hostfile for communication over ethernet using the ips of the
|
||||
# provided hostnames
|
||||
if args.over == "ethernet":
|
||||
prepare_ethernet_hostfile(args, hosts)
|
||||
|
||||
# Configure the macs for communication over thunderbolt, both via RDMA and IP
|
||||
else:
|
||||
prepare_tb_hostfile(args, hosts, sshinfo)
|
||||
546
python/mlx/_distributed_utils/launch.py
Normal file
546
python/mlx/_distributed_utils/launch.py
Normal file
@@ -0,0 +1,546 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import shlex
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
from collections import Counter
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from queue import Empty as QueueEmpty
|
||||
from queue import Queue
|
||||
from select import select
|
||||
from subprocess import PIPE, Popen, run
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .common import log, log_warning, parse_hostfile, parse_hostlist, positive_number
|
||||
|
||||
|
||||
class CommandProcess:
|
||||
@property
|
||||
def process(self):
|
||||
"""Return the Popen object that refers to the current command."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def exit_status(self):
|
||||
"""Return a tuple (returncode, killed) for the command. It should be
|
||||
(None, None) while the command is running normally."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def preprocess_output(self, data: str, is_stdout=False):
|
||||
"""Preprocess the output of the command so that extra data can be
|
||||
capture or the format changed on the fly."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def terminate(self):
|
||||
"""Terminate or return the exit code."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class RemoteProcess(CommandProcess):
|
||||
def __init__(self, rank, host, python, cwd, files, env, command):
|
||||
is_local = host == "127.0.0.1"
|
||||
cmd = RemoteProcess.make_launch_script(rank, cwd, files, env, command)
|
||||
if not is_local:
|
||||
cmd = f"ssh {host} {shlex.quote(cmd)}"
|
||||
|
||||
self._host = host
|
||||
self._pidfile = None
|
||||
self._is_local = is_local
|
||||
self._process = Popen(
|
||||
cmd,
|
||||
shell=True,
|
||||
stdin=PIPE,
|
||||
stdout=PIPE,
|
||||
stderr=PIPE,
|
||||
)
|
||||
|
||||
self._killed = False
|
||||
|
||||
@property
|
||||
def process(self):
|
||||
return self._process
|
||||
|
||||
@property
|
||||
def exit_status(self):
|
||||
return self._process.poll(), self._killed
|
||||
|
||||
def preprocess_output(self, data, is_stdout=False):
|
||||
if self._pidfile is None:
|
||||
pidfile, *rest = data.split("\n", maxsplit=1)
|
||||
self._pidfile = pidfile
|
||||
return rest[0] if rest else ""
|
||||
|
||||
return data
|
||||
|
||||
def terminate(self):
|
||||
if self._killed:
|
||||
return
|
||||
|
||||
self._process.terminate()
|
||||
self._process.wait()
|
||||
|
||||
# Kill the remote program if possible
|
||||
cmd = RemoteProcess.make_kill_script(self._pidfile)
|
||||
if not self._is_local:
|
||||
cmd = f"ssh {self._host} {shlex.quote(cmd)}"
|
||||
c = run(cmd, check=True, shell=True, capture_output=True, text=True)
|
||||
|
||||
self._killed = c.stdout.strip() == "1"
|
||||
|
||||
@staticmethod
|
||||
def make_launch_script(rank, cwd, files, env, command):
|
||||
script = ""
|
||||
|
||||
# Write the PID to a file so we can kill the process if needed
|
||||
script += "pidfile=$(mktemp); "
|
||||
script += "echo $$ > $pidfile; "
|
||||
script += "echo $pidfile; "
|
||||
|
||||
# Change the working directory if one was requested. Otherwise attempt to
|
||||
# change to the current one but don't fail if it wasn't possible.
|
||||
d = cwd or os.getcwd()
|
||||
script += f"if [[ -d {repr(d)} ]]; then "
|
||||
script += f" cd {repr(d)}; "
|
||||
if cwd is not None:
|
||||
script += "else "
|
||||
script += f" echo 'Failed to change directory to' {repr(d)} >2; "
|
||||
script += "fi; "
|
||||
|
||||
# Add the environment variables that were requested
|
||||
for e in env:
|
||||
key, *value = e.split("=", maxsplit=1)
|
||||
value = shlex.quote(value[0]) if len(value) > 0 else ""
|
||||
if not all(c.isalnum() or c == "_" for c in key):
|
||||
log_warning(
|
||||
f"'{e}' is an invalid environment variable so it is ignored"
|
||||
)
|
||||
continue
|
||||
script += f"export {key}={value}; "
|
||||
|
||||
# Make the temporary files
|
||||
for env_name, content in files.items():
|
||||
script += "fname=$(mktemp); "
|
||||
script += f"echo {shlex.quote(content)} >$fname; "
|
||||
script += f"export {env_name}=$fname; "
|
||||
|
||||
# Finally add the rank
|
||||
script += f"export MLX_RANK={rank}; "
|
||||
|
||||
# Replace the process with the script
|
||||
script += f"cmd=({' '.join(map(shlex.quote, command))}); "
|
||||
script += 'exec "${cmd[@]}"'
|
||||
|
||||
return script
|
||||
|
||||
@staticmethod
|
||||
def make_kill_script(pidfile):
|
||||
script = ""
|
||||
script += f"pid=$(cat {pidfile}); "
|
||||
script += "if ps -p $pid >/dev/null; then "
|
||||
script += " kill $pid; "
|
||||
script += " echo 1; "
|
||||
script += "else "
|
||||
script += " echo 0; "
|
||||
script += "fi; "
|
||||
script += f"rm {pidfile}"
|
||||
|
||||
return script
|
||||
|
||||
|
||||
def _launch_with_io(command_class, arguments, verbose):
|
||||
stop = False
|
||||
exit_codes = [(None, None)] * len(arguments)
|
||||
|
||||
def _thread_fn(rank, *args, **kwargs):
|
||||
stdin_queue = kwargs.pop("stdin_queue")
|
||||
stdout_queue = kwargs.pop("stdout_queue")
|
||||
stderr_queue = kwargs.pop("stderr_queue")
|
||||
|
||||
command = command_class(rank, *args, **kwargs)
|
||||
p = command.process
|
||||
os.set_blocking(p.stdout.fileno(), False)
|
||||
os.set_blocking(p.stderr.fileno(), False)
|
||||
os.set_blocking(p.stdin.fileno(), False)
|
||||
|
||||
to_read = [p.stdout.fileno(), p.stderr.fileno()]
|
||||
to_write = [p.stdin.fileno()]
|
||||
|
||||
stdin_buffer = b""
|
||||
while p.poll() is None:
|
||||
try:
|
||||
stdin_buffer += stdin_queue.get_nowait()
|
||||
except QueueEmpty:
|
||||
pass
|
||||
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
|
||||
for fd in rlist:
|
||||
is_stdout = fd == p.stdout.fileno()
|
||||
msg = os.read(fd, 8192).decode(errors="ignore")
|
||||
msg = command.preprocess_output(msg, is_stdout)
|
||||
if is_stdout:
|
||||
stdout_queue.put(msg.encode())
|
||||
else:
|
||||
stderr_queue.put(msg.encode())
|
||||
for fd in wlist:
|
||||
if len(stdin_buffer) > 0:
|
||||
n = os.write(fd, stdin_buffer)
|
||||
stdin_buffer = stdin_buffer[n:]
|
||||
if stop:
|
||||
command.terminate()
|
||||
break
|
||||
exit_codes[rank] = command.exit_status
|
||||
|
||||
if exit_codes[rank][1]:
|
||||
log_warning(f"Node with rank {rank} was killed")
|
||||
elif exit_codes[rank][0] != 0:
|
||||
log_warning(f"Node with rank {rank} exited with code {exit_codes[rank][0]}")
|
||||
else:
|
||||
log(verbose, f"Node with rank {rank} completed")
|
||||
|
||||
stdin_queues = []
|
||||
stdout_queues = []
|
||||
stderr_queues = []
|
||||
threads = []
|
||||
for i, (args, kwargs) in enumerate(arguments):
|
||||
stdin_queues.append(Queue())
|
||||
stdout_queues.append(Queue())
|
||||
stderr_queues.append(Queue())
|
||||
t = threading.Thread(
|
||||
target=_thread_fn,
|
||||
args=args,
|
||||
kwargs=kwargs
|
||||
| {
|
||||
"stdin_queue": stdin_queues[-1],
|
||||
"stdout_queue": stdout_queues[-1],
|
||||
"stderr_queue": stderr_queues[-1],
|
||||
},
|
||||
)
|
||||
t.start()
|
||||
threads.append(t)
|
||||
|
||||
os.set_blocking(sys.stdin.fileno(), False)
|
||||
os.set_blocking(sys.stdout.fileno(), True)
|
||||
os.set_blocking(sys.stderr.fileno(), True)
|
||||
while not stop or any(not q.empty() for q in chain(stdout_queues, stderr_queues)):
|
||||
# Broadcast user input to the jobs
|
||||
rlist, _, _ = select([sys.stdin.fileno()], [], [], 0.1)
|
||||
for fd in rlist:
|
||||
stdin_buffer = os.read(fd, 8192)
|
||||
for q in stdin_queues:
|
||||
q.put(stdin_buffer)
|
||||
|
||||
# Gather job output
|
||||
for q in stdout_queues:
|
||||
try:
|
||||
while not q.empty():
|
||||
sys.stdout.buffer.write(q.get_nowait())
|
||||
except QueueEmpty:
|
||||
pass
|
||||
for q in stderr_queues:
|
||||
try:
|
||||
while not q.empty():
|
||||
sys.stderr.buffer.write(q.get_nowait())
|
||||
except QueueEmpty:
|
||||
pass
|
||||
sys.stdout.buffer.flush()
|
||||
sys.stderr.buffer.flush()
|
||||
|
||||
# Check if all are running and terminate otherwise
|
||||
if any(t.is_alive() for t in threads):
|
||||
for i, t in enumerate(threads):
|
||||
if not t.is_alive():
|
||||
if exit_codes[i][0] != 0:
|
||||
stop = True
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
# Wait for the jobs to finish
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Process any remaining outputs
|
||||
for q in stdout_queues:
|
||||
while not q.empty():
|
||||
sys.stdout.buffer.write(q.get())
|
||||
for q in stderr_queues:
|
||||
while not q.empty():
|
||||
sys.stderr.buffer.write(q.get())
|
||||
sys.stdout.buffer.flush()
|
||||
sys.stderr.buffer.flush()
|
||||
|
||||
|
||||
def launch_ring(parser, hosts, args, command):
|
||||
if any(len(h.ips) == 0 for h in hosts):
|
||||
parser.error(
|
||||
"The ring backend requires IPs to be provided instead of hostnames"
|
||||
)
|
||||
|
||||
port = args.starting_port
|
||||
ring_hosts = []
|
||||
for h in hosts:
|
||||
node = []
|
||||
for ip in h.ips:
|
||||
for i in range(args.connections_per_ip):
|
||||
node.append(f"{ip}:{port}")
|
||||
port += 1
|
||||
ring_hosts.append(node)
|
||||
hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else ""
|
||||
|
||||
files = {"MLX_HOSTFILE": hostfile}
|
||||
env = args.env
|
||||
if args.verbose:
|
||||
env.append("MLX_RING_VERBOSE=1")
|
||||
cwd = args.cwd
|
||||
|
||||
log(args.verbose, "Running", shlex.join(command))
|
||||
|
||||
_launch_with_io(
|
||||
RemoteProcess,
|
||||
[
|
||||
((rank, h.ssh_hostname, args.python, cwd, files, env, command), {})
|
||||
for rank, h in enumerate(hosts)
|
||||
],
|
||||
args.verbose,
|
||||
)
|
||||
|
||||
|
||||
def launch_nccl(parser, hosts, args, command):
|
||||
if not hosts[0].ips:
|
||||
raise ValueError("Rank 0 should have an IP reachable from all other ranks")
|
||||
|
||||
master_host = hosts[0].ips[0]
|
||||
master_port = args.nccl_port
|
||||
world_size = len(hosts)
|
||||
|
||||
env = args.env
|
||||
cwd = args.cwd
|
||||
if args.verbose:
|
||||
env.append("NCCL_DEBUG=INFO")
|
||||
env.append(f"NCCL_HOST_IP={master_host}")
|
||||
env.append(f"NCCL_PORT={master_port}")
|
||||
env.append(f"MLX_WORLD_SIZE={world_size}")
|
||||
|
||||
log(args.verbose, "Running", shlex.join(command))
|
||||
|
||||
_launch_with_io(
|
||||
RemoteProcess,
|
||||
[
|
||||
(
|
||||
(
|
||||
rank,
|
||||
h.ssh_hostname,
|
||||
args.python,
|
||||
cwd,
|
||||
{},
|
||||
env + [f"CUDA_VISIBLE_DEVICES={rank % args.repeat_hosts}"],
|
||||
command,
|
||||
),
|
||||
{},
|
||||
)
|
||||
for rank, h in enumerate(hosts)
|
||||
],
|
||||
args.verbose,
|
||||
)
|
||||
|
||||
|
||||
def launch_jaccl(parser, hosts, args, command):
|
||||
if not hosts[0].ips:
|
||||
raise ValueError("Rank 0 should have an IP reachable from all other ranks")
|
||||
|
||||
have_rdmas = all(len(h.rdma) == len(hosts) for h in hosts)
|
||||
have_nulls = all(h.rdma[i] is None for i, h in enumerate(hosts))
|
||||
if not have_rdmas or not have_nulls:
|
||||
raise ValueError("Malformed hostfile for jaccl backend")
|
||||
|
||||
coordinator = hosts[0].ips[0]
|
||||
env = args.env
|
||||
cwd = args.cwd
|
||||
env.append(f"MLX_JACCL_COORDINATOR={coordinator}:{args.starting_port}")
|
||||
files = {"MLX_IBV_DEVICES": json.dumps([h.rdma for h in hosts])}
|
||||
|
||||
log(args.verbose, "Running", shlex.join(command))
|
||||
|
||||
_launch_with_io(
|
||||
RemoteProcess,
|
||||
[
|
||||
((rank, h.ssh_hostname, args.python, cwd, files, env, command), {})
|
||||
for rank, h in enumerate(hosts)
|
||||
],
|
||||
args.verbose,
|
||||
)
|
||||
|
||||
|
||||
def get_mpi_libname():
|
||||
try:
|
||||
ompi_info = run(["which", "ompi_info"], check=True, capture_output=True)
|
||||
ompi_info = ompi_info.stdout.strip().decode()
|
||||
|
||||
if platform.system() == "Darwin":
|
||||
otool_output = run(
|
||||
["otool", "-L", ompi_info], check=True, capture_output=True
|
||||
)
|
||||
else:
|
||||
otool_output = run(["ldd", ompi_info], check=True, capture_output=True)
|
||||
otool_output = otool_output.stdout.decode()
|
||||
|
||||
# StopIteration if not found
|
||||
libmpi_line = next(
|
||||
filter(lambda line: "libmpi" in line, otool_output.splitlines())
|
||||
)
|
||||
return libmpi_line.strip().split()[0].removeprefix("@rpath/")
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def launch_mpi(parser, hosts, args, command):
|
||||
mpirun = run(["which", "mpirun"], check=True, capture_output=True)
|
||||
mpirun = mpirun.stdout.strip().decode()
|
||||
|
||||
# Compatibility with homebrew and pip installs
|
||||
mpi_libname = get_mpi_libname()
|
||||
if mpi_libname is not None:
|
||||
dyld = Path(mpirun).parent.parent / "lib"
|
||||
args.env = [
|
||||
f"DYLD_LIBRARY_PATH={str(dyld)}",
|
||||
f"MLX_MPI_LIBNAME={mpi_libname}",
|
||||
] + args.env
|
||||
|
||||
log(args.verbose, f"Using '{mpirun}'")
|
||||
with tempfile.NamedTemporaryFile(mode="w") as f:
|
||||
hosts = Counter((h.ssh_hostname for h in hosts))
|
||||
for h, n in hosts.items():
|
||||
print(f"{h} slots={n}", file=f)
|
||||
f.flush()
|
||||
|
||||
cmd = [
|
||||
mpirun,
|
||||
"--output",
|
||||
":raw", # do not line buffer output
|
||||
"--hostfile",
|
||||
f.name,
|
||||
*(["-cwd", args.cwd] if args.cwd else []),
|
||||
*sum((["-x", e] for e in args.env), []),
|
||||
*sum([shlex.split(arg) for arg in args.mpi_arg], []),
|
||||
"--",
|
||||
*command,
|
||||
]
|
||||
log(args.verbose, "Running", " ".join(cmd))
|
||||
try:
|
||||
run(cmd)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Launch an MLX distributed program")
|
||||
parser.add_argument(
|
||||
"--print-python",
|
||||
action="store_true",
|
||||
help="Print the path to the current python executable and exit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", action="store_true", help="Print debug messages in stdout"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeat-hosts",
|
||||
"-n",
|
||||
type=positive_number,
|
||||
default=1,
|
||||
help="Repeat each host a given number of times",
|
||||
)
|
||||
parser.add_argument("--hostfile", help="The file containing the hosts")
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
choices=["ring", "mpi", "nccl", "jaccl"],
|
||||
default="nccl" if mx.cuda.is_available() else "ring",
|
||||
help="Which distributed backend to launch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
default=[],
|
||||
help="Set environment variables for the jobs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mpi-arg",
|
||||
action="append",
|
||||
default=[],
|
||||
help="Arguments to pass directly to mpirun",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--connections-per-ip",
|
||||
default=1,
|
||||
type=int,
|
||||
help="How many connections per ip to use for the ring backend",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--starting-port",
|
||||
"-p",
|
||||
type=int,
|
||||
default=32323,
|
||||
help="For the ring backend listen on this port increasing by 1 per rank and IP",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cwd", help="Set the working directory on each node to the provided one"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nccl-port",
|
||||
type=int,
|
||||
default=12345,
|
||||
help="The port to use for the NCCL communication (only for nccl backend)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-verify-script",
|
||||
action="store_false",
|
||||
dest="verify_script",
|
||||
help="Do not verify that the script exists",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--python", default=sys.executable, help="Use this python on the remote hosts"
|
||||
)
|
||||
|
||||
args, rest = parser.parse_known_args()
|
||||
|
||||
if args.print_python:
|
||||
print(args.python)
|
||||
return
|
||||
|
||||
if len(rest) == 0:
|
||||
parser.error("No script is provided")
|
||||
if rest[0] == "--":
|
||||
rest.pop(0)
|
||||
|
||||
# Try to extract a list of hosts and corresponding ips
|
||||
if args.hostfile is not None:
|
||||
hosts = parse_hostfile(parser, args.hostfile)
|
||||
else:
|
||||
hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts)
|
||||
|
||||
# Check if the script is a file and convert it to a full path
|
||||
if (script := Path(rest[0])).exists() and script.is_file():
|
||||
rest[0:1] = [args.python, str(script.resolve())]
|
||||
elif (command := shutil.which(rest[0])) is not None:
|
||||
rest[0] = command
|
||||
elif args.verify_script:
|
||||
raise ValueError(f"Invalid script or command {rest[0]}")
|
||||
|
||||
# Launch
|
||||
if args.backend == "ring":
|
||||
launch_ring(parser, hosts, args, rest)
|
||||
if args.backend == "mpi":
|
||||
launch_mpi(parser, hosts, args, rest)
|
||||
if args.backend == "nccl":
|
||||
launch_nccl(parser, hosts, args, rest)
|
||||
if args.backend == "jaccl":
|
||||
launch_jaccl(parser, hosts, args, rest)
|
||||
@@ -1,909 +0,0 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import ipaddress
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import shlex
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from queue import Empty as QueueEmpty
|
||||
from queue import Queue
|
||||
from select import select
|
||||
from subprocess import PIPE, Popen, run
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
@dataclass
|
||||
class Host:
|
||||
rank: int
|
||||
ssh_hostname: str
|
||||
ips: list[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThunderboltPort:
|
||||
iface: str
|
||||
uuid: str
|
||||
connected_to: Optional[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThunderboltHost:
|
||||
name: str
|
||||
ports: list[ThunderboltPort]
|
||||
|
||||
|
||||
def parse_hardware_ports(ports_string):
|
||||
ports = {}
|
||||
port_name = None
|
||||
for l in ports_string.decode("utf-8").split("\n"):
|
||||
if l.startswith("Hardware Port:"):
|
||||
port_name = l.strip()[15:]
|
||||
elif l.startswith("Device:"):
|
||||
ports[port_name] = l.strip()[8:]
|
||||
port_name = None
|
||||
return ports
|
||||
|
||||
|
||||
def get_num_nvidia_gpus():
|
||||
result = run(["nvidia-smi", "-L"], capture_output=True, text=True, check=True)
|
||||
return len(result.stdout.strip().split("\n"))
|
||||
|
||||
|
||||
def extract_rings(hosts, index):
|
||||
def usable_port(i, j, used_ports):
|
||||
return (i, j) not in used_ports and hosts[i].ports[j].connected_to is not None
|
||||
|
||||
def dfs(start_node, node, path, visited, used_ports):
|
||||
path.append(node)
|
||||
visited.add(node)
|
||||
for j, p in enumerate(hosts[node].ports):
|
||||
if not usable_port(node, j, used_ports):
|
||||
continue
|
||||
next_node, _ = index[p.connected_to]
|
||||
if next_node == start_node:
|
||||
yield path[:]
|
||||
if next_node not in visited:
|
||||
yield from dfs(start_node, next_node, path, visited, used_ports)
|
||||
path.pop()
|
||||
visited.remove(node)
|
||||
|
||||
# Concretize maps the found cycle to real thunderbolt ports. It also adds
|
||||
# those ports to the used set so next cycles can't use them again.
|
||||
def concretize(cycle, used_ports):
|
||||
concrete_path = []
|
||||
for n1, n2 in zip(cycle, cycle[1:] + cycle[:1]):
|
||||
for j, p in enumerate(hosts[n1].ports):
|
||||
if not usable_port(n1, j, used_ports):
|
||||
continue
|
||||
n2_hat, nj = index[p.connected_to]
|
||||
if n2 == n2_hat:
|
||||
concrete_path.append(((n1, j), (n2, nj)))
|
||||
used_ports.add((n1, j))
|
||||
used_ports.add((n2, nj))
|
||||
break
|
||||
if concrete_path[-1][0][0] != n1:
|
||||
raise RuntimeError("Couldn't concretize the cycle")
|
||||
return concrete_path
|
||||
|
||||
# Normalize tries to ensure that the cycles have the same direction so we can
|
||||
# use them together. We achieve this by selecting the direction such that
|
||||
# the smallest rank hosts connect to larger rank hosts.
|
||||
def normalize(path):
|
||||
small_to_large = sum(1 for p in path if p[0][0] < p[1][0])
|
||||
if small_to_large > len(path) - small_to_large:
|
||||
return path
|
||||
else:
|
||||
return [(p[1], p[0]) for p in path]
|
||||
|
||||
rings = []
|
||||
used_ports = set()
|
||||
for start_node in range(len(hosts)):
|
||||
while True:
|
||||
ring = []
|
||||
for r in dfs(start_node, start_node, [], set(), used_ports):
|
||||
if len(r) > len(ring):
|
||||
ring = r
|
||||
# Break early since we won't find a bigger ring no matter what
|
||||
if len(ring) == len(hosts):
|
||||
break
|
||||
if not ring:
|
||||
break
|
||||
try:
|
||||
rings.append(normalize(concretize(ring, used_ports)))
|
||||
except RuntimeError:
|
||||
if len(rings) > 0:
|
||||
return rings
|
||||
raise
|
||||
|
||||
return rings
|
||||
|
||||
|
||||
def positive_number(x):
|
||||
x = int(x)
|
||||
if x <= 0:
|
||||
raise ValueError("Number should be positive")
|
||||
return x
|
||||
|
||||
|
||||
def log(verbose, *args, **kwargs):
|
||||
if not verbose:
|
||||
return
|
||||
print("\033[32m[INFO]", *args, "\033[0m", **kwargs)
|
||||
|
||||
|
||||
def log_warning(*args, **kwargs):
|
||||
kwargs["file"] = sys.stderr
|
||||
print("\033[33m[WARN]", *args, "\033[0m", **kwargs)
|
||||
|
||||
|
||||
def log_error(*args, **kwargs):
|
||||
kwargs["file"] = sys.stderr
|
||||
print("\033[31m[ERROR]", *args, "\033[0m", **kwargs)
|
||||
|
||||
|
||||
def parse_hostfile(parser, hostfile):
|
||||
"""Parse the json hostfile that contains both the hostnames to ssh into and
|
||||
the ips to communicate over when using the ring backend.
|
||||
|
||||
Example:
|
||||
|
||||
[
|
||||
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
|
||||
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
|
||||
...
|
||||
{"ssh": "hostnameN", "ips": ["123.123.123.N"]},
|
||||
]
|
||||
|
||||
Args:
|
||||
hostfile (str): The path to the json file containing the host
|
||||
information
|
||||
"""
|
||||
hostfile = Path(hostfile)
|
||||
if not hostfile.exists():
|
||||
parser.error(f"Hostfile {str(hostfile)} doesn't exist")
|
||||
|
||||
try:
|
||||
hosts = []
|
||||
with open(hostfile) as f:
|
||||
for i, h in enumerate(json.load(f)):
|
||||
hosts.append(Host(i, h["ssh"], h.get("ips", [])))
|
||||
return hosts
|
||||
except Exception as e:
|
||||
parser.error(f"Failed to parse hostfile {str(hostfile)} ({str(e)})")
|
||||
|
||||
|
||||
def parse_hostlist(parser, hostlist, repeats):
|
||||
hosts = []
|
||||
for i, h in enumerate(hostlist.split(",")):
|
||||
if h == "":
|
||||
raise ValueError("Hostname cannot be empty")
|
||||
try:
|
||||
ipaddress.ip_address(h)
|
||||
ips = [h]
|
||||
except ValueError:
|
||||
ips = []
|
||||
for i in range(repeats):
|
||||
hosts.append(Host(i, h, ips))
|
||||
return hosts
|
||||
|
||||
|
||||
def make_monitor_script(rank, hostfile, cwd, env, command, verbose):
|
||||
# Imports that are used throughout
|
||||
script = ""
|
||||
script += "import os\n"
|
||||
script += "import sys\n"
|
||||
script += "import tempfile\n"
|
||||
script += "from pathlib import Path\n"
|
||||
|
||||
# Write the PID to a file so we can kill the process if needed
|
||||
script += "_, pidfile = tempfile.mkstemp() \n"
|
||||
script += "open(pidfile, 'w').write(str(os.getpid()))\n"
|
||||
script += "print(pidfile, flush=True)\n"
|
||||
|
||||
# Change the working directory if one was requested. Otherwise attempt to
|
||||
# change to the current one but don't fail if it wasn't possible.
|
||||
d = cwd or os.getcwd()
|
||||
script += f"if Path({repr(d)}).exists():\n"
|
||||
script += f" os.chdir({repr(d)})\n"
|
||||
if cwd is not None:
|
||||
script += "else:\n"
|
||||
script += (
|
||||
f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n"
|
||||
)
|
||||
script += f" sys.exit(1)\n"
|
||||
|
||||
# Add the environment variables that were given to us
|
||||
script += "env = dict(os.environ)\n"
|
||||
for e in env:
|
||||
key, *value = e.split("=", maxsplit=1)
|
||||
value = shlex.quote(value[0]) if len(value) > 0 else ""
|
||||
if not all(c.isalnum() or c == "_" for c in key):
|
||||
log_warning(f"'{e}' is an invalid environment variable so it is ignored")
|
||||
continue
|
||||
script += f"env[{repr(key)}] = {repr(value)}\n"
|
||||
|
||||
# Add the environment variables to enable the ring distributed backend
|
||||
if hostfile != "":
|
||||
script += "_, hostfile = tempfile.mkstemp()\n"
|
||||
script += "with open(hostfile, 'w') as f:\n"
|
||||
script += f" f.write({repr(hostfile)})\n"
|
||||
if verbose:
|
||||
script += "env['MLX_RING_VERBOSE'] = '1'\n"
|
||||
script += "env['MLX_HOSTFILE'] = hostfile\n"
|
||||
script += f"env['MLX_RANK'] = '{rank}'\n"
|
||||
script += "\n"
|
||||
|
||||
# Replace the process with the script
|
||||
script += f"command = [{','.join(map(repr, command))}]\n"
|
||||
script += "os.execve(command[0], command, env)\n"
|
||||
|
||||
return script
|
||||
|
||||
|
||||
def launch_ring(parser, hosts, args, command):
|
||||
stop = False
|
||||
exit_codes = [None] * len(hosts)
|
||||
|
||||
def node_thread(rank, host, hostfile, input_queue):
|
||||
is_local = host == "127.0.0.1"
|
||||
script = make_monitor_script(
|
||||
rank, hostfile, args.cwd, args.env, command, args.verbose
|
||||
)
|
||||
script_b64 = base64.b64encode(script.encode()).decode()
|
||||
cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
|
||||
if not is_local:
|
||||
cmd = f"ssh {host} '{cmd}'"
|
||||
p = Popen(
|
||||
cmd,
|
||||
shell=True,
|
||||
stdin=PIPE,
|
||||
stdout=PIPE,
|
||||
stderr=PIPE,
|
||||
)
|
||||
os.set_blocking(p.stdout.fileno(), False)
|
||||
os.set_blocking(p.stderr.fileno(), False)
|
||||
os.set_blocking(p.stdin.fileno(), False)
|
||||
|
||||
# Repeat the stdout and stderr to the local machine
|
||||
to_read = [p.stdout.fileno(), p.stderr.fileno()]
|
||||
to_write = [p.stdin.fileno(), sys.stdout.fileno(), sys.stderr.fileno()]
|
||||
pidfile = ""
|
||||
stdin_buffer = b""
|
||||
stdout_buffer = b""
|
||||
stderr_buffer = b""
|
||||
while p.poll() is None:
|
||||
try:
|
||||
stdin_buffer += input_queue.get_nowait()
|
||||
except QueueEmpty:
|
||||
pass
|
||||
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
|
||||
for fd in rlist:
|
||||
msg = os.read(fd, 8192).decode(errors="ignore")
|
||||
|
||||
# Fetch the PID file first if we haven't already
|
||||
if pidfile == "":
|
||||
pidfile, *msg = msg.split("\n", maxsplit=1)
|
||||
msg = msg[0] if msg else ""
|
||||
|
||||
is_stdout = fd == p.stdout.fileno()
|
||||
if is_stdout:
|
||||
stdout_buffer += msg.encode()
|
||||
else:
|
||||
stderr_buffer += msg.encode()
|
||||
for fd in wlist:
|
||||
if fd == p.stdin.fileno() and len(stdin_buffer) > 0:
|
||||
n = os.write(fd, stdin_buffer)
|
||||
stdin_buffer = stdin_buffer[n:]
|
||||
elif fd == sys.stdout.fileno() and len(stdout_buffer) > 0:
|
||||
n = os.write(fd, stdout_buffer)
|
||||
stdout_buffer = stdout_buffer[n:]
|
||||
elif fd == sys.stderr.fileno() and len(stderr_buffer) > 0:
|
||||
n = os.write(fd, stderr_buffer)
|
||||
stderr_buffer = stderr_buffer[n:]
|
||||
if stop:
|
||||
p.terminate()
|
||||
break
|
||||
p.wait()
|
||||
exit_codes[rank] = p.returncode
|
||||
|
||||
# Kill the remote program if possible
|
||||
cmd = ""
|
||||
cmd += f"pid=$(cat {pidfile}); "
|
||||
cmd += "if ps -p $pid >/dev/null; then "
|
||||
cmd += " kill $pid; "
|
||||
cmd += " echo 1; "
|
||||
cmd += "else "
|
||||
cmd += " echo 0; "
|
||||
cmd += "fi; "
|
||||
cmd += f"rm {pidfile}"
|
||||
if not is_local:
|
||||
cmd = f"ssh {host} '{cmd}'"
|
||||
c = run(cmd, check=True, shell=True, capture_output=True, text=True)
|
||||
if c.stdout.strip() == "1":
|
||||
log_warning(f"Node with rank {rank} was killed")
|
||||
elif p.returncode != 0:
|
||||
log_warning(f"Node with rank {rank} exited with code {p.returncode}")
|
||||
else:
|
||||
log(args.verbose, f"Node with rank {rank} completed")
|
||||
|
||||
if all(len(h.ips) == 0 for h in hosts):
|
||||
parser.error(
|
||||
"The ring backend requires IPs to be provided instead of hostnames"
|
||||
)
|
||||
|
||||
port = args.starting_port
|
||||
ring_hosts = []
|
||||
for h in hosts:
|
||||
node = []
|
||||
for ip in h.ips:
|
||||
for i in range(args.connections_per_ip):
|
||||
node.append(f"{ip}:{port}")
|
||||
port += 1
|
||||
ring_hosts.append(node)
|
||||
hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else ""
|
||||
|
||||
log(args.verbose, "Running", shlex.join(command))
|
||||
|
||||
input_queues = []
|
||||
threads = []
|
||||
for i, h in enumerate(hosts):
|
||||
if i + 1 == len(hosts):
|
||||
time.sleep(1.0)
|
||||
input_queues.append(Queue())
|
||||
t = threading.Thread(
|
||||
target=node_thread, args=(i, h.ssh_hostname, hostfile, input_queues[-1])
|
||||
)
|
||||
t.start()
|
||||
threads.append(t)
|
||||
|
||||
os.set_blocking(sys.stdin.fileno(), False)
|
||||
while not stop:
|
||||
rlist, _, _ = select([sys.stdin.fileno()], [], [], 1.0)
|
||||
for fd in rlist:
|
||||
stdin_buffer = os.read(fd, 8192)
|
||||
for q in input_queues:
|
||||
q.put(stdin_buffer)
|
||||
if any(t.is_alive() for t in threads):
|
||||
for i, t in enumerate(threads):
|
||||
if not t.is_alive():
|
||||
if exit_codes[i] != 0:
|
||||
stop = True
|
||||
break
|
||||
else:
|
||||
break
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
|
||||
def get_mpi_libname():
|
||||
try:
|
||||
ompi_info = run(["which", "ompi_info"], check=True, capture_output=True)
|
||||
ompi_info = ompi_info.stdout.strip().decode()
|
||||
|
||||
if platform.system() == "Darwin":
|
||||
otool_output = run(
|
||||
["otool", "-L", ompi_info], check=True, capture_output=True
|
||||
)
|
||||
else:
|
||||
otool_output = run(["ldd", ompi_info], check=True, capture_output=True)
|
||||
otool_output = otool_output.stdout.decode()
|
||||
|
||||
# StopIteration if not found
|
||||
libmpi_line = next(
|
||||
filter(lambda line: "libmpi" in line, otool_output.splitlines())
|
||||
)
|
||||
return libmpi_line.strip().split()[0].removeprefix("@rpath/")
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def launch_mpi(parser, hosts, args, command):
|
||||
mpirun = run(["which", "mpirun"], check=True, capture_output=True)
|
||||
mpirun = mpirun.stdout.strip().decode()
|
||||
|
||||
# Compatibility with homebrew and pip installs
|
||||
mpi_libname = get_mpi_libname()
|
||||
if mpi_libname is not None:
|
||||
dyld = Path(mpirun).parent.parent / "lib"
|
||||
args.env = [
|
||||
f"DYLD_LIBRARY_PATH={str(dyld)}",
|
||||
f"MLX_MPI_LIBNAME={mpi_libname}",
|
||||
] + args.env
|
||||
|
||||
log(args.verbose, f"Using '{mpirun}'")
|
||||
with tempfile.NamedTemporaryFile(mode="w") as f:
|
||||
hosts = Counter((h.ssh_hostname for h in hosts))
|
||||
for h, n in hosts.items():
|
||||
print(f"{h} slots={n}", file=f)
|
||||
f.flush()
|
||||
|
||||
cmd = [
|
||||
mpirun,
|
||||
"--output",
|
||||
":raw", # do not line buffer output
|
||||
"--hostfile",
|
||||
f.name,
|
||||
*(["-cwd", args.cwd] if args.cwd else []),
|
||||
*sum((["-x", e] for e in args.env), []),
|
||||
*sum([shlex.split(arg) for arg in args.mpi_arg], []),
|
||||
"--",
|
||||
*command,
|
||||
]
|
||||
log(args.verbose, "Running", " ".join(cmd))
|
||||
try:
|
||||
run(cmd)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
|
||||
def launch_nccl(parser, hosts, args, command):
|
||||
master_host = hosts[0].ips[0]
|
||||
|
||||
if master_host != "127.0.0.1":
|
||||
raise ValueError("The NCCL backend only supports localhost for now.")
|
||||
master_port = args.nccl_port
|
||||
world_size = len(hosts)
|
||||
|
||||
base_env = os.environ.copy()
|
||||
base_env.update(
|
||||
{
|
||||
"NCCL_DEBUG": base_env.get(
|
||||
"NCCL_DEBUG", "INFO" if args.verbose else "DEBUG"
|
||||
),
|
||||
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
|
||||
"NCCL_HOST_IP": master_host,
|
||||
"NCCL_PORT": str(master_port),
|
||||
"MLX_WORLD_SIZE": str(world_size),
|
||||
}
|
||||
)
|
||||
procs = []
|
||||
num_gpus = get_num_nvidia_gpus()
|
||||
if num_gpus == 0:
|
||||
raise RuntimeError("Cannot run NCCL backend with no GPUs.")
|
||||
if args.repeat_hosts > num_gpus:
|
||||
raise RuntimeError("NCCL requires a separate GPU per process.")
|
||||
|
||||
try:
|
||||
for rank in range(world_size):
|
||||
env = base_env.copy()
|
||||
mlx_rank = str(rank % args.repeat_hosts)
|
||||
env["MLX_RANK"] = mlx_rank
|
||||
env["CUDA_VISIBLE_DEVICES"] = mlx_rank
|
||||
p = Popen(command, env=env)
|
||||
procs.append(p)
|
||||
|
||||
for p in procs:
|
||||
ret = p.wait()
|
||||
if ret != 0:
|
||||
raise RuntimeError(f"Rank process exited with {ret}")
|
||||
|
||||
except (RuntimeError, KeyboardInterrupt) as err:
|
||||
for p in procs:
|
||||
if p.poll() is None:
|
||||
try:
|
||||
p.kill()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
|
||||
|
||||
def check_ssh_connections(hosts):
|
||||
results = [False] * len(hosts)
|
||||
|
||||
def _check(hostname, i):
|
||||
result = run(
|
||||
[
|
||||
"ssh",
|
||||
"-o",
|
||||
"BatchMode=yes",
|
||||
"-o",
|
||||
"ConnectTimeout=5",
|
||||
hostname,
|
||||
"echo",
|
||||
"success",
|
||||
],
|
||||
stdout=PIPE,
|
||||
stderr=PIPE,
|
||||
)
|
||||
results[i] = result.returncode == 0
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=_check, args=(h.ssh_hostname, i))
|
||||
for i, h in enumerate(hosts)
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
if not all(results):
|
||||
log_error("Could not ssh to the following hosts:")
|
||||
for i, h in enumerate(hosts):
|
||||
if not results[i]:
|
||||
log_error(" - ", h.ssh_hostname)
|
||||
log_error()
|
||||
log_error("Maybe they are not set-up for password-less ssh?")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def prepare_tb_ring(args, hosts):
|
||||
log(
|
||||
args.verbose,
|
||||
f"Preparing a thunderbolt ring for {', '.join(h.ssh_hostname for h in hosts)}",
|
||||
)
|
||||
|
||||
# Check that we can ssh
|
||||
check_ssh_connections(hosts)
|
||||
if args.auto_setup and args.verbose:
|
||||
log_warning(
|
||||
"--auto-setup is requested which requires password-less sudo",
|
||||
"on the remote hosts",
|
||||
)
|
||||
|
||||
# Extract the current connectivity from the remote hosts
|
||||
thunderbolt_connections = []
|
||||
for h in hosts:
|
||||
log(args.verbose, "Getting connectivity from", h.ssh_hostname)
|
||||
thunderbolt_connections.append(
|
||||
json.loads(
|
||||
run(
|
||||
[
|
||||
"ssh",
|
||||
h.ssh_hostname,
|
||||
"system_profiler",
|
||||
"SPThunderboltDataType",
|
||||
"-json",
|
||||
],
|
||||
capture_output=True,
|
||||
).stdout
|
||||
)
|
||||
)
|
||||
interface_maps = []
|
||||
for h in hosts:
|
||||
log(args.verbose, "Getting interface names from", h.ssh_hostname)
|
||||
interface_maps.append(
|
||||
parse_hardware_ports(
|
||||
run(
|
||||
[
|
||||
"ssh",
|
||||
h.ssh_hostname,
|
||||
"networksetup",
|
||||
"-listallhardwareports",
|
||||
],
|
||||
capture_output=True,
|
||||
).stdout
|
||||
)
|
||||
)
|
||||
|
||||
# Parse the connectivity into some simple dataclasses
|
||||
tb_hosts = []
|
||||
for c, iface_map in zip(thunderbolt_connections, interface_maps):
|
||||
name = ""
|
||||
ports = []
|
||||
for t in c["SPThunderboltDataType"]:
|
||||
uuid = t.get("domain_uuid_key")
|
||||
if uuid is None:
|
||||
continue
|
||||
name = t["device_name_key"]
|
||||
tag = t["receptacle_1_tag"]["receptacle_id_key"]
|
||||
items = t.get("_items", [])
|
||||
connected_items = [item for item in items if "domain_uuid_key" in item]
|
||||
connected_to = (
|
||||
connected_items[0]["domain_uuid_key"] if connected_items else None
|
||||
)
|
||||
iface = iface_map[f"Thunderbolt {tag}"]
|
||||
ports.append(ThunderboltPort(iface, uuid, connected_to))
|
||||
tb_hosts.append(ThunderboltHost(name, sorted(ports, key=lambda x: x.iface)))
|
||||
|
||||
# Create a reverse index to be able to map uuids to (host, port) quickly
|
||||
uuid_reverse_index = {}
|
||||
for i, h in enumerate(tb_hosts):
|
||||
for j, p in enumerate(h.ports):
|
||||
uuid_reverse_index[p.uuid] = (i, j)
|
||||
|
||||
# Find the rings by simply walking and marking visited (host, port) tuples
|
||||
# and keeping the largest rings greedily.
|
||||
log(args.verbose, "Extracting rings from the parsed connectivity")
|
||||
rings = extract_rings(tb_hosts, uuid_reverse_index)
|
||||
|
||||
# Just output a DOT graphical representation of the found rings
|
||||
if args.dot:
|
||||
names = []
|
||||
for i in range(len(tb_hosts)):
|
||||
n = ""
|
||||
j = i
|
||||
while True:
|
||||
n += chr(97 + j % 26)
|
||||
j //= 26
|
||||
if j == 0:
|
||||
break
|
||||
names.append(n)
|
||||
|
||||
print("graph G {")
|
||||
print(" node [shape=rectangle];")
|
||||
for i, h in enumerate(hosts):
|
||||
print(f' {names[i]} [label="{h.ssh_hostname}"];')
|
||||
for r in rings:
|
||||
for (i, _), (j, _) in r:
|
||||
print(f" {names[i]} -- {names[j]};")
|
||||
print("}")
|
||||
return
|
||||
|
||||
# Assign IPs to each interface such that the interfaces can communicate
|
||||
ips = {}
|
||||
pairs = {}
|
||||
expecting = set()
|
||||
ip0 = 0
|
||||
ip1 = 0
|
||||
netmask = "255.255.255.252"
|
||||
for r in rings:
|
||||
for a, b in r:
|
||||
ips[a] = f"192.168.{ip0}.{ip1 + 1}"
|
||||
ips[b] = f"192.168.{ip0}.{ip1 + 2}"
|
||||
pairs[a] = b
|
||||
pairs[b] = a
|
||||
expecting.add(b)
|
||||
ip1 += 4
|
||||
if ip1 > 255:
|
||||
ip0 += 1
|
||||
ip1 = 0
|
||||
if ip0 > 255:
|
||||
raise ValueError("Ran out of available local IPs for the ring")
|
||||
|
||||
# Extract the host order from the first ring
|
||||
hostmap = dict((r[0][0], r[1][0]) for r in rings[0])
|
||||
first_host = min(hostmap.keys())
|
||||
order = [first_host]
|
||||
while hostmap[order[-1]] != first_host:
|
||||
order.append(hostmap[order[-1]])
|
||||
|
||||
# Create the hostfile
|
||||
hostfile = []
|
||||
for i in order:
|
||||
h = hosts[i]
|
||||
host = {
|
||||
"ssh": h.ssh_hostname,
|
||||
"ips": [
|
||||
ips[i, j]
|
||||
for j, p in enumerate(tb_hosts[i].ports)
|
||||
if (i, j) in expecting
|
||||
],
|
||||
}
|
||||
hostfile.append(host)
|
||||
|
||||
if not args.hostfile_only:
|
||||
for i, h in enumerate(hosts):
|
||||
command = ""
|
||||
command += "sudo ifconfig bridge0 down\n"
|
||||
for j, p in enumerate(tb_hosts[i].ports):
|
||||
if (i, j) not in ips:
|
||||
continue
|
||||
iface = p.iface
|
||||
ip = ips[i, j]
|
||||
peer = ips[pairs[i, j]]
|
||||
command += f"sudo ifconfig {iface} inet {ip} netmask {netmask}\n"
|
||||
command += f"sudo route change {peer} -interface {iface}\n"
|
||||
if args.auto_setup:
|
||||
print(f"Running auto setup for {h.ssh_hostname}")
|
||||
command = command.strip().replace("\n", " && ")
|
||||
command = ["ssh", h.ssh_hostname, command]
|
||||
log(args.verbose, shlex.join(command))
|
||||
run(command)
|
||||
else:
|
||||
msg = f"Setup for {h.ssh_hostname}"
|
||||
print(msg)
|
||||
print("=" * len(msg))
|
||||
print(command)
|
||||
input("Enter to continue")
|
||||
print()
|
||||
|
||||
if args.output_hostfile:
|
||||
with open(args.output_hostfile, "w") as f:
|
||||
json.dump(hostfile, f, indent=4)
|
||||
else:
|
||||
print("Hostfile")
|
||||
print("========")
|
||||
print(json.dumps(hostfile, indent=4))
|
||||
|
||||
|
||||
def prepare_hostfile(args, hosts):
|
||||
log(
|
||||
args.verbose,
|
||||
f"Preparing an ethernet hostfile for {', '.join(h.ssh_hostname for h in hosts)}",
|
||||
)
|
||||
|
||||
# Check that we can ssh
|
||||
check_ssh_connections(hosts)
|
||||
|
||||
# Get the ips for each host
|
||||
for h in hosts:
|
||||
log(args.verbose, "Getting the ip from", h.ssh_hostname)
|
||||
h.ips.append(
|
||||
run(
|
||||
["ssh", h.ssh_hostname, "ipconfig", "getifaddr", "en0"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
).stdout.strip()
|
||||
)
|
||||
|
||||
hostfile = []
|
||||
for h in hosts:
|
||||
hostfile.append(dict(ssh=h.ssh_hostname, ips=h.ips))
|
||||
|
||||
if args.output_hostfile:
|
||||
with open(args.output_hostfile, "w") as f:
|
||||
json.dump(hostfile, f, indent=4)
|
||||
else:
|
||||
print("Hostfile")
|
||||
print("========")
|
||||
print(json.dumps(hostfile, indent=4))
|
||||
|
||||
|
||||
def distributed_config():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Configure remote machines for use with MLX distributed"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", action="store_true", help="Print debug messages in stdout"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
choices=["ring", "mpi", "nccl"],
|
||||
default="nccl" if mx.cuda.is_available() else "ring",
|
||||
help="Which distributed backend to configure",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--over",
|
||||
choices=["thunderbolt", "ethernet"],
|
||||
default="thunderbolt",
|
||||
help="What type of connectivity to configure",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
|
||||
)
|
||||
parser.add_argument("--hostfile", help="The file containing the hosts")
|
||||
parser.add_argument(
|
||||
"--dot", action="store_true", help="Output the topology in DOT format and exit"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hostfile-only", action="store_true", help="If set only compute the hostfile"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-hostfile", help="If provided, save the hostfile to this path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--auto-setup",
|
||||
action="store_true",
|
||||
help="If set we will attempt to automatically configure the machines via ssh",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.backend == "mpi" and args.over == "thunderbolt":
|
||||
raise ValueError(
|
||||
(
|
||||
"The configuration of MPI over thunderbolt is "
|
||||
"not supported yet by mlx.distributed_config"
|
||||
)
|
||||
)
|
||||
|
||||
if args.hostfile is not None:
|
||||
hosts = parse_hostfile(parser, args.hostfile)
|
||||
else:
|
||||
hosts = parse_hostlist(parser, args.hosts, 1)
|
||||
|
||||
if args.over == "thunderbolt":
|
||||
prepare_tb_ring(args, hosts)
|
||||
else:
|
||||
prepare_hostfile(args, hosts)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Launch an MLX distributed program")
|
||||
parser.add_argument(
|
||||
"--print-python",
|
||||
action="store_true",
|
||||
help="Print the path to the current python executable and exit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", action="store_true", help="Print debug messages in stdout"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeat-hosts",
|
||||
"-n",
|
||||
type=positive_number,
|
||||
default=1,
|
||||
help="Repeat each host a given number of times",
|
||||
)
|
||||
parser.add_argument("--hostfile", help="The file containing the hosts")
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
choices=["ring", "mpi", "nccl"],
|
||||
default="nccl" if mx.cuda.is_available() else "ring",
|
||||
help="Which distributed backend to launch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
default=[],
|
||||
help="Set environment variables for the jobs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mpi-arg",
|
||||
action="append",
|
||||
default=[],
|
||||
help="Arguments to pass directly to mpirun",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--connections-per-ip",
|
||||
default=1,
|
||||
type=int,
|
||||
help="How many connections per ip to use for the ring backend",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--starting-port",
|
||||
"-p",
|
||||
type=int,
|
||||
default=5000,
|
||||
help="For the ring backend listen on this port increasing by 1 per rank and IP",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cwd", help="Set the working directory on each node to the provided one"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nccl-port",
|
||||
type=int,
|
||||
default=12345,
|
||||
help="The port to use for the NCCL communication (only for nccl backend)",
|
||||
)
|
||||
|
||||
args, rest = parser.parse_known_args()
|
||||
|
||||
if args.print_python:
|
||||
print(sys.executable)
|
||||
return
|
||||
|
||||
if len(rest) == 0:
|
||||
parser.error("No script is provided")
|
||||
if rest[0] == "--":
|
||||
rest.pop(0)
|
||||
|
||||
# Try to extract a list of hosts and corresponding ips
|
||||
if args.hostfile is not None:
|
||||
hosts = parse_hostfile(parser, args.hostfile)
|
||||
else:
|
||||
hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts)
|
||||
|
||||
# Check if the script is a file and convert it to a full path
|
||||
if (script := Path(rest[0])).exists():
|
||||
rest[0:1] = [sys.executable, str(script.resolve())]
|
||||
elif (command := shutil.which(rest[0])) is not None:
|
||||
rest[0] = command
|
||||
else:
|
||||
raise ValueError(f"Invalid script or command {rest[0]}")
|
||||
|
||||
# Launch
|
||||
if args.backend == "ring":
|
||||
launch_ring(parser, hosts, args, rest)
|
||||
if args.backend == "mpi":
|
||||
launch_mpi(parser, hosts, args, rest)
|
||||
if args.backend == "nccl":
|
||||
launch_nccl(parser, hosts, args, rest)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -52,9 +52,25 @@ void init_distributed(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"is_available",
|
||||
&mx::distributed::is_available,
|
||||
[](const std::string& backend) {
|
||||
return mx::distributed::is_available(backend);
|
||||
},
|
||||
"backend"_a = "any",
|
||||
nb::sig("def is_available(backend: str = 'any') -> bool"),
|
||||
R"pbdoc(
|
||||
Check if a communication backend is available.
|
||||
|
||||
Note, this function returns whether MLX has the capability of
|
||||
instantiating that distributed backend not whether it is possible to
|
||||
create a communication group. For that purpose one should use
|
||||
``init(strict=True)``.
|
||||
|
||||
Args:
|
||||
backend (str, optional): The name of the backend to check for availability.
|
||||
It takes the same values as ``init()``. Default: ``any``.
|
||||
|
||||
Returns:
|
||||
bool: Whether the distributed backend is available.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
@@ -79,10 +95,10 @@ void init_distributed(nb::module_& parent_module) {
|
||||
in case ``mx.distributed.is_available()`` returns False otherwise
|
||||
it throws a runtime error. Default: ``False``
|
||||
backend (str, optional): Which distributed backend to initialize.
|
||||
Possible values ``mpi``, ``ring``, ``nccl``, ``any``. If set to ``any`` all
|
||||
available backends are tried and the first one that succeeds
|
||||
becomes the global group which will be returned in subsequent
|
||||
calls. Default: ``any``
|
||||
Possible values ``mpi``, ``ring``, ``nccl``, ``jaccl``, ``any``. If
|
||||
set to ``any`` all available backends are tried and the first one
|
||||
that succeeds becomes the global group which will be returned in
|
||||
subsequent calls. Default: ``any``
|
||||
|
||||
Returns:
|
||||
Group: The group representing all the launched processes.
|
||||
|
||||
4
setup.py
4
setup.py
@@ -265,8 +265,8 @@ if __name__ == "__main__":
|
||||
}
|
||||
entry_points = {
|
||||
"console_scripts": [
|
||||
"mlx.launch = mlx.distributed_run:main",
|
||||
"mlx.distributed_config = mlx.distributed_run:distributed_config",
|
||||
"mlx.launch = mlx._distributed_utils.launch:main",
|
||||
"mlx.distributed_config = mlx._distributed_utils.config:main",
|
||||
]
|
||||
}
|
||||
install_requires = []
|
||||
|
||||
Reference in New Issue
Block a user