Compare commits

..

2 Commits

Author SHA1 Message Date
Angelos Katharopoulos
997cfc7699 Add a 2-pass col reduce for CUDA (#2863)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-04 15:53:59 -08:00
Awni Hannun
1fa8dc5797 Do a PyPi release for cuda on arm (#2866) 2025-12-04 15:28:29 -08:00
24 changed files with 396 additions and 2214 deletions

View File

@@ -1,6 +1,15 @@
name: 'Build CUDA wheel'
description: 'Build CUDA wheel'
inputs:
arch:
description: 'Platform architecture tag'
required: true
type: choice
options:
- x86_64
- aarch64
runs:
using: "composite"
steps:
@@ -12,4 +21,4 @@ runs:
pip install auditwheel build patchelf setuptools
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
bash python/scripts/repair_cuda.sh
bash python/scripts/repair_cuda.sh ${{ inputs.arch }}

View File

@@ -15,6 +15,7 @@ runs:
using: "composite"
steps:
- name: Use ccache
if: ${{ runner.arch == 'x86_64' }}
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}

View File

@@ -128,7 +128,10 @@ jobs:
build_cuda_release:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22-large
strategy:
matrix:
arch: ['x86_64', 'aarch64']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }}
env:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
@@ -139,6 +142,8 @@ jobs:
toolkit: 'cuda-12.9'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
arch: ${{ matrix.arch }}
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:

View File

@@ -119,10 +119,6 @@ if(MLX_BUILD_METAL)
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-path"
OUTPUT_VARIABLE CMAKE_OSX_SYSROOT
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0)
message(

View File

@@ -89,9 +89,13 @@ template <
int NDIM,
int BM,
int BN,
int N_READS = 4>
__global__ void
col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
int N_READS = 4,
int BLOCKS = 1>
__global__ void col_reduce_looped(
T* in,
U* out,
const __grid_constant__ ColReduceArgs args,
int64_t out_size) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
@@ -102,6 +106,8 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
size_t tile_idx = grid.block_rank();
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
size_t tile_out = tile_y / out_size;
tile_y = tile_y % out_size;
// Compute the indices for the thread within the tile
short thread_x = block.thread_rank() % threads_per_row;
@@ -118,12 +124,23 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
totals[i] = ReduceInit<Op, T>::value();
}
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
size_t total = args.non_col_reductions * args.reduction_size;
size_t per_block, start, end;
if constexpr (BLOCKS > 1) {
per_block = (total + BLOCKS - 1) / BLOCKS;
start = tile_out * per_block + thread_y;
end = min((tile_out + 1) * per_block, total);
} else {
per_block = total;
start = thread_y;
end = total;
}
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(start, args.reduce_shape.data(), args.reduce_strides.data());
if (tile_x * BN + BN <= args.reduction_stride) {
if (args.reduction_stride % N_READS == 0) {
for (size_t r = thread_y; r < total; r += BM) {
for (size_t r = start; r < end; r += BM) {
T vals[N_READS];
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) {
@@ -132,7 +149,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
}
} else {
for (size_t r = thread_y; r < total; r += BM) {
for (size_t r = start; r < end; r += BM) {
T vals[N_READS];
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) {
@@ -142,7 +159,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
}
}
} else {
for (size_t r = thread_y; r < total; r += BM) {
for (size_t r = start; r < end; r += BM) {
T vals[N_READS];
cub::LoadDirectBlocked(
thread_x,
@@ -173,6 +190,9 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
// Write result.
if (warp.thread_rank() == 0) {
if (BLOCKS > 1) {
out += tile_out * out_size * args.reduction_stride;
}
cub::StoreDirectBlocked(
warp.meta_group_rank(),
out + tile_y * args.reduction_stride + tile_x * BN,
@@ -227,11 +247,12 @@ __global__ void col_reduce_small(
inline auto output_grid_for_col_reduce(
const array& out,
const cu::ColReduceArgs& args,
int bn) {
int bn,
int outer = 1) {
int gx, gy = 1;
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
size_t n_outer_blocks = out.size() / args.reduction_stride;
size_t n_blocks = n_outer_blocks * n_inner_blocks;
size_t n_blocks = n_outer_blocks * n_inner_blocks * outer;
while (n_blocks / gy > INT32_MAX) {
gy *= 2;
}
@@ -277,7 +298,8 @@ void col_reduce_looped(
0,
indata,
gpu_ptr<U>(out),
static_cast<cu::ColReduceArgs>(args));
static_cast<cu::ColReduceArgs>(args),
out.size() / args.reduction_stride);
});
});
});
@@ -320,6 +342,117 @@ void col_reduce_small(
});
}
void col_reduce_two_pass(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes, encoder);
// Allocate an intermediate array to hold the 1st pass result
constexpr int outer = 32;
Shape intermediate_shape;
intermediate_shape.push_back(outer);
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end());
Strides intermediate_strides;
intermediate_strides.push_back(out.size());
intermediate_strides.insert(
intermediate_strides.end(), out.strides().begin(), out.strides().end());
array intermediate(intermediate_shape, out.dtype(), nullptr, {});
auto [data_size, rc, cc] =
check_contiguity(intermediate_shape, intermediate_strides);
auto fl = out.flags();
fl.row_contiguous = rc;
fl.col_contiguous = cc;
fl.contiguous = true;
intermediate.set_data(
cu::malloc_async(intermediate.nbytes(), encoder),
data_size,
intermediate_strides,
fl,
allocator::free);
encoder.add_temporary(intermediate);
encoder.set_input_array(in);
encoder.set_output_array(intermediate);
dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(gpu_ptr<T>(in));
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, args, BN, outer);
int blocks = BM * BN / N_READS;
auto kernel = cu::
col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS, outer>;
encoder.add_kernel_node(
kernel,
grid,
blocks,
0,
indata,
gpu_ptr<U>(intermediate),
static_cast<cu::ColReduceArgs>(args),
out.size() / args.reduction_stride);
});
});
});
// Prepare the reduction arguments for the 2nd pass
cu::ColReduceArgs second_args = args;
second_args.reduction_size = outer;
second_args.reduction_stride = out.size();
second_args.ndim = 0;
second_args.reduce_shape[0] = outer;
second_args.reduce_strides[0] = out.size();
second_args.reduce_ndim = 1;
second_args.non_col_reductions = 1;
encoder.set_input_array(intermediate);
encoder.set_output_array(out);
dispatch_all_types(intermediate.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ndim(second_args.reduce_ndim, [&](auto reduce_ndim) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, second_args, BN);
int blocks = BM * BN / N_READS;
auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
encoder.add_kernel_node(
kernel,
grid,
blocks,
0,
gpu_ptr<T>(intermediate),
gpu_ptr<U>(out),
second_args,
second_args.reduction_stride);
});
});
});
}
void col_reduce(
cu::CommandEncoder& encoder,
const array& in,
@@ -334,6 +467,18 @@ void col_reduce(
// It is a general strided reduce. Each threadblock computes the output for
// a subrow of the fast moving axis. For instance 32 elements.
//
// - col_reduce_small
//
// It is a column reduce for small columns. Each thread loops over the whole
// column without communicating with any other thread.
//
// - col_reduce_two_pass
//
// It is a reduce for long columns. To increase parallelism, we split the
// reduction in two passes. First we do a column reduce where many
// threadblocks operate on different parts of the reduced axis. Then we
// perform a final column reduce.
//
// Notes: As in row reduce we opt to read as much in order as possible and
// leave transpositions as they are (contrary to our Metal backend).
//
@@ -349,6 +494,14 @@ void col_reduce(
return;
}
// Long column with smallish row
size_t total_sums = args.non_col_reductions * args.reduction_size;
size_t approx_threads = out.size();
if (total_sums / approx_threads > 32) {
col_reduce_two_pass(encoder, in, out, reduce_type, axes, plan, args);
return;
}
// Fallback col reduce
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
}

View File

@@ -4,11 +4,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
if(MLX_BUILD_CPU AND NOT WIN32)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/jaccl)

View File

@@ -5,7 +5,6 @@
#include "mlx/backend/cuda/cuda.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/jaccl/jaccl.h"
#include "mlx/distributed/mpi/mpi.h"
#include "mlx/distributed/nccl/nccl.h"
#include "mlx/distributed/ring/ring.h"
@@ -103,27 +102,7 @@ class EmptyGroup : public GroupImpl {
} // namespace detail
bool is_available() {
return mpi::is_available() || ring::is_available() || nccl::is_available() ||
jaccl::is_available();
}
bool is_available(const std::string& bk) {
if (bk == "any") {
return is_available();
}
if (bk == "mpi") {
return mpi::is_available();
}
if (bk == "ring") {
return ring::is_available();
}
if (bk == "nccl") {
return nccl::is_available();
}
if (bk == "jaccl") {
return jaccl::is_available();
}
return false;
return mpi::is_available() || ring::is_available() || nccl::is_available();
}
int Group::rank() const {
@@ -156,8 +135,6 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
group = ring::init(strict);
} else if (bk == "nccl") {
group = nccl::init(strict);
} else if (bk == "jaccl") {
group = jaccl::init(strict);
} else if (bk == "any") {
if (mlx::core::cu::is_available()) {
group = nccl::init(false);
@@ -171,17 +148,13 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
group = mpi::init(false);
bk_ = "mpi";
}
if (group == nullptr) {
group = jaccl::init(false);
bk_ = "jaccl";
}
if (group == nullptr && strict) {
throw std::runtime_error("[distributed] Couldn't initialize any backend");
}
} else {
std::ostringstream msg;
msg << "[distributed] The only valid values for backend are 'any', 'mpi', 'nccl', "
<< "'jaccl' and 'ring' but '" << bk << "' was provided.";
msg << "[distributed] The only valid values for backend are 'any', 'mpi' "
<< "and 'ring' but '" << bk << "' was provided.";
throw std::invalid_argument(msg.str());
}

View File

@@ -16,7 +16,6 @@ class GroupImpl;
/* Check if a communication backend is available */
bool is_available();
bool is_available(const std::string& bk);
/**
* A distributed::Group represents a group of independent mlx processes that

View File

@@ -1,8 +0,0 @@
if(MLX_BUILD_CPU
AND ${CMAKE_SYSTEM_NAME} MATCHES "Darwin"
AND MACOS_SDK_VERSION GREATER_EQUAL 26.2)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jaccl.cpp)
target_link_libraries(mlx PRIVATE rdma)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_jaccl.cpp)
endif()

File diff suppressed because it is too large Load Diff

View File

@@ -1,12 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/distributed/distributed.h"
namespace mlx::core::distributed::jaccl {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available();
std::shared_ptr<GroupImpl> init(bool strict = false);
} // namespace mlx::core::distributed::jaccl

View File

@@ -1,20 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/distributed/jaccl/jaccl.h"
namespace mlx::core::distributed::jaccl {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available() {
return false;
}
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
if (strict) {
throw std::runtime_error("Cannot initialize jaccl distributed backend.");
}
return nullptr;
}
} // namespace mlx::core::distributed::jaccl

View File

@@ -1,38 +0,0 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core::distributed::detail {
template <typename T>
struct SumOp {
void operator()(const T* input, T* output, size_t N) const {
while (N-- > 0) {
*output += *input;
input++;
output++;
}
}
};
template <typename T>
struct MaxOp {
void operator()(const T* input, T* output, size_t N) const {
while (N-- > 0) {
*output = std::max(*output, *input);
input++;
output++;
}
}
};
template <typename T>
struct MinOp {
void operator()(const T* input, T* output, size_t N) const {
while (N-- > 0) {
*output = std::min(*output, *input);
input++;
output++;
}
}
};
} // namespace mlx::core::distributed::detail

View File

@@ -1,6 +1,9 @@
// Copyright © 2024 Apple Inc.
#include <arpa/inet.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
#include <unistd.h>
@@ -19,8 +22,6 @@
#include "mlx/backend/cpu/encoder.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/reduction_ops.h"
#include "mlx/distributed/utils.h"
#include "mlx/threadpool.h"
#ifndef SOL_TCP
@@ -93,7 +94,6 @@ constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024;
constexpr const size_t ALL_SUM_BUFFERS = 2;
constexpr const int CONN_ATTEMPTS = 5;
constexpr const int CONN_WAIT = 1000;
constexpr const char* RING_TAG = "[ring]";
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
using json = nlohmann::json;
@@ -296,6 +296,55 @@ class CommunicationThreads {
std::unordered_map<int, SocketThread> threads_;
};
struct address_t {
sockaddr_storage addr;
socklen_t len;
const sockaddr* get() const {
return (struct sockaddr*)&addr;
}
};
/**
* Parse a sockaddr from an ip and port provided as strings.
*/
address_t parse_address(const std::string& ip, const std::string& port) {
struct addrinfo hints, *res;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
if (status != 0) {
std::ostringstream msg;
msg << "Can't parse address " << ip << ":" << port;
throw std::runtime_error(msg.str());
}
address_t result;
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
result.len = res->ai_addrlen;
freeaddrinfo(res);
return result;
}
/**
* Parse a sockaddr provided as an <ip>:<port> string.
*/
address_t parse_address(const std::string& ip_port) {
auto colon = ip_port.find(":");
if (colon == std::string::npos) {
std::ostringstream msg;
msg << "Can't parse address " << ip_port;
throw std::runtime_error(msg.str());
}
std::string ip(ip_port.begin(), ip_port.begin() + colon);
std::string port(ip_port.begin() + colon + 1, ip_port.end());
return parse_address(ip, port);
}
/**
* Load all addresses from the json hostfile. The hostfile is a list of
* addresses in order of rank. For each rank there can be many addresses so
@@ -308,15 +357,15 @@ class CommunicationThreads {
* ["ip3:5000", "ip3:5001"],
* ]
*/
std::vector<std::vector<detail::address_t>> load_nodes(const char* hostfile) {
std::vector<std::vector<detail::address_t>> nodes;
std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
std::vector<std::vector<address_t>> nodes;
std::ifstream f(hostfile);
json hosts = json::parse(f);
for (auto& h : hosts) {
std::vector<detail::address_t> host;
std::vector<address_t> host;
for (auto& ips : h) {
host.push_back(std::move(detail::parse_address(ips.get<std::string>())));
host.push_back(parse_address(ips.get<std::string>()));
}
nodes.push_back(std::move(host));
}
@@ -328,15 +377,73 @@ std::vector<std::vector<detail::address_t>> load_nodes(const char* hostfile) {
* Create a socket and accept one connection for each of the provided
* addresses.
*/
std::vector<int> accept_connections(
const std::vector<detail::address_t>& addresses) {
std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
std::vector<int> sockets;
int success;
for (auto& address : addresses) {
detail::TCPSocket socket(RING_TAG);
socket.listen(RING_TAG, address);
sockets.push_back(socket.accept(RING_TAG).detach());
// Create the socket to wait for connections from the peers
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "[ring] Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Make sure we can launch immediately after shutdown by setting the
// reuseaddr option so that we don't get address already in use errors
int enable = 1;
success = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't enable reuseaddr (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
success = setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't enable reuseport (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Bind the socket to the address and port
success = bind(sock, address.get(), address.len);
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't bind socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Wait for connections
success = listen(sock, 0);
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't listen (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
int peer_socket = accept(sock, nullptr, nullptr);
if (peer_socket < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Accept failed (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Close the listening socket
shutdown(sock, 2);
close(sock);
sockets.push_back(peer_socket);
}
return sockets;
@@ -347,42 +454,93 @@ std::vector<int> accept_connections(
* provided addresses.
*/
std::vector<int> make_connections(
const std::vector<detail::address_t>& addresses,
const std::vector<address_t>& addresses,
bool verbose) {
std::vector<int> sockets;
int success;
for (auto& address : addresses) {
sockets.push_back(detail::TCPSocket::connect(
RING_TAG,
address,
CONN_ATTEMPTS,
CONN_WAIT,
[verbose](int attempt, int wait) {
log_info(
verbose,
"Attempt",
attempt,
"waiting",
wait,
"ms (error:",
errno,
")");
})
.detach());
int sock;
// Attempt to connect to the peer CONN_ATTEMPTS times with exponential
// backoff. TODO: Do we need that?
for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
// Create the socket
sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "[ring] Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
if (attempt > 0) {
int wait = (1 << (attempt - 1)) * CONN_WAIT;
log_info(
verbose,
"Attempt",
attempt,
"wait",
wait,
"ms (error:",
errno,
")");
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
}
success = connect(sock, address.get(), address.len);
if (success == 0) {
break;
}
}
if (success < 0) {
std::ostringstream msg;
msg << "[ring] Couldn't connect (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
sockets.push_back(sock);
}
return sockets;
}
template <typename T>
struct SumOp {
void operator()(const T* input, T* output, size_t N) {
while (N-- > 0) {
*output += *input;
input++;
output++;
}
}
};
template <typename T>
struct MaxOp {
void operator()(const T* input, T* output, size_t N) {
while (N-- > 0) {
*output = std::max(*output, *input);
input++;
output++;
}
}
};
template <typename T>
struct MinOp {
void operator()(const T* input, T* output, size_t N) {
while (N-- > 0) {
*output = std::min(*output, *input);
input++;
output++;
}
}
};
} // namespace
class RingGroup : public GroupImpl {
public:
RingGroup(
int rank,
std::vector<std::vector<detail::address_t>> nodes,
bool verbose)
RingGroup(int rank, std::vector<std::vector<address_t>> nodes, bool verbose)
: rank_(rank), verbose_(verbose), pool_(0) {
if (rank_ > 0 && rank_ >= nodes.size()) {
throw std::runtime_error(
@@ -475,17 +633,17 @@ class RingGroup : public GroupImpl {
void all_sum(const array& input, array& output, Stream stream) override {
SWITCH_TYPE(
output, all_reduce<T>(input, output, stream, detail::SumOp<T>()));
output, all_reduce<T, SumOp<T>>(input, output, stream, SumOp<T>()));
}
void all_max(const array& input, array& output, Stream stream) override {
SWITCH_TYPE(
output, all_reduce<T>(input, output, stream, detail::MaxOp<T>()));
output, all_reduce<T, MaxOp<T>>(input, output, stream, MaxOp<T>()));
}
void all_min(const array& input, array& output, Stream stream) override {
SWITCH_TYPE(
output, all_reduce<T>(input, output, stream, detail::MinOp<T>()));
output, all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>()));
}
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {

View File

@@ -1,204 +0,0 @@
// Copyright © 2025 Apple Inc.
#include <netdb.h>
#include <unistd.h>
#include <cstring>
#include <sstream>
#include <thread>
#include "mlx/distributed/utils.h"
namespace mlx::core::distributed::detail {
/**
* Parse a sockaddr from an ip and port provided as strings.
*/
address_t parse_address(const std::string& ip, const std::string& port) {
struct addrinfo hints, *res;
std::memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
if (status != 0) {
std::ostringstream msg;
msg << "Can't parse address " << ip << ":" << port;
throw std::runtime_error(msg.str());
}
address_t result;
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
result.len = res->ai_addrlen;
freeaddrinfo(res);
return result;
}
/**
* Parse a sockaddr provided as an <ip>:<port> string.
*/
address_t parse_address(const std::string& ip_port) {
auto colon = ip_port.find(":");
if (colon == std::string::npos) {
std::ostringstream msg;
msg << "Can't parse address " << ip_port;
throw std::runtime_error(msg.str());
}
std::string ip(ip_port.begin(), ip_port.begin() + colon);
std::string port(ip_port.begin() + colon + 1, ip_port.end());
return parse_address(ip, port);
}
TCPSocket::TCPSocket(const char* tag) {
sock_ = socket(AF_INET, SOCK_STREAM, 0);
if (sock_ < 0) {
std::ostringstream msg;
msg << tag << " Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
}
TCPSocket::TCPSocket(TCPSocket&& s) {
sock_ = s.sock_;
s.sock_ = -1;
}
TCPSocket& TCPSocket::operator=(TCPSocket&& s) {
if (this != &s) {
sock_ = s.sock_;
s.sock_ = -1;
}
return *this;
}
TCPSocket::TCPSocket(int s) : sock_(s) {}
TCPSocket::~TCPSocket() {
if (sock_ > 0) {
shutdown(sock_, 2);
close(sock_);
}
}
int TCPSocket::detach() {
int s = sock_;
sock_ = -1;
return s;
}
void TCPSocket::listen(const char* tag, const address_t& addr) {
int success;
// Make sure we can launch immediately after shutdown by setting the
// reuseaddr option so that we don't get address already in use errors
int enable = 1;
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't enable reuseaddr (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't enable reuseport (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Bind the socket to the address and port
success = bind(sock_, addr.get(), addr.len);
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't bind socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Prepare waiting for connections
success = ::listen(sock_, 0);
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't listen (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
}
TCPSocket TCPSocket::accept(const char* tag) {
int peer = ::accept(sock_, nullptr, nullptr);
if (peer < 0) {
std::ostringstream msg;
msg << tag << " Accept failed (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
return TCPSocket(peer);
}
void TCPSocket::send(const char* tag, const void* data, size_t len) {
while (len > 0) {
auto n = ::send(sock_, data, len, 0);
if (n <= 0) {
std::ostringstream msg;
msg << tag << " Send failed with errno=" << errno;
throw std::runtime_error(msg.str());
}
len -= n;
data = static_cast<const char*>(data) + n;
}
}
void TCPSocket::recv(const char* tag, void* data, size_t len) {
while (len > 0) {
auto n = ::recv(sock_, data, len, 0);
if (n <= 0) {
std::ostringstream msg;
msg << tag << " Recv failed with errno=" << errno;
throw std::runtime_error(msg.str());
}
len -= n;
data = static_cast<char*>(data) + n;
}
}
TCPSocket TCPSocket::connect(
const char* tag,
const address_t& addr,
int num_retries,
int wait,
std::function<void(int, int)> cb) {
int sock, success;
// Attempt to connect `num_retries` times with exponential backoff.
for (int attempt = 0; attempt < num_retries; attempt++) {
// Create the socket
sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << tag << " Couldn't create socket to connect (error: " << errno
<< ")";
throw std::runtime_error(msg.str());
}
success = ::connect(sock, addr.get(), addr.len);
if (success == 0) {
break;
}
cb(attempt, wait);
if (wait > 0) {
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
}
wait <<= 1;
}
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't connect (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
return TCPSocket(sock);
}
} // namespace mlx::core::distributed::detail

View File

@@ -1,67 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <sys/socket.h>
#include <functional>
#include <string>
namespace mlx::core::distributed::detail {
struct address_t {
sockaddr_storage addr;
socklen_t len;
const sockaddr* get() const {
return (struct sockaddr*)&addr;
}
};
/**
* Parse a sockaddr from an ip and port provided as strings.
*/
address_t parse_address(const std::string& ip, const std::string& port);
/**
* Parse a sockaddr provided as an <ip>:<port> string.
*/
address_t parse_address(const std::string& ip_port);
/**
* Small wrapper over a TCP socket to simplify initiating connections.
*/
class TCPSocket {
public:
TCPSocket(const char* tag);
TCPSocket(const TCPSocket&) = delete;
TCPSocket& operator=(const TCPSocket&) = delete;
TCPSocket(TCPSocket&& s);
TCPSocket& operator=(TCPSocket&&);
~TCPSocket();
void listen(const char* tag, const address_t& addr);
TCPSocket accept(const char* tag);
void send(const char* tag, const void* data, size_t len);
void recv(const char* tag, void* data, size_t len);
int detach();
operator int() const {
return sock_;
}
static TCPSocket connect(
const char* tag,
const address_t& addr,
int num_retries = 1,
int wait = 0,
std::function<void(int, int)> cb = nullptr);
private:
TCPSocket(int sock);
int sock_;
};
} // namespace mlx::core::distributed::detail

View File

@@ -1,85 +0,0 @@
# Copyright © 2025 Apple Inc.
import ipaddress
import json
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
@dataclass
class Host:
rank: int
ssh_hostname: str
ips: list[str]
rdma: list[Optional[str]]
def positive_number(x):
x = int(x)
if x <= 0:
raise ValueError("Number should be positive")
return x
def log(verbose, *args, **kwargs):
if not verbose:
return
print("\033[32m[INFO]", *args, "\033[0m", **kwargs)
def log_warning(*args, **kwargs):
kwargs["file"] = sys.stderr
print("\033[33m[WARN]", *args, "\033[0m", **kwargs)
def log_error(*args, **kwargs):
kwargs["file"] = sys.stderr
print("\033[31m[ERROR]", *args, "\033[0m", **kwargs)
def parse_hostlist(parser, hostlist, repeats):
hosts = []
for i, h in enumerate(hostlist.split(",")):
if h == "":
raise ValueError("Hostname cannot be empty")
try:
ipaddress.ip_address(h)
ips = [h]
except ValueError:
ips = []
for i in range(repeats):
hosts.append(Host(i, h, ips))
return hosts
def parse_hostfile(parser, hostfile):
"""Parse the json hostfile that contains both the hostnames to ssh into and
the ips to communicate over when using the ring backend.
Example:
[
{"ssh": "hostname1", "ips": ["123.123.123.1"], "rdma": [null, "rdma_en2", "rdma_en3"]},
{"ssh": "hostname2", "ips": ["123.123.123.2"], "rdma": ["rdma_en2", null, "rdma_en3"]},
...
{"ssh": "hostnameN", "ips": ["123.123.123.N"], "rdma": ["rdma_en2", "rdma_en3", null]},
]
Args:
hostfile (str): The path to the json file containing the host
information
"""
hostfile = Path(hostfile)
if not hostfile.exists():
parser.error(f"Hostfile {str(hostfile)} doesn't exist")
try:
hosts = []
with open(hostfile) as f:
for i, h in enumerate(json.load(f)):
hosts.append(Host(i, h["ssh"], h.get("ips", []), h.get("rdma", [])))
return hosts
except Exception as e:
parser.error(f"Failed to parse hostfile {str(hostfile)} ({str(e)})")

View File

@@ -1,540 +0,0 @@
# Copyright © 2025 Apple Inc.
import argparse
import base64
import json
import os
import shlex
import shutil
import sys
import tempfile
import threading
from collections import Counter
from itertools import chain
from pathlib import Path
from queue import Empty as QueueEmpty
from queue import Queue
from select import select
from subprocess import PIPE, Popen, run
import mlx.core as mx
from .common import log, log_warning, parse_hostfile, parse_hostlist, positive_number
class CommandProcess:
@property
def process(self):
"""Return the Popen object that refers to the current command."""
raise NotImplementedError()
@property
def exit_status(self):
"""Return a tuple (returncode, killed) for the command. It should be
(None, None) while the command is running normally."""
raise NotImplementedError()
def preprocess_output(self, data: str, is_stdout=False):
"""Preprocess the output of the command so that extra data can be
capture or the format changed on the fly."""
raise NotImplementedError()
def terminate(self):
"""Terminate or return the exit code."""
raise NotImplementedError()
class RemoteProcess(CommandProcess):
def __init__(self, rank, host, cwd, files, env, command):
is_local = host == "127.0.0.1"
script = RemoteProcess.make_monitor_script(rank, cwd, files, env, command)
script_b64 = base64.b64encode(script.encode()).decode()
cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
if not is_local:
cmd = f"ssh {host} '{cmd}'"
self._host = host
self._pidfile = None
self._is_local = is_local
self._process = Popen(
cmd,
shell=True,
stdin=PIPE,
stdout=PIPE,
stderr=PIPE,
)
self._killed = False
@property
def process(self):
return self._process
@property
def exit_status(self):
return self._process.poll(), self._killed
def preprocess_output(self, data, is_stdout=False):
if self._pidfile is None:
pidfile, *rest = data.split("\n", maxsplit=1)
self._pidfile = pidfile
return rest[0] if rest else ""
return data
def terminate(self):
if self._killed:
return
self._process.terminate()
self._process.wait()
# Kill the remote program if possible
cmd = ""
cmd += f"pid=$(cat {self._pidfile}); "
cmd += "if ps -p $pid >/dev/null; then "
cmd += " kill $pid; "
cmd += " echo 1; "
cmd += "else "
cmd += " echo 0; "
cmd += "fi; "
cmd += f"rm {self._pidfile}"
if not self._is_local:
cmd = f"ssh {self._host} '{cmd}'"
c = run(cmd, check=True, shell=True, capture_output=True, text=True)
self._killed = c.stdout.strip() == "1"
@staticmethod
def make_monitor_script(rank, cwd, files, env, command):
# Imports that are used throughout
script = ""
script += "import os\n"
script += "import sys\n"
script += "import tempfile\n"
script += "from pathlib import Path\n"
# Write the PID to a file so we can kill the process if needed
script += "_, pidfile = tempfile.mkstemp() \n"
script += "open(pidfile, 'w').write(str(os.getpid()))\n"
script += "print(pidfile, flush=True)\n"
# Change the working directory if one was requested. Otherwise attempt to
# change to the current one but don't fail if it wasn't possible.
d = cwd or os.getcwd()
script += f"if Path({repr(d)}).exists():\n"
script += f" os.chdir({repr(d)})\n"
if cwd is not None:
script += "else:\n"
script += f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n"
script += f" sys.exit(1)\n"
# Add the environment variables that were requested
script += "env = dict(os.environ)\n"
for e in env:
key, *value = e.split("=", maxsplit=1)
value = shlex.quote(value[0]) if len(value) > 0 else ""
if not all(c.isalnum() or c == "_" for c in key):
log_warning(
f"'{e}' is an invalid environment variable so it is ignored"
)
continue
script += f"env[{repr(key)}] = {repr(value)}\n"
# Make the temporary files
for env_name, content in files.items():
script += "_, fname = tempfile.mkstemp()\n"
script += "with open(fname, 'w') as f:\n"
script += f" f.write({repr(content)})\n"
script += f"env[{repr(env_name)}] = fname\n"
# Finally add the rank
script += f"env['MLX_RANK'] = '{rank}'\n"
script += "\n"
# Replace the process with the script
script += f"command = [{','.join(map(repr, command))}]\n"
script += "os.execve(command[0], command, env)\n"
return script
def _launch_with_io(command_class, arguments, verbose):
stop = False
exit_codes = [(None, None)] * len(arguments)
def _thread_fn(rank, *args, **kwargs):
stdin_queue = kwargs.pop("stdin_queue")
stdout_queue = kwargs.pop("stdout_queue")
stderr_queue = kwargs.pop("stderr_queue")
command = command_class(rank, *args, **kwargs)
p = command.process
os.set_blocking(p.stdout.fileno(), False)
os.set_blocking(p.stderr.fileno(), False)
os.set_blocking(p.stdin.fileno(), False)
to_read = [p.stdout.fileno(), p.stderr.fileno()]
to_write = [p.stdin.fileno()]
stdin_buffer = b""
while p.poll() is None:
try:
stdin_buffer += stdin_queue.get_nowait()
except QueueEmpty:
pass
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
for fd in rlist:
is_stdout = fd == p.stdout.fileno()
msg = os.read(fd, 8192).decode(errors="ignore")
msg = command.preprocess_output(msg, is_stdout)
if is_stdout:
stdout_queue.put(msg.encode())
else:
stderr_queue.put(msg.encode())
for fd in wlist:
if len(stdin_buffer) > 0:
n = os.write(fd, stdin_buffer)
stdin_buffer = stdin_buffer[n:]
if stop:
command.terminate()
break
exit_codes[rank] = command.exit_status
if exit_codes[rank][1]:
log_warning(f"Node with rank {rank} was killed")
elif exit_codes[rank][0] != 0:
log_warning(f"Node with rank {rank} exited with code {exit_codes[rank][0]}")
else:
log(verbose, f"Node with rank {rank} completed")
stdin_queues = []
stdout_queues = []
stderr_queues = []
threads = []
for i, (args, kwargs) in enumerate(arguments):
stdin_queues.append(Queue())
stdout_queues.append(Queue())
stderr_queues.append(Queue())
t = threading.Thread(
target=_thread_fn,
args=args,
kwargs=kwargs
| {
"stdin_queue": stdin_queues[-1],
"stdout_queue": stdout_queues[-1],
"stderr_queue": stderr_queues[-1],
},
)
t.start()
threads.append(t)
os.set_blocking(sys.stdin.fileno(), False)
os.set_blocking(sys.stdout.fileno(), True)
os.set_blocking(sys.stderr.fileno(), True)
while not stop or any(not q.empty() for q in chain(stdout_queues, stderr_queues)):
# Broadcast user input to the jobs
rlist, _, _ = select([sys.stdin.fileno()], [], [], 0.1)
for fd in rlist:
stdin_buffer = os.read(fd, 8192)
for q in stdin_queues:
q.put(stdin_buffer)
# Gather job output
for q in stdout_queues:
try:
while not q.empty():
sys.stdout.buffer.write(q.get_nowait())
except QueueEmpty:
pass
for q in stderr_queues:
try:
while not q.empty():
sys.stderr.buffer.write(q.get_nowait())
except QueueEmpty:
pass
sys.stdout.buffer.flush()
sys.stderr.buffer.flush()
# Check if all are running and terminate otherwise
if any(t.is_alive() for t in threads):
for i, t in enumerate(threads):
if not t.is_alive():
if exit_codes[i][0] != 0:
stop = True
break
else:
break
# Wait for the jobs to finish
for t in threads:
t.join()
# Process any remaining outputs
for q in stdout_queues:
while not q.empty():
sys.stdout.buffer.write(q.get())
for q in stderr_queues:
while not q.empty():
sys.stderr.buffer.write(q.get())
sys.stdout.buffer.flush()
sys.stderr.buffer.flush()
def launch_ring(parser, hosts, args, command):
if any(len(h.ips) == 0 for h in hosts):
parser.error(
"The ring backend requires IPs to be provided instead of hostnames"
)
port = args.starting_port
ring_hosts = []
for h in hosts:
node = []
for ip in h.ips:
for i in range(args.connections_per_ip):
node.append(f"{ip}:{port}")
port += 1
ring_hosts.append(node)
hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else ""
files = {"MLX_HOSTFILE": hostfile}
env = args.env
if args.verbose:
env.append("MLX_RING_VERBOSE=1")
cwd = args.cwd
log(args.verbose, "Running", shlex.join(command))
_launch_with_io(
RemoteProcess,
[
((rank, h.ssh_hostname, cwd, files, env, command), {})
for rank, h in enumerate(hosts)
],
args.verbose,
)
def launch_nccl(parser, hosts, args, command):
if not hosts[0].ips:
raise ValueError("Rank 0 should have an IP reachable from all other ranks")
master_host = hosts[0].ips[0]
master_port = args.nccl_port
world_size = len(hosts)
env = args.env
cwd = args.cwd
if args.verbose:
env.append("NCCL_DEBUG=INFO")
env.append(f"NCCL_HOST_IP={master_host}")
env.append(f"NCCL_PORT={master_port}")
env.append(f"MLX_WORLD_SIZE={world_size}")
log(args.verbose, "Running", shlex.join(command))
_launch_with_io(
RemoteProcess,
[
(
(
rank,
h.ssh_hostname,
cwd,
{},
env + [f"CUDA_VISIBLE_DEVICES={rank % args.repeat_hosts}"],
command,
),
{},
)
for rank, h in enumerate(hosts)
],
args.verbose,
)
def launch_jaccl(parser, hosts, args, command):
if not hosts[0].ips:
raise ValueError("Rank 0 should have an IP reachable from all other ranks")
have_rdmas = all(len(h.rdma) == len(hosts) for h in hosts)
have_nulls = all(h.rdma[i] is None for i, h in enumerate(hosts))
if not have_rdmas or not have_nulls:
raise ValueError("Malformed hostfile for jaccl backend")
coordinator = hosts[0].ips[0]
env = args.env
cwd = args.cwd
env.append(f"MLX_JACCL_COORDINATOR={coordinator}:{args.starting_port}")
files = {"MLX_IBV_DEVICES": json.dumps([h.rdma for h in hosts])}
log(args.verbose, "Running", shlex.join(command))
_launch_with_io(
RemoteProcess,
[
((rank, h.ssh_hostname, cwd, files, env, command), {})
for rank, h in enumerate(hosts)
],
args.verbose,
)
def get_mpi_libname():
try:
ompi_info = run(["which", "ompi_info"], check=True, capture_output=True)
ompi_info = ompi_info.stdout.strip().decode()
if platform.system() == "Darwin":
otool_output = run(
["otool", "-L", ompi_info], check=True, capture_output=True
)
else:
otool_output = run(["ldd", ompi_info], check=True, capture_output=True)
otool_output = otool_output.stdout.decode()
# StopIteration if not found
libmpi_line = next(
filter(lambda line: "libmpi" in line, otool_output.splitlines())
)
return libmpi_line.strip().split()[0].removeprefix("@rpath/")
except:
return None
def launch_mpi(parser, hosts, args, command):
mpirun = run(["which", "mpirun"], check=True, capture_output=True)
mpirun = mpirun.stdout.strip().decode()
# Compatibility with homebrew and pip installs
mpi_libname = get_mpi_libname()
if mpi_libname is not None:
dyld = Path(mpirun).parent.parent / "lib"
args.env = [
f"DYLD_LIBRARY_PATH={str(dyld)}",
f"MLX_MPI_LIBNAME={mpi_libname}",
] + args.env
log(args.verbose, f"Using '{mpirun}'")
with tempfile.NamedTemporaryFile(mode="w") as f:
hosts = Counter((h.ssh_hostname for h in hosts))
for h, n in hosts.items():
print(f"{h} slots={n}", file=f)
f.flush()
cmd = [
mpirun,
"--output",
":raw", # do not line buffer output
"--hostfile",
f.name,
*(["-cwd", args.cwd] if args.cwd else []),
*sum((["-x", e] for e in args.env), []),
*sum([shlex.split(arg) for arg in args.mpi_arg], []),
"--",
*command,
]
log(args.verbose, "Running", " ".join(cmd))
try:
run(cmd)
except KeyboardInterrupt:
pass
def main():
parser = argparse.ArgumentParser(description="Launch an MLX distributed program")
parser.add_argument(
"--print-python",
action="store_true",
help="Print the path to the current python executable and exit",
)
parser.add_argument(
"--verbose", action="store_true", help="Print debug messages in stdout"
)
parser.add_argument(
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
)
parser.add_argument(
"--repeat-hosts",
"-n",
type=positive_number,
default=1,
help="Repeat each host a given number of times",
)
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--backend",
choices=["ring", "mpi", "nccl", "jaccl"],
default="nccl" if mx.cuda.is_available() else "ring",
help="Which distributed backend to launch",
)
parser.add_argument(
"--env",
action="append",
default=[],
help="Set environment variables for the jobs",
)
parser.add_argument(
"--mpi-arg",
action="append",
default=[],
help="Arguments to pass directly to mpirun",
)
parser.add_argument(
"--connections-per-ip",
default=1,
type=int,
help="How many connections per ip to use for the ring backend",
)
parser.add_argument(
"--starting-port",
"-p",
type=int,
default=32323,
help="For the ring backend listen on this port increasing by 1 per rank and IP",
)
parser.add_argument(
"--cwd", help="Set the working directory on each node to the provided one"
)
parser.add_argument(
"--nccl-port",
type=int,
default=12345,
help="The port to use for the NCCL communication (only for nccl backend)",
)
args, rest = parser.parse_known_args()
if args.print_python:
print(sys.executable)
return
if len(rest) == 0:
parser.error("No script is provided")
if rest[0] == "--":
rest.pop(0)
# Try to extract a list of hosts and corresponding ips
if args.hostfile is not None:
hosts = parse_hostfile(parser, args.hostfile)
else:
hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts)
# Check if the script is a file and convert it to a full path
if (script := Path(rest[0])).exists() and script.is_file():
rest[0:1] = [sys.executable, str(script.resolve())]
elif (command := shutil.which(rest[0])) is not None:
rest[0] = command
else:
raise ValueError(f"Invalid script or command {rest[0]}")
# Launch
if args.backend == "ring":
launch_ring(parser, hosts, args, rest)
if args.backend == "mpi":
launch_mpi(parser, hosts, args, rest)
if args.backend == "nccl":
launch_nccl(parser, hosts, args, rest)
if args.backend == "jaccl":
launch_jaccl(parser, hosts, args, rest)

View File

@@ -832,7 +832,7 @@ def main():
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--backend",
choices=["ring", "mpi", "nccl", "jaccl"],
choices=["ring", "mpi", "nccl"],
default="nccl" if mx.cuda.is_available() else "ring",
help="Which distributed backend to launch",
)
@@ -903,8 +903,6 @@ def main():
launch_mpi(parser, hosts, args, rest)
if args.backend == "nccl":
launch_nccl(parser, hosts, args, rest)
if args.backend == "jaccl":
launch_jaccl(parser, hosts, args, rest)
if __name__ == "__main__":

View File

@@ -1,7 +1,7 @@
#!/bin/bash
auditwheel repair dist/* \
--plat manylinux_2_35_x86_64 \
--plat manylinux_2_35_${1} \
--exclude libcublas* \
--exclude libnvrtc* \
--exclude libcuda* \

View File

@@ -52,25 +52,9 @@ void init_distributed(nb::module_& parent_module) {
m.def(
"is_available",
[](const std::string& backend) {
return mx::distributed::is_available(backend);
},
"backend"_a = "any",
nb::sig("def is_available(backend: str = 'any') -> bool"),
&mx::distributed::is_available,
R"pbdoc(
Check if a communication backend is available.
Note, this function returns whether MLX has the capability of
instantiating that distributed backend not whether it is possible to
create a communication group. For that purpose one should use
``init(strict=True)``.
Args:
backend (str, optional): The name of the backend to check for availability.
It takes the same values as ``init()``. Default: ``any``.
Returns:
bool: Whether the distributed backend is available.
)pbdoc");
m.def(
@@ -95,10 +79,10 @@ void init_distributed(nb::module_& parent_module) {
in case ``mx.distributed.is_available()`` returns False otherwise
it throws a runtime error. Default: ``False``
backend (str, optional): Which distributed backend to initialize.
Possible values ``mpi``, ``ring``, ``nccl``, ``jaccl``, ``any``. If
set to ``any`` all available backends are tried and the first one
that succeeds becomes the global group which will be returned in
subsequent calls. Default: ``any``
Possible values ``mpi``, ``ring``, ``nccl``, ``any``. If set to ``any`` all
available backends are tried and the first one that succeeds
becomes the global group which will be returned in subsequent
calls. Default: ``any``
Returns:
Group: The group representing all the launched processes.

View File

@@ -210,6 +210,14 @@ class TestReduce(mlx_tests.MLXTestCase):
ref = getattr(np, op)(np_arr, axis=axis)
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
def test_long_column(self):
a = (np.random.randn(8192, 64) * 32).astype(np.int32)
b = mx.array(a)
c1 = a.sum(0)
c2 = b.sum(0)
self.assertTrue(np.all(c1 == c2))
if __name__ == "__main__":
mlx_tests.MLXTestRunner(failfast=True)

View File

@@ -257,8 +257,8 @@ if __name__ == "__main__":
}
entry_points = {
"console_scripts": [
"mlx.launch = mlx._distributed_utils.launch:main",
# "mlx.distributed_config = mlx.distributed_run:distributed_config",
"mlx.launch = mlx.distributed_run:main",
"mlx.distributed_config = mlx.distributed_run:distributed_config",
]
}
install_requires = []