Compare commits

..

13 Commits

Author SHA1 Message Date
Angelos Katharopoulos
dd91ee9534 Refactoring launcher 2025-12-08 02:57:50 -08:00
Angelos Katharopoulos
8fab4f0929 Change the name to a fun pun 2025-12-04 14:20:52 -08:00
Angelos Katharopoulos
47af2c8cb0 Add headers for gcc 2025-12-04 14:20:52 -08:00
Angelos Katharopoulos
f40152ebc1 Expose per-backend availability in C++ and python 2025-12-04 14:20:52 -08:00
Angelos Katharopoulos
5d7e6a0642 Add a no_ibv 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
b9b78b1059 Add empty sum_scatter 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
45727b0c02 Add send/recv 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
2444fbdfe9 Make sure that there is space for work completions 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
f3b605e53c Add working reduce and semi-working all gather 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
0388ae3aaf Fix ring 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
d4c1de4a8b Fix side channel initialization for more than 2 peers 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
4dbffb3954 All gather 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
b1a60b2d2d Initial working all reduce 2025-12-04 14:20:51 -08:00
28 changed files with 1063 additions and 1070 deletions

View File

@@ -1,15 +1,6 @@
name: 'Build CUDA wheel'
description: 'Build CUDA wheel'
inputs:
arch:
description: 'Platform architecture tag'
required: true
type: choice
options:
- x86_64
- aarch64
runs:
using: "composite"
steps:
@@ -21,4 +12,4 @@ runs:
pip install auditwheel build patchelf setuptools
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
bash python/scripts/repair_cuda.sh ${{ inputs.arch }}
bash python/scripts/repair_cuda.sh

View File

@@ -15,7 +15,6 @@ runs:
using: "composite"
steps:
- name: Use ccache
if: ${{ runner.arch == 'x86_64' }}
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}

View File

@@ -128,11 +128,7 @@ jobs:
build_cuda_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
arch: ['x86_64', 'aarch64']
toolkit: ['cuda-12.9', 'cuda-13.0']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }}
runs-on: ubuntu-22-large
env:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
@@ -140,11 +136,9 @@ jobs:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: ${{ matrix.toolkit }}
toolkit: 'cuda-12.9'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
arch: ${{ matrix.arch }}
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:

View File

@@ -29,20 +29,17 @@ MLX has a CUDA backend which you can install with:
.. code-block:: shell
pip install mlx[cuda12]
pip install mlx[cuda]
To install the CUDA package from PyPi your system must meet the following
requirements:
- Nvidia architecture >= SM 7.5
- Nvidia architecture >= SM 7.0 (Volta)
- Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35
- Python >= 3.10
For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires
an Nvidia driver >= 580 or an appropriate CUDA compatibility package.
CPU-only (Linux)
^^^^^^^^^^^^^^^^

View File

@@ -1,6 +1,7 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp

24
mlx/allocator.cpp Normal file
View File

@@ -0,0 +1,24 @@
// Copyright © 2023 Apple Inc.
#include <cstdlib>
#include <sstream>
#include "mlx/allocator.h"
namespace mlx::core::allocator {
Buffer malloc(size_t size) {
auto buffer = allocator().malloc(size);
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
void free(Buffer buffer) {
allocator().free(buffer);
}
} // namespace mlx::core::allocator

View File

@@ -28,16 +28,16 @@ class Buffer {
};
};
Buffer malloc(size_t size);
void free(Buffer buffer);
class Allocator {
/** Abstract base class for a memory allocator. */
public:
virtual Buffer malloc(size_t size) = 0;
virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0;
virtual Buffer make_buffer(void* ptr, size_t size) {
return Buffer{nullptr};
};
virtual void release(Buffer buffer) {}
Allocator() = default;
Allocator(const Allocator& other) = delete;
@@ -49,25 +49,4 @@ class Allocator {
Allocator& allocator();
inline Buffer malloc(size_t size) {
return allocator().malloc(size);
}
inline void free(Buffer buffer) {
allocator().free(buffer);
}
// Make a Buffer from a raw pointer of the given size without a copy. If a
// no-copy conversion is not possible then the returned buffer.ptr() will be
// nullptr. Any buffer created with this function must be released with
// release(buffer)
inline Buffer make_buffer(void* ptr, size_t size) {
return allocator().make_buffer(ptr, size);
};
// Release a buffer from the allocator made with make_buffer
inline void release(Buffer buffer) {
allocator().release(buffer);
}
} // namespace mlx::core::allocator

View File

@@ -82,28 +82,6 @@ array::array(std::initializer_list<int> data, Dtype dtype)
init(data.begin());
}
array::array(
void* data,
Shape shape,
Dtype dtype,
const std::function<void(void*)>& deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
auto buffer = allocator::make_buffer(data, nbytes());
if (buffer.ptr() == nullptr) {
set_data(allocator::malloc(nbytes()));
auto ptr = static_cast<char*>(data);
std::copy(ptr, ptr + nbytes(), this->data<char>());
deleter(data);
} else {
auto wrapped_deleter = [deleter](allocator::Buffer buffer) {
auto ptr = buffer.ptr();
allocator::release(buffer);
return deleter(ptr);
};
set_data(buffer, std::move(wrapped_deleter));
}
}
/* Build an array from a shared buffer */
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {

View File

@@ -57,16 +57,6 @@ class array {
Shape shape,
Dtype dtype = TypeToDtype<T>());
/* Build an array from a raw pointer. The constructor will attempt to use the
* input data without a copy. The deleter will be called when the array no
* longer needs the underlying memory - after the array is destroyed in the
* no-copy case and after the copy otherwise. */
explicit array(
void* data,
Shape shape,
Dtype dtype,
const std::function<void(void*)>& deleter);
/* Build an array from a buffer */
explicit array(
allocator::Buffer data,

View File

@@ -20,19 +20,6 @@ constexpr int page_size = 16384;
// Any allocations smaller than this will try to use the small pool
constexpr int small_block_size = 8;
#if CUDART_VERSION >= 13000
inline cudaMemLocation cuda_mem_loc(int i) {
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = i;
return loc;
}
#else
inline int cuda_mem_loc(int i) {
return i;
}
#endif // CUDART_VERSION >= 13000
// The small pool size in bytes. This should be a multiple of the host page
// size and small_block_size.
constexpr int small_pool_size = 4 * page_size;
@@ -48,7 +35,13 @@ SmallSizePool::SmallSizePool() {
int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
for (int i = 0; i < device_count; ++i) {
auto loc = cuda_mem_loc(i);
#if CUDART_VERSION >= 13000
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = i;
#else
int loc = i;
#endif // CUDART_VERSION >= 13000
CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
}
@@ -97,10 +90,9 @@ CudaAllocator::CudaAllocator()
page_size,
[](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) {
size_t free;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total_memory_));
memory_limit_ = total_memory_ * 0.95;
free_limit_ = total_memory_ - memory_limit_;
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.9;
max_pool_size_ = memory_limit_;
int device_count = 0;
@@ -112,10 +104,6 @@ CudaAllocator::CudaAllocator()
cudaStream_t s;
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
free_streams_.push_back(s);
cudaMemPool_t mem_pool;
CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&mem_pool, i));
mem_pools_.push_back(mem_pool);
}
CHECK_CUDA_ERROR(cudaSetDevice(curr));
}
@@ -166,35 +154,23 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
}
lock.unlock();
if (!buf) {
cudaError_t err;
void* data = nullptr;
if (device == -1) {
CHECK_CUDA_ERROR(cudaMallocManaged(&data, size));
err = cudaMallocManaged(&data, size);
} else {
CHECK_CUDA_ERROR(cudaMallocAsync(&data, size, stream));
err = cudaMallocAsync(&data, size, stream);
}
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
if (!data) {
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
return Buffer{nullptr};
}
buf = new CudaBuffer{data, size, device};
}
lock.lock();
// If any cuda memory pool has too much reserved memory, clear some
// memory from the cache. This prevents graph / kernel execution failing
// from OOM
if (get_cache_memory() > 0) {
for (auto p : mem_pools_) {
size_t used = 0;
CHECK_CUDA_ERROR(cudaMemPoolGetAttribute(
p, cudaMemPoolAttrReservedMemCurrent, &used));
if (used > (total_memory_ - free_limit_)) {
buffer_cache_.release_cached_buffers(free_limit_);
break;
}
}
}
}
active_memory_ += buf->size;
peak_memory_ = std::max(active_memory_, peak_memory_);

View File

@@ -71,14 +71,11 @@ class CudaAllocator : public allocator::Allocator {
std::mutex mutex_;
size_t memory_limit_;
size_t free_limit_;
size_t total_memory_;
size_t max_pool_size_;
BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0};
size_t peak_memory_{0};
std::vector<cudaStream_t> free_streams_;
std::vector<cudaMemPool_t> mem_pools_;
SmallSizePool scalar_pool_;
};

View File

@@ -95,14 +95,11 @@ void copy_general_input(
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size();
int work_per_thread = 8;
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4 && dim0 < 8) {
if (dim0 >= 4) {
work_per_thread = 4;
} else if (dim0 < 4) {
work_per_thread = 1;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
@@ -113,10 +110,7 @@ void copy_general_input(
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
if (work_per_thread == 8) {
kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 8>;
} else if (work_per_thread == 4) {
if (work_per_thread == 4) {
kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
}
@@ -133,9 +127,7 @@ void copy_general_input(
});
} else { // ndim >= 4
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
if (work_per_thread == 8) {
kernel = cu::copy_g<InType, OutType, IdxT, 8>;
} else if (work_per_thread == 4) {
if (work_per_thread == 4) {
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
}
encoder.add_kernel_node(

View File

@@ -318,52 +318,46 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
insert_graph_dependencies(GraphNode{node, "K"});
}
std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
// Constructs a key representing the nodes of a sub-graph.
// Also checks if the sub-graph is updatable as CUDA graphs do not get
// updated correctly if a kernel node getting updated has a different cluster
// shape than the node it's being updated with.
std::string key = "(";
bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
// CUDA graphs do not get updated correctly if a kernel node getting updated
// has a different cluster shape than the node it's being updated with.
size_t num_nodes = 0;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));
if (num_nodes == 0) {
return {key + ")", true};
return true;
}
bool is_updatable = true;
std::vector<cudaGraphNode_t> nodes(num_nodes);
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
for (const auto& node : nodes) {
if (!is_updatable) {
break;
}
cudaGraphNodeType type;
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
if (type == cudaGraphNodeTypeGraph) {
// Try to be updatable for a structure like graph -> graph -> kernel
if (num_nodes > 1) {
return false;
}
cudaGraph_t child;
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
auto [subkey, sub_is_updatable] = subgraph_to_key(child);
is_updatable &= sub_is_updatable;
key += subkey;
} else if (type == cudaGraphNodeTypeMemset) {
key += "M";
return is_graph_updatable(child, cluster_dim_x);
} else if (type != cudaGraphNodeTypeKernel) {
is_updatable = false;
return false;
} else {
cudaLaunchAttributeValue cluster_dim;
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
// Only allow dim.x to be greater than 1
// Only dim.x can be greater than 1
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
is_updatable = false;
} else {
key += "K";
key += std::to_string(cluster_dim.clusterDim.x);
return false;
}
// Only one child node allowed when subgraph uses clusters
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
return false;
}
cluster_dim_x = cluster_dim.clusterDim.x;
}
}
key += ")";
return {key, is_updatable};
return true;
}
void CommandEncoder::add_graph_node(cudaGraph_t child) {
@@ -376,10 +370,11 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
return;
}
cudaGraphNode_t node;
auto [sub_graph_key, is_updatable] = subgraph_to_key(child);
is_graph_updatable_ &= is_updatable;
int cluster_dim_x = 0;
is_graph_updatable_ &= is_graph_updatable(child, cluster_dim_x);
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies(GraphNode{node, sub_graph_key});
insert_graph_dependencies(
GraphNode{node, "G" + std::to_string(cluster_dim_x)});
}
bool CommandEncoder::needs_commit() {

View File

@@ -106,7 +106,7 @@ class CommandEncoder {
cudaGraphNode_t node;
// K = kernel
// E = empty
// () = subgraph (with metadata)
// G* = subgraph (with metadata)
// Symbols ':', '-' are reserved as separators
std::string node_type;
std::string id;

View File

@@ -89,13 +89,9 @@ template <
int NDIM,
int BM,
int BN,
int N_READS = 4,
int BLOCKS = 1>
__global__ void col_reduce_looped(
T* in,
U* out,
const __grid_constant__ ColReduceArgs args,
int64_t out_size) {
int N_READS = 4>
__global__ void
col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
@@ -106,8 +102,6 @@ __global__ void col_reduce_looped(
size_t tile_idx = grid.block_rank();
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
size_t tile_out = tile_y / out_size;
tile_y = tile_y % out_size;
// Compute the indices for the thread within the tile
short thread_x = block.thread_rank() % threads_per_row;
@@ -124,23 +118,12 @@ __global__ void col_reduce_looped(
totals[i] = ReduceInit<Op, T>::value();
}
size_t total = args.non_col_reductions * args.reduction_size;
size_t per_block, start, end;
if constexpr (BLOCKS > 1) {
per_block = (total + BLOCKS - 1) / BLOCKS;
start = tile_out * per_block + thread_y;
end = min((tile_out + 1) * per_block, total);
} else {
per_block = total;
start = thread_y;
end = total;
}
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(start, args.reduce_shape.data(), args.reduce_strides.data());
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
size_t total = args.non_col_reductions * args.reduction_size;
if (tile_x * BN + BN <= args.reduction_stride) {
if (args.reduction_stride % N_READS == 0) {
for (size_t r = start; r < end; r += BM) {
for (size_t r = thread_y; r < total; r += BM) {
T vals[N_READS];
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) {
@@ -149,7 +132,7 @@ __global__ void col_reduce_looped(
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
}
} else {
for (size_t r = start; r < end; r += BM) {
for (size_t r = thread_y; r < total; r += BM) {
T vals[N_READS];
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) {
@@ -159,7 +142,7 @@ __global__ void col_reduce_looped(
}
}
} else {
for (size_t r = start; r < end; r += BM) {
for (size_t r = thread_y; r < total; r += BM) {
T vals[N_READS];
cub::LoadDirectBlocked(
thread_x,
@@ -190,9 +173,6 @@ __global__ void col_reduce_looped(
// Write result.
if (warp.thread_rank() == 0) {
if (BLOCKS > 1) {
out += tile_out * out_size * args.reduction_stride;
}
cub::StoreDirectBlocked(
warp.meta_group_rank(),
out + tile_y * args.reduction_stride + tile_x * BN,
@@ -247,12 +227,11 @@ __global__ void col_reduce_small(
inline auto output_grid_for_col_reduce(
const array& out,
const cu::ColReduceArgs& args,
int bn,
int outer = 1) {
int bn) {
int gx, gy = 1;
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
size_t n_outer_blocks = out.size() / args.reduction_stride;
size_t n_blocks = n_outer_blocks * n_inner_blocks * outer;
size_t n_blocks = n_outer_blocks * n_inner_blocks;
while (n_blocks / gy > INT32_MAX) {
gy *= 2;
}
@@ -298,8 +277,7 @@ void col_reduce_looped(
0,
indata,
gpu_ptr<U>(out),
static_cast<cu::ColReduceArgs>(args),
out.size() / args.reduction_stride);
static_cast<cu::ColReduceArgs>(args));
});
});
});
@@ -342,117 +320,6 @@ void col_reduce_small(
});
}
void col_reduce_two_pass(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes, encoder);
// Allocate an intermediate array to hold the 1st pass result
constexpr int outer = 32;
Shape intermediate_shape;
intermediate_shape.push_back(outer);
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end());
Strides intermediate_strides;
intermediate_strides.push_back(out.size());
intermediate_strides.insert(
intermediate_strides.end(), out.strides().begin(), out.strides().end());
array intermediate(intermediate_shape, out.dtype(), nullptr, {});
auto [data_size, rc, cc] =
check_contiguity(intermediate_shape, intermediate_strides);
auto fl = out.flags();
fl.row_contiguous = rc;
fl.col_contiguous = cc;
fl.contiguous = true;
intermediate.set_data(
cu::malloc_async(intermediate.nbytes(), encoder),
data_size,
intermediate_strides,
fl,
allocator::free);
encoder.add_temporary(intermediate);
encoder.set_input_array(in);
encoder.set_output_array(intermediate);
dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(gpu_ptr<T>(in));
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, args, BN, outer);
int blocks = BM * BN / N_READS;
auto kernel = cu::
col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS, outer>;
encoder.add_kernel_node(
kernel,
grid,
blocks,
0,
indata,
gpu_ptr<U>(intermediate),
static_cast<cu::ColReduceArgs>(args),
out.size() / args.reduction_stride);
});
});
});
// Prepare the reduction arguments for the 2nd pass
cu::ColReduceArgs second_args = args;
second_args.reduction_size = outer;
second_args.reduction_stride = out.size();
second_args.ndim = 0;
second_args.reduce_shape[0] = outer;
second_args.reduce_strides[0] = out.size();
second_args.reduce_ndim = 1;
second_args.non_col_reductions = 1;
encoder.set_input_array(intermediate);
encoder.set_output_array(out);
dispatch_all_types(intermediate.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ndim(second_args.reduce_ndim, [&](auto reduce_ndim) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, second_args, BN);
int blocks = BM * BN / N_READS;
auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
encoder.add_kernel_node(
kernel,
grid,
blocks,
0,
gpu_ptr<T>(intermediate),
gpu_ptr<U>(out),
second_args,
second_args.reduction_stride);
});
});
});
}
void col_reduce(
cu::CommandEncoder& encoder,
const array& in,
@@ -467,18 +334,6 @@ void col_reduce(
// It is a general strided reduce. Each threadblock computes the output for
// a subrow of the fast moving axis. For instance 32 elements.
//
// - col_reduce_small
//
// It is a column reduce for small columns. Each thread loops over the whole
// column without communicating with any other thread.
//
// - col_reduce_two_pass
//
// It is a reduce for long columns. To increase parallelism, we split the
// reduction in two passes. First we do a column reduce where many
// threadblocks operate on different parts of the reduced axis. Then we
// perform a final column reduce.
//
// Notes: As in row reduce we opt to read as much in order as possible and
// leave transpositions as they are (contrary to our Metal backend).
//
@@ -494,14 +349,6 @@ void col_reduce(
return;
}
// Long column with smallish row
size_t total_sums = args.non_col_reductions * args.reduction_size;
size_t approx_threads = out.size();
if (total_sums / approx_threads > 32) {
col_reduce_two_pass(encoder, in, out, reduce_type, axes, plan, args);
return;
}
// Fallback col reduce
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
}

View File

@@ -80,6 +80,7 @@ CudaGraph::CudaGraph(cu::Device& device) {
}
void CudaGraph::end_capture(cudaStream_t stream) {
assert(handle_ == nullptr);
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
}

View File

@@ -7,6 +7,8 @@
namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s);
void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream());
}

View File

@@ -149,9 +149,7 @@ Buffer MetalAllocator::malloc(size_t size) {
buf = device_->newBuffer(size, resource_options);
}
if (!buf) {
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
return Buffer{nullptr};
}
lk.lock();
num_resources_++;
@@ -203,32 +201,6 @@ size_t MetalAllocator::size(Buffer buffer) const {
return static_cast<MTL::Buffer*>(buffer.ptr())->length();
}
Buffer MetalAllocator::make_buffer(void* ptr, size_t size) {
auto buf = device_->newBuffer(ptr, size, resource_options, nullptr);
if (!buf) {
return Buffer{nullptr};
}
std::unique_lock lk(mutex_);
residency_set_.insert(buf);
active_memory_ += buf->length();
peak_memory_ = std::max(peak_memory_, active_memory_);
num_resources_++;
return Buffer{static_cast<void*>(buf)};
}
void MetalAllocator::release(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
if (buf == nullptr) {
return;
}
std::unique_lock lk(mutex_);
active_memory_ -= buf->length();
num_resources_--;
lk.unlock();
auto pool = metal::new_scoped_memory_pool();
buf->release();
}
MetalAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of MetalAllocator
// will not be called on exit and buffers in the cache will be leaked. This

View File

@@ -21,9 +21,6 @@ class MetalAllocator : public allocator::Allocator {
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
virtual Buffer make_buffer(void* ptr, size_t size) override;
virtual void release(Buffer buffer) override;
size_t get_active_memory() {
return active_memory_;
};

View File

@@ -25,7 +25,6 @@ class CommonAllocator : public Allocator {
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
size_t get_active_memory() const {
return active_memory_;
};

View File

@@ -1,6 +1,5 @@
# Copyright © 2025 Apple Inc.
import argparse
import ipaddress
import json
import sys
@@ -17,14 +16,6 @@ class Host:
rdma: list[Optional[str]]
class OptionalBoolAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
if option_string.startswith("--no-"):
setattr(namespace, self.dest, False)
else:
setattr(namespace, self.dest, True)
def positive_number(x):
x = int(x)
if x <= 0:
@@ -35,7 +26,6 @@ def positive_number(x):
def log(verbose, *args, **kwargs):
if not verbose:
return
kwargs["file"] = sys.stderr
print("\033[32m[INFO]", *args, "\033[0m", **kwargs)
@@ -60,7 +50,7 @@ def parse_hostlist(parser, hostlist, repeats):
except ValueError:
ips = []
for i in range(repeats):
hosts.append(Host(i, h, ips, []))
hosts.append(Host(i, h, ips))
return hosts

View File

@@ -1,570 +0,0 @@
# Copyright © 2025 Apple Inc.
import argparse
import json
import shlex
import sys
import threading
from collections import defaultdict
from dataclasses import dataclass
from subprocess import DEVNULL, run
from typing import Optional
import mlx.core as mx
from .common import (
Host,
OptionalBoolAction,
log,
log_error,
parse_hostfile,
parse_hostlist,
)
@dataclass
class SSHInfo:
can_ssh: bool
has_sudo: bool
def __bool__(self):
return self.can_ssh
@dataclass
class ThunderboltPort:
iface: str
uuid: str
connected_to: Optional[str]
@dataclass
class ThunderboltHost:
name: str
ports: list[ThunderboltPort]
def add_ethernet_ips(hosts, verbose=False):
# Get the ips for each host
for h in hosts:
log(verbose, "Getting the ip from", h.ssh_hostname)
h.ips.append(
run(
["ssh", h.ssh_hostname, "ipconfig", "getifaddr", "en0"],
capture_output=True,
text=True,
).stdout.strip()
)
def check_rdma(hosts, verbose=False):
# Check whether the hosts are capable of RDMA over thunderbolt
warn = False
for h in hosts:
log(verbose, "Checking that", h.ssh_hostname, "supports RDMA")
rdma_devs = (
run(["ssh", h.ssh_hostname, "ibv_devices"], capture_output=True, text=True)
.stdout.strip()
.split()
)
rdma_devs = [d for d in rdma_devs if d.startswith("rdma_")]
if not rdma_devs:
log_warning(h.ssh_hostname, "does not seem to have RDMA enabled")
warn = True
if warn:
log_warning()
log_warning(
"Some of the hosts don't have RDMA enabled or they don't support RDMA."
)
log_warning()
log_warning(
"See https://ml-explore.github.io/mlx/build/html/usage/distributed.html"
)
log_warning("for instructions on how to enable RDMA.")
def can_auto_setup(hosts, sshinfo, auto_setup=False):
has_sudo = all(info.has_sudo for info in sshinfo)
if not has_sudo and auto_setup:
log_warning(
"Automatic setup requested but the following hosts do not have passwordless sudo"
)
for h, i in zip(hosts, sshinfo):
if not i.has_sudo:
log_warning(" - ", h.ssh_hostname)
return has_sudo
class IPConfigurator:
def __init__(self, hosts, tb_hosts, uuid_reverse_index):
assigned = set()
ips = defaultdict(list)
ip0 = 0
ip1 = 0
for src_node, h in enumerate(tb_hosts):
for src_port, p in enumerate(h.ports):
if not p.connected_to:
continue
if p.connected_to not in uuid_reverse_index:
continue
if (src_node, src_port) in assigned:
continue
dst_node, dst_port = uuid_reverse_index[p.connected_to]
ip_src = f"192.168.{ip0}.{ip1 + 1}"
ip_dst = f"192.168.{ip0}.{ip1 + 2}"
iface_src = p.iface
iface_dst = tb_hosts[dst_node].ports[dst_port].iface
ips[src_node, dst_node].append((iface_src, ip_src))
ips[dst_node, src_node].append((iface_dst, ip_dst))
assigned.add((src_node, src_port))
assigned.add((dst_node, dst_port))
ip1 += 4
if ip1 > 255:
ip0 += 1
ip1 = 0
if ip0 > 255:
raise ValueError("Ran out of available local IPs")
self.ips = ips
self.hosts = hosts
self.tb_hosts = tb_hosts
def setup(self, verbose=False, auto_setup=False):
netmask = "255.255.255.252"
for i, (h, th) in enumerate(zip(self.hosts, self.tb_hosts)):
command = ""
command += "sudo ifconfig bridge0 down\n"
for j in range(len(self.hosts)):
if i == j or (i, j) not in self.ips:
continue
for (iface, ip), (_, peer) in zip(self.ips[i, j], self.ips[j, i]):
command += f"sudo ifconfig {iface} inet {ip} netmask {netmask}\n"
command += f"sudo route change {peer} -interface {iface}\n"
if auto_setup:
print(f"Running auto setup for {h.ssh_hostname}")
command = command.strip().replace("\n", " ; ")
command = ["ssh", h.ssh_hostname, command]
log(verbose, shlex.join(command))
run(command)
else:
msg = f"Setup for {h.ssh_hostname}"
print(msg)
print("=" * len(msg))
print(command)
input("Enter to continue")
print()
def parse_hardware_ports(ports_string):
ports = {}
port_name = None
for l in ports_string.decode("utf-8").split("\n"):
if l.startswith("Hardware Port:"):
port_name = l.strip()[15:]
elif l.startswith("Device:"):
ports[port_name] = l.strip()[8:]
port_name = None
return ports
def extract_connectivity(hosts, verbose):
# Extract the current connectivity from the remote hosts
thunderbolt_connections = []
for h in hosts:
log(verbose, "Getting connectivity from", h.ssh_hostname)
thunderbolt_connections.append(
json.loads(
run(
[
"ssh",
h.ssh_hostname,
"system_profiler",
"SPThunderboltDataType",
"-json",
],
capture_output=True,
).stdout
)
)
interface_maps = []
for h in hosts:
log(verbose, "Getting interface names from", h.ssh_hostname)
interface_maps.append(
parse_hardware_ports(
run(
[
"ssh",
h.ssh_hostname,
"networksetup",
"-listallhardwareports",
],
capture_output=True,
).stdout
)
)
# Parse the connectivity into some simple dataclasses
tb_hosts = []
for c, iface_map in zip(thunderbolt_connections, interface_maps):
name = ""
ports = []
for t in c["SPThunderboltDataType"]:
uuid = t.get("domain_uuid_key")
if uuid is None:
continue
name = t["device_name_key"]
tag = t["receptacle_1_tag"]["receptacle_id_key"]
items = t.get("_items", [])
connected_items = [item for item in items if "domain_uuid_key" in item]
connected_to = (
connected_items[0]["domain_uuid_key"] if connected_items else None
)
iface = iface_map[f"Thunderbolt {tag}"]
ports.append(ThunderboltPort(iface, uuid, connected_to))
tb_hosts.append(ThunderboltHost(name, sorted(ports, key=lambda x: x.iface)))
# Create a reverse index to be able to map uuids to (host, port) quickly
uuid_reverse_index = {}
for i, h in enumerate(tb_hosts):
for j, p in enumerate(h.ports):
uuid_reverse_index[p.uuid] = (i, j)
return tb_hosts, uuid_reverse_index
def make_connectivity_matrix(tb_hosts, uuid_reverse_index):
connectivity = []
for i, h in enumerate(tb_hosts):
c = [0] * len(tb_hosts)
for p in h.ports:
if p.connected_to in uuid_reverse_index:
j, _ = uuid_reverse_index[p.connected_to]
c[j] += 1
connectivity.append(c)
return connectivity
def tb_connectivity_to_dot(hosts, tb_hosts, uuid_reverse_index):
# Make ids per node
names = []
for i in range(len(tb_hosts)):
n = ""
j = i
while True:
n += chr(97 + j % 26)
j //= 26
if j == 0:
break
names.append(n)
print("graph G {")
print(" node [shape=rectangle];")
for i, h in enumerate(hosts):
print(f' {names[i]} [label="{h.ssh_hostname}"];')
for i, h in enumerate(tb_hosts):
for p in h.ports:
if not p.connected_to:
continue
dst = uuid_reverse_index[p.connected_to]
if dst[0] < i:
continue
print(f" {names[i]} -- {names[dst[0]]}", end="")
print(f' [label="{p.iface}/{tb_hosts[dst[0]].ports[dst[1]].iface}"]')
print("}")
def extract_rings(connectivity):
rings = []
existing_rings = set()
num_nodes = len(connectivity)
def dfs(start_node, node, path, visited):
path.append(node)
visited.add(node)
for j in range(num_nodes):
if connectivity[node][j] <= 0:
continue
if j == start_node:
yield path[:]
if j not in visited:
yield from dfs(start_node, j, path, visited)
path.pop()
visited.remove(node)
for start in range(num_nodes):
for r in dfs(start, start, [], set()):
cnt = min(connectivity[r[i]][r[(i + 1) % len(r)]] for i in range(len(r)))
rkey = tuple(sorted(r))
if rkey not in existing_rings:
rings.append((r, cnt))
existing_rings.add(rkey)
return sorted(rings, key=lambda x: -len(x[0]))
def check_valid_mesh(hosts, connectivity, strict=True):
num_nodes = len(connectivity)
for i in range(num_nodes):
for j in range(num_nodes):
if i == j:
continue
if connectivity[i][j] <= 0:
if strict:
log_error(
f"Incomplete mesh, {hosts[i].ssh_hostname} is not connected to {hosts[j].ssh_hostname}"
)
log_error()
log_error("Try passing --dot to visualize the connectivity")
sys.exit(1)
else:
return False
return True
def check_ssh_connections(hosts):
results = [None] * len(hosts)
def _check(hostname, i):
info = SSHInfo(False, False)
results[i] = info
# Check for ssh
result = run(
[
"ssh",
"-o",
"BatchMode=yes",
"-o",
"ConnectTimeout=5",
hostname,
"echo",
"success",
],
stdout=DEVNULL,
stderr=DEVNULL,
)
info.can_ssh = result.returncode == 0
if not info.can_ssh:
return
# Check for sudo
result = run(
[
"ssh",
"-o",
"BatchMode=yes",
"-o",
"ConnectTimeout=5",
hostname,
"sudo",
"ls",
],
stdout=DEVNULL,
stderr=DEVNULL,
)
info.has_sudo = result.returncode == 0
threads = [
threading.Thread(target=_check, args=(h.ssh_hostname, i))
for i, h in enumerate(hosts)
]
for t in threads:
t.start()
for t in threads:
t.join()
if not all(results):
log_error("Could not ssh to the following hosts:")
for i, h in enumerate(hosts):
if not results[i]:
log_error(" - ", h.ssh_hostname)
log_error()
log_error("Maybe they are not set-up for password-less ssh?")
sys.exit(1)
return results
def prepare_ethernet_hostfile(args, hosts):
log(args.verbose, f"Preparing an ethernet hostfile")
add_ethernet_ips(hosts, args.verbose)
hostfile = []
for h in hosts:
hostfile.append(dict(ssh=h.ssh_hostname, ips=h.ips))
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def configure_ring(args, hosts, ips, ring, sshinfo):
log(args.verbose, "Prepare a ring hostfile")
ring, count = ring
hostfile = []
for i, node in enumerate(ring):
h = hosts[node]
peer = ring[i - 1]
hostfile.append(
{
"ssh": h.ssh_hostname,
"ips": [ips.ips[node, peer][c][1] for c in range(count)],
"rdma": [],
}
)
has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup)
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo)
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def configure_jaccl(args, hosts, ips, sshinfo):
log(args.verbose, "Prepare a jaccl hostfile")
check_rdma(hosts, args.verbose)
add_ethernet_ips(hosts, args.verbose)
hostfile = []
for i, h in enumerate(hosts):
rdma = []
for j in range(len(hosts)):
if i == j:
rdma.append(None)
else:
rdma.append(f"rdma_{ips.ips[i, j][0][0]}")
hostfile.append({"ssh": h.ssh_hostname, "ips": h.ips, "rdma": rdma})
has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup)
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo)
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def prepare_tb_hostfile(args, hosts, sshinfo):
log(args.verbose, f"Preparing for communication over thunderbolt")
tb_hosts, uuid_reverse_index = extract_connectivity(hosts, args.verbose)
if args.dot:
tb_connectivity_to_dot(hosts, tb_hosts, uuid_reverse_index)
return
ips = IPConfigurator(hosts, tb_hosts, uuid_reverse_index)
connectivity = make_connectivity_matrix(tb_hosts, uuid_reverse_index)
if args.backend is None:
rings = extract_rings(connectivity)
has_mesh = check_valid_mesh(hosts, connectivity, False)
has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts)
if not has_ring and not has_mesh:
log_error("Neither thunderbolt mesh nor ring found.")
log_error("Perhaps run with --dot to generate a plot of the connectivity.")
sys.exit(1)
elif has_ring:
configure_ring(args, hosts, ips, rings[0], sshinfo)
else:
configure_jaccl(args, hosts, ips, sshinfo)
elif args.backend == "ring":
rings = extract_rings(connectivity)
has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts)
if not has_ring:
log_error("Could not find a full ring.")
log_error()
log_error("Try passing --dot to visualize the connectivity")
if len(rings) > 0:
log_error("Rings found:")
for r in rings:
log_error(f" - {','.join(hosts[i].ssh_hostname for i in r)}")
sys.exit(1)
configure_ring(args, hosts, ips, rings[0], sshinfo)
elif args.backend == "jaccl":
check_valid_mesh(hosts, connectivity)
configure_jaccl(args, hosts, ips, sshinfo)
def main():
parser = argparse.ArgumentParser(
description="Configure remote machines for use with MLX distributed"
)
parser.add_argument(
"--verbose", action="store_true", help="Print debug messages in stdout"
)
parser.add_argument(
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
)
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--over",
choices=["thunderbolt", "ethernet"],
default="thunderbolt",
help="What type of connectivity to configure",
required=True,
)
parser.add_argument(
"--output-hostfile", help="If provided, save the hostfile to this path"
)
parser.add_argument(
"--auto-setup",
"--no-auto-setup",
action=OptionalBoolAction,
nargs=0,
dest="auto_setup",
default=None,
)
parser.add_argument(
"--dot", action="store_true", help="Output the topology in DOT format and exit"
)
parser.add_argument(
"--backend",
choices=["ring", "jaccl"],
default=None,
help="Which distributed backend to configure",
)
args = parser.parse_args()
if args.hostfile is not None:
hosts = parse_hostfile(parser, args.hostfile)
else:
hosts = parse_hostlist(parser, args.hosts, 1)
# Check that we can ssh
log(
args.verbose,
f"Checking for ssh access for {', '.join(h.ssh_hostname for h in hosts)}",
)
sshinfo = check_ssh_connections(hosts)
# Prepare a hostfile for communication over ethernet using the ips of the
# provided hostnames
if args.over == "ethernet":
prepare_ethernet_hostfile(args, hosts)
# Configure the macs for communication over thunderbolt, both via RDMA and IP
else:
prepare_tb_hostfile(args, hosts, sshinfo)

View File

@@ -0,0 +1,911 @@
# Copyright © 2025 Apple Inc.
import argparse
import base64
import ipaddress
import json
import os
import platform
import shlex
import shutil
import sys
import tempfile
import threading
import time
from collections import Counter
from dataclasses import dataclass
from pathlib import Path
from queue import Empty as QueueEmpty
from queue import Queue
from select import select
from subprocess import PIPE, Popen, run
from typing import Optional
import mlx.core as mx
@dataclass
class Host:
rank: int
ssh_hostname: str
ips: list[str]
@dataclass
class ThunderboltPort:
iface: str
uuid: str
connected_to: Optional[str]
@dataclass
class ThunderboltHost:
name: str
ports: list[ThunderboltPort]
def parse_hardware_ports(ports_string):
ports = {}
port_name = None
for l in ports_string.decode("utf-8").split("\n"):
if l.startswith("Hardware Port:"):
port_name = l.strip()[15:]
elif l.startswith("Device:"):
ports[port_name] = l.strip()[8:]
port_name = None
return ports
def get_num_nvidia_gpus():
result = run(["nvidia-smi", "-L"], capture_output=True, text=True, check=True)
return len(result.stdout.strip().split("\n"))
def extract_rings(hosts, index):
def usable_port(i, j, used_ports):
return (i, j) not in used_ports and hosts[i].ports[j].connected_to is not None
def dfs(start_node, node, path, visited, used_ports):
path.append(node)
visited.add(node)
for j, p in enumerate(hosts[node].ports):
if not usable_port(node, j, used_ports):
continue
next_node, _ = index[p.connected_to]
if next_node == start_node:
yield path[:]
if next_node not in visited:
yield from dfs(start_node, next_node, path, visited, used_ports)
path.pop()
visited.remove(node)
# Concretize maps the found cycle to real thunderbolt ports. It also adds
# those ports to the used set so next cycles can't use them again.
def concretize(cycle, used_ports):
concrete_path = []
for n1, n2 in zip(cycle, cycle[1:] + cycle[:1]):
for j, p in enumerate(hosts[n1].ports):
if not usable_port(n1, j, used_ports):
continue
n2_hat, nj = index[p.connected_to]
if n2 == n2_hat:
concrete_path.append(((n1, j), (n2, nj)))
used_ports.add((n1, j))
used_ports.add((n2, nj))
break
if concrete_path[-1][0][0] != n1:
raise RuntimeError("Couldn't concretize the cycle")
return concrete_path
# Normalize tries to ensure that the cycles have the same direction so we can
# use them together. We achieve this by selecting the direction such that
# the smallest rank hosts connect to larger rank hosts.
def normalize(path):
small_to_large = sum(1 for p in path if p[0][0] < p[1][0])
if small_to_large > len(path) - small_to_large:
return path
else:
return [(p[1], p[0]) for p in path]
rings = []
used_ports = set()
for start_node in range(len(hosts)):
while True:
ring = []
for r in dfs(start_node, start_node, [], set(), used_ports):
if len(r) > len(ring):
ring = r
# Break early since we won't find a bigger ring no matter what
if len(ring) == len(hosts):
break
if not ring:
break
try:
rings.append(normalize(concretize(ring, used_ports)))
except RuntimeError:
if len(rings) > 0:
return rings
raise
return rings
def positive_number(x):
x = int(x)
if x <= 0:
raise ValueError("Number should be positive")
return x
def log(verbose, *args, **kwargs):
if not verbose:
return
print("\033[32m[INFO]", *args, "\033[0m", **kwargs)
def log_warning(*args, **kwargs):
kwargs["file"] = sys.stderr
print("\033[33m[WARN]", *args, "\033[0m", **kwargs)
def log_error(*args, **kwargs):
kwargs["file"] = sys.stderr
print("\033[31m[ERROR]", *args, "\033[0m", **kwargs)
def parse_hostfile(parser, hostfile):
"""Parse the json hostfile that contains both the hostnames to ssh into and
the ips to communicate over when using the ring backend.
Example:
[
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
...
{"ssh": "hostnameN", "ips": ["123.123.123.N"]},
]
Args:
hostfile (str): The path to the json file containing the host
information
"""
hostfile = Path(hostfile)
if not hostfile.exists():
parser.error(f"Hostfile {str(hostfile)} doesn't exist")
try:
hosts = []
with open(hostfile) as f:
for i, h in enumerate(json.load(f)):
hosts.append(Host(i, h["ssh"], h.get("ips", [])))
return hosts
except Exception as e:
parser.error(f"Failed to parse hostfile {str(hostfile)} ({str(e)})")
def parse_hostlist(parser, hostlist, repeats):
hosts = []
for i, h in enumerate(hostlist.split(",")):
if h == "":
raise ValueError("Hostname cannot be empty")
try:
ipaddress.ip_address(h)
ips = [h]
except ValueError:
ips = []
for i in range(repeats):
hosts.append(Host(i, h, ips))
return hosts
def make_monitor_script(rank, hostfile, cwd, env, command, verbose):
# Imports that are used throughout
script = ""
script += "import os\n"
script += "import sys\n"
script += "import tempfile\n"
script += "from pathlib import Path\n"
# Write the PID to a file so we can kill the process if needed
script += "_, pidfile = tempfile.mkstemp() \n"
script += "open(pidfile, 'w').write(str(os.getpid()))\n"
script += "print(pidfile, flush=True)\n"
# Change the working directory if one was requested. Otherwise attempt to
# change to the current one but don't fail if it wasn't possible.
d = cwd or os.getcwd()
script += f"if Path({repr(d)}).exists():\n"
script += f" os.chdir({repr(d)})\n"
if cwd is not None:
script += "else:\n"
script += (
f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n"
)
script += f" sys.exit(1)\n"
# Add the environment variables that were given to us
script += "env = dict(os.environ)\n"
for e in env:
key, *value = e.split("=", maxsplit=1)
value = shlex.quote(value[0]) if len(value) > 0 else ""
if not all(c.isalnum() or c == "_" for c in key):
log_warning(f"'{e}' is an invalid environment variable so it is ignored")
continue
script += f"env[{repr(key)}] = {repr(value)}\n"
# Add the environment variables to enable the ring distributed backend
if hostfile != "":
script += "_, hostfile = tempfile.mkstemp()\n"
script += "with open(hostfile, 'w') as f:\n"
script += f" f.write({repr(hostfile)})\n"
if verbose:
script += "env['MLX_RING_VERBOSE'] = '1'\n"
script += "env['MLX_HOSTFILE'] = hostfile\n"
script += f"env['MLX_RANK'] = '{rank}'\n"
script += "\n"
# Replace the process with the script
script += f"command = [{','.join(map(repr, command))}]\n"
script += "os.execve(command[0], command, env)\n"
return script
def launch_ring(parser, hosts, args, command):
stop = False
exit_codes = [None] * len(hosts)
def node_thread(rank, host, hostfile, input_queue):
is_local = host == "127.0.0.1"
script = make_monitor_script(
rank, hostfile, args.cwd, args.env, command, args.verbose
)
script_b64 = base64.b64encode(script.encode()).decode()
cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
if not is_local:
cmd = f"ssh {host} '{cmd}'"
p = Popen(
cmd,
shell=True,
stdin=PIPE,
stdout=PIPE,
stderr=PIPE,
)
os.set_blocking(p.stdout.fileno(), False)
os.set_blocking(p.stderr.fileno(), False)
os.set_blocking(p.stdin.fileno(), False)
# Repeat the stdout and stderr to the local machine
to_read = [p.stdout.fileno(), p.stderr.fileno()]
to_write = [p.stdin.fileno(), sys.stdout.fileno(), sys.stderr.fileno()]
pidfile = ""
stdin_buffer = b""
stdout_buffer = b""
stderr_buffer = b""
while p.poll() is None:
try:
stdin_buffer += input_queue.get_nowait()
except QueueEmpty:
pass
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
for fd in rlist:
msg = os.read(fd, 8192).decode(errors="ignore")
# Fetch the PID file first if we haven't already
if pidfile == "":
pidfile, *msg = msg.split("\n", maxsplit=1)
msg = msg[0] if msg else ""
is_stdout = fd == p.stdout.fileno()
if is_stdout:
stdout_buffer += msg.encode()
else:
stderr_buffer += msg.encode()
for fd in wlist:
if fd == p.stdin.fileno() and len(stdin_buffer) > 0:
n = os.write(fd, stdin_buffer)
stdin_buffer = stdin_buffer[n:]
elif fd == sys.stdout.fileno() and len(stdout_buffer) > 0:
n = os.write(fd, stdout_buffer)
stdout_buffer = stdout_buffer[n:]
elif fd == sys.stderr.fileno() and len(stderr_buffer) > 0:
n = os.write(fd, stderr_buffer)
stderr_buffer = stderr_buffer[n:]
if stop:
p.terminate()
break
p.wait()
exit_codes[rank] = p.returncode
# Kill the remote program if possible
cmd = ""
cmd += f"pid=$(cat {pidfile}); "
cmd += "if ps -p $pid >/dev/null; then "
cmd += " kill $pid; "
cmd += " echo 1; "
cmd += "else "
cmd += " echo 0; "
cmd += "fi; "
cmd += f"rm {pidfile}"
if not is_local:
cmd = f"ssh {host} '{cmd}'"
c = run(cmd, check=True, shell=True, capture_output=True, text=True)
if c.stdout.strip() == "1":
log_warning(f"Node with rank {rank} was killed")
elif p.returncode != 0:
log_warning(f"Node with rank {rank} exited with code {p.returncode}")
else:
log(args.verbose, f"Node with rank {rank} completed")
if all(len(h.ips) == 0 for h in hosts):
parser.error(
"The ring backend requires IPs to be provided instead of hostnames"
)
port = args.starting_port
ring_hosts = []
for h in hosts:
node = []
for ip in h.ips:
for i in range(args.connections_per_ip):
node.append(f"{ip}:{port}")
port += 1
ring_hosts.append(node)
hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else ""
log(args.verbose, "Running", shlex.join(command))
input_queues = []
threads = []
for i, h in enumerate(hosts):
if i + 1 == len(hosts):
time.sleep(1.0)
input_queues.append(Queue())
t = threading.Thread(
target=node_thread, args=(i, h.ssh_hostname, hostfile, input_queues[-1])
)
t.start()
threads.append(t)
os.set_blocking(sys.stdin.fileno(), False)
while not stop:
rlist, _, _ = select([sys.stdin.fileno()], [], [], 1.0)
for fd in rlist:
stdin_buffer = os.read(fd, 8192)
for q in input_queues:
q.put(stdin_buffer)
if any(t.is_alive() for t in threads):
for i, t in enumerate(threads):
if not t.is_alive():
if exit_codes[i] != 0:
stop = True
break
else:
break
for t in threads:
t.join()
def get_mpi_libname():
try:
ompi_info = run(["which", "ompi_info"], check=True, capture_output=True)
ompi_info = ompi_info.stdout.strip().decode()
if platform.system() == "Darwin":
otool_output = run(
["otool", "-L", ompi_info], check=True, capture_output=True
)
else:
otool_output = run(["ldd", ompi_info], check=True, capture_output=True)
otool_output = otool_output.stdout.decode()
# StopIteration if not found
libmpi_line = next(
filter(lambda line: "libmpi" in line, otool_output.splitlines())
)
return libmpi_line.strip().split()[0].removeprefix("@rpath/")
except:
return None
def launch_mpi(parser, hosts, args, command):
mpirun = run(["which", "mpirun"], check=True, capture_output=True)
mpirun = mpirun.stdout.strip().decode()
# Compatibility with homebrew and pip installs
mpi_libname = get_mpi_libname()
if mpi_libname is not None:
dyld = Path(mpirun).parent.parent / "lib"
args.env = [
f"DYLD_LIBRARY_PATH={str(dyld)}",
f"MLX_MPI_LIBNAME={mpi_libname}",
] + args.env
log(args.verbose, f"Using '{mpirun}'")
with tempfile.NamedTemporaryFile(mode="w") as f:
hosts = Counter((h.ssh_hostname for h in hosts))
for h, n in hosts.items():
print(f"{h} slots={n}", file=f)
f.flush()
cmd = [
mpirun,
"--output",
":raw", # do not line buffer output
"--hostfile",
f.name,
*(["-cwd", args.cwd] if args.cwd else []),
*sum((["-x", e] for e in args.env), []),
*sum([shlex.split(arg) for arg in args.mpi_arg], []),
"--",
*command,
]
log(args.verbose, "Running", " ".join(cmd))
try:
run(cmd)
except KeyboardInterrupt:
pass
def launch_nccl(parser, hosts, args, command):
master_host = hosts[0].ips[0]
if master_host != "127.0.0.1":
raise ValueError("The NCCL backend only supports localhost for now.")
master_port = args.nccl_port
world_size = len(hosts)
base_env = os.environ.copy()
base_env.update(
{
"NCCL_DEBUG": base_env.get(
"NCCL_DEBUG", "INFO" if args.verbose else "DEBUG"
),
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
"NCCL_HOST_IP": master_host,
"NCCL_PORT": str(master_port),
"MLX_WORLD_SIZE": str(world_size),
}
)
procs = []
num_gpus = get_num_nvidia_gpus()
if num_gpus == 0:
raise RuntimeError("Cannot run NCCL backend with no GPUs.")
if args.repeat_hosts > num_gpus:
raise RuntimeError("NCCL requires a separate GPU per process.")
try:
for rank in range(world_size):
env = base_env.copy()
mlx_rank = str(rank % args.repeat_hosts)
env["MLX_RANK"] = mlx_rank
env["CUDA_VISIBLE_DEVICES"] = mlx_rank
p = Popen(command, env=env)
procs.append(p)
for p in procs:
ret = p.wait()
if ret != 0:
raise RuntimeError(f"Rank process exited with {ret}")
except (RuntimeError, KeyboardInterrupt) as err:
for p in procs:
if p.poll() is None:
try:
p.kill()
except Exception:
pass
raise
def check_ssh_connections(hosts):
results = [False] * len(hosts)
def _check(hostname, i):
result = run(
[
"ssh",
"-o",
"BatchMode=yes",
"-o",
"ConnectTimeout=5",
hostname,
"echo",
"success",
],
stdout=PIPE,
stderr=PIPE,
)
results[i] = result.returncode == 0
threads = [
threading.Thread(target=_check, args=(h.ssh_hostname, i))
for i, h in enumerate(hosts)
]
for t in threads:
t.start()
for t in threads:
t.join()
if not all(results):
log_error("Could not ssh to the following hosts:")
for i, h in enumerate(hosts):
if not results[i]:
log_error(" - ", h.ssh_hostname)
log_error()
log_error("Maybe they are not set-up for password-less ssh?")
sys.exit(1)
def prepare_tb_ring(args, hosts):
log(
args.verbose,
f"Preparing a thunderbolt ring for {', '.join(h.ssh_hostname for h in hosts)}",
)
# Check that we can ssh
check_ssh_connections(hosts)
if args.auto_setup and args.verbose:
log_warning(
"--auto-setup is requested which requires password-less sudo",
"on the remote hosts",
)
# Extract the current connectivity from the remote hosts
thunderbolt_connections = []
for h in hosts:
log(args.verbose, "Getting connectivity from", h.ssh_hostname)
thunderbolt_connections.append(
json.loads(
run(
[
"ssh",
h.ssh_hostname,
"system_profiler",
"SPThunderboltDataType",
"-json",
],
capture_output=True,
).stdout
)
)
interface_maps = []
for h in hosts:
log(args.verbose, "Getting interface names from", h.ssh_hostname)
interface_maps.append(
parse_hardware_ports(
run(
[
"ssh",
h.ssh_hostname,
"networksetup",
"-listallhardwareports",
],
capture_output=True,
).stdout
)
)
# Parse the connectivity into some simple dataclasses
tb_hosts = []
for c, iface_map in zip(thunderbolt_connections, interface_maps):
name = ""
ports = []
for t in c["SPThunderboltDataType"]:
uuid = t.get("domain_uuid_key")
if uuid is None:
continue
name = t["device_name_key"]
tag = t["receptacle_1_tag"]["receptacle_id_key"]
items = t.get("_items", [])
connected_items = [item for item in items if "domain_uuid_key" in item]
connected_to = (
connected_items[0]["domain_uuid_key"] if connected_items else None
)
iface = iface_map[f"Thunderbolt {tag}"]
ports.append(ThunderboltPort(iface, uuid, connected_to))
tb_hosts.append(ThunderboltHost(name, sorted(ports, key=lambda x: x.iface)))
# Create a reverse index to be able to map uuids to (host, port) quickly
uuid_reverse_index = {}
for i, h in enumerate(tb_hosts):
for j, p in enumerate(h.ports):
uuid_reverse_index[p.uuid] = (i, j)
# Find the rings by simply walking and marking visited (host, port) tuples
# and keeping the largest rings greedily.
log(args.verbose, "Extracting rings from the parsed connectivity")
rings = extract_rings(tb_hosts, uuid_reverse_index)
# Just output a DOT graphical representation of the found rings
if args.dot:
names = []
for i in range(len(tb_hosts)):
n = ""
j = i
while True:
n += chr(97 + j % 26)
j //= 26
if j == 0:
break
names.append(n)
print("graph G {")
print(" node [shape=rectangle];")
for i, h in enumerate(hosts):
print(f' {names[i]} [label="{h.ssh_hostname}"];')
for r in rings:
for (i, _), (j, _) in r:
print(f" {names[i]} -- {names[j]};")
print("}")
return
# Assign IPs to each interface such that the interfaces can communicate
ips = {}
pairs = {}
expecting = set()
ip0 = 0
ip1 = 0
netmask = "255.255.255.252"
for r in rings:
for a, b in r:
ips[a] = f"192.168.{ip0}.{ip1 + 1}"
ips[b] = f"192.168.{ip0}.{ip1 + 2}"
pairs[a] = b
pairs[b] = a
expecting.add(b)
ip1 += 4
if ip1 > 255:
ip0 += 1
ip1 = 0
if ip0 > 255:
raise ValueError("Ran out of available local IPs for the ring")
# Extract the host order from the first ring
hostmap = dict((r[0][0], r[1][0]) for r in rings[0])
first_host = min(hostmap.keys())
order = [first_host]
while hostmap[order[-1]] != first_host:
order.append(hostmap[order[-1]])
# Create the hostfile
hostfile = []
for i in order:
h = hosts[i]
host = {
"ssh": h.ssh_hostname,
"ips": [
ips[i, j]
for j, p in enumerate(tb_hosts[i].ports)
if (i, j) in expecting
],
}
hostfile.append(host)
if not args.hostfile_only:
for i, h in enumerate(hosts):
command = ""
command += "sudo ifconfig bridge0 down\n"
for j, p in enumerate(tb_hosts[i].ports):
if (i, j) not in ips:
continue
iface = p.iface
ip = ips[i, j]
peer = ips[pairs[i, j]]
command += f"sudo ifconfig {iface} inet {ip} netmask {netmask}\n"
command += f"sudo route change {peer} -interface {iface}\n"
if args.auto_setup:
print(f"Running auto setup for {h.ssh_hostname}")
command = command.strip().replace("\n", " && ")
command = ["ssh", h.ssh_hostname, command]
log(args.verbose, shlex.join(command))
run(command)
else:
msg = f"Setup for {h.ssh_hostname}"
print(msg)
print("=" * len(msg))
print(command)
input("Enter to continue")
print()
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def prepare_hostfile(args, hosts):
log(
args.verbose,
f"Preparing an ethernet hostfile for {', '.join(h.ssh_hostname for h in hosts)}",
)
# Check that we can ssh
check_ssh_connections(hosts)
# Get the ips for each host
for h in hosts:
log(args.verbose, "Getting the ip from", h.ssh_hostname)
h.ips.append(
run(
["ssh", h.ssh_hostname, "ipconfig", "getifaddr", "en0"],
capture_output=True,
text=True,
).stdout.strip()
)
hostfile = []
for h in hosts:
hostfile.append(dict(ssh=h.ssh_hostname, ips=h.ips))
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def distributed_config():
parser = argparse.ArgumentParser(
description="Configure remote machines for use with MLX distributed"
)
parser.add_argument(
"--verbose", action="store_true", help="Print debug messages in stdout"
)
parser.add_argument(
"--backend",
choices=["ring", "mpi", "nccl"],
default="nccl" if mx.cuda.is_available() else "ring",
help="Which distributed backend to configure",
)
parser.add_argument(
"--over",
choices=["thunderbolt", "ethernet"],
default="thunderbolt",
help="What type of connectivity to configure",
)
parser.add_argument(
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
)
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--dot", action="store_true", help="Output the topology in DOT format and exit"
)
parser.add_argument(
"--hostfile-only", action="store_true", help="If set only compute the hostfile"
)
parser.add_argument(
"--output-hostfile", help="If provided, save the hostfile to this path"
)
parser.add_argument(
"--auto-setup",
action="store_true",
help="If set we will attempt to automatically configure the machines via ssh",
)
args = parser.parse_args()
if args.backend == "mpi" and args.over == "thunderbolt":
raise ValueError(
(
"The configuration of MPI over thunderbolt is "
"not supported yet by mlx.distributed_config"
)
)
if args.hostfile is not None:
hosts = parse_hostfile(parser, args.hostfile)
else:
hosts = parse_hostlist(parser, args.hosts, 1)
if args.over == "thunderbolt":
prepare_tb_ring(args, hosts)
else:
prepare_hostfile(args, hosts)
def main():
parser = argparse.ArgumentParser(description="Launch an MLX distributed program")
parser.add_argument(
"--print-python",
action="store_true",
help="Print the path to the current python executable and exit",
)
parser.add_argument(
"--verbose", action="store_true", help="Print debug messages in stdout"
)
parser.add_argument(
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
)
parser.add_argument(
"--repeat-hosts",
"-n",
type=positive_number,
default=1,
help="Repeat each host a given number of times",
)
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--backend",
choices=["ring", "mpi", "nccl", "jaccl"],
default="nccl" if mx.cuda.is_available() else "ring",
help="Which distributed backend to launch",
)
parser.add_argument(
"--env",
action="append",
default=[],
help="Set environment variables for the jobs",
)
parser.add_argument(
"--mpi-arg",
action="append",
default=[],
help="Arguments to pass directly to mpirun",
)
parser.add_argument(
"--connections-per-ip",
default=1,
type=int,
help="How many connections per ip to use for the ring backend",
)
parser.add_argument(
"--starting-port",
"-p",
type=int,
default=5000,
help="For the ring backend listen on this port increasing by 1 per rank and IP",
)
parser.add_argument(
"--cwd", help="Set the working directory on each node to the provided one"
)
parser.add_argument(
"--nccl-port",
type=int,
default=12345,
help="The port to use for the NCCL communication (only for nccl backend)",
)
args, rest = parser.parse_known_args()
if args.print_python:
print(sys.executable)
return
if len(rest) == 0:
parser.error("No script is provided")
if rest[0] == "--":
rest.pop(0)
# Try to extract a list of hosts and corresponding ips
if args.hostfile is not None:
hosts = parse_hostfile(parser, args.hostfile)
else:
hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts)
# Check if the script is a file and convert it to a full path
if (script := Path(rest[0])).exists():
rest[0:1] = [sys.executable, str(script.resolve())]
elif (command := shutil.which(rest[0])) is not None:
rest[0] = command
else:
raise ValueError(f"Invalid script or command {rest[0]}")
# Launch
if args.backend == "ring":
launch_ring(parser, hosts, args, rest)
if args.backend == "mpi":
launch_mpi(parser, hosts, args, rest)
if args.backend == "nccl":
launch_nccl(parser, hosts, args, rest)
if args.backend == "jaccl":
launch_jaccl(parser, hosts, args, rest)
if __name__ == "__main__":
main()

View File

@@ -45,11 +45,13 @@ class CommandProcess:
class RemoteProcess(CommandProcess):
def __init__(self, rank, host, python, cwd, files, env, command):
def __init__(self, rank, host, cwd, files, env, command):
is_local = host == "127.0.0.1"
cmd = RemoteProcess.make_launch_script(rank, cwd, files, env, command)
script = RemoteProcess.make_monitor_script(rank, cwd, files, env, command)
script_b64 = base64.b64encode(script.encode()).decode()
cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
if not is_local:
cmd = f"ssh -tt -o LogLevel=QUIET {host} {shlex.quote(cmd)}"
cmd = f"ssh {host} '{cmd}'"
self._host = host
self._pidfile = None
@@ -57,7 +59,6 @@ class RemoteProcess(CommandProcess):
self._process = Popen(
cmd,
shell=True,
executable="/bin/bash",
stdin=PIPE,
stdout=PIPE,
stderr=PIPE,
@@ -89,43 +90,47 @@ class RemoteProcess(CommandProcess):
self._process.wait()
# Kill the remote program if possible
cmd = RemoteProcess.make_kill_script(self._pidfile)
cmd = ""
cmd += f"pid=$(cat {self._pidfile}); "
cmd += "if ps -p $pid >/dev/null; then "
cmd += " kill $pid; "
cmd += " echo 1; "
cmd += "else "
cmd += " echo 0; "
cmd += "fi; "
cmd += f"rm {self._pidfile}"
if not self._is_local:
cmd = f"ssh {self._host} {shlex.quote(cmd)}"
c = run(
cmd,
check=True,
shell=True,
executable="/bin/bash",
capture_output=True,
text=True,
)
cmd = f"ssh {self._host} '{cmd}'"
c = run(cmd, check=True, shell=True, capture_output=True, text=True)
self._killed = c.stdout.strip() == "1"
@staticmethod
def make_launch_script(rank, cwd, files, env, command):
def make_monitor_script(rank, cwd, files, env, command):
# Imports that are used throughout
script = ""
# Disable echo
script = "stty -echo; "
script += "import os\n"
script += "import sys\n"
script += "import tempfile\n"
script += "from pathlib import Path\n"
# Write the PID to a file so we can kill the process if needed
script += "pidfile=$(mktemp); "
script += "echo $$ > $pidfile; "
script += 'printf "%s\\n" $pidfile; '
script += "_, pidfile = tempfile.mkstemp() \n"
script += "open(pidfile, 'w').write(str(os.getpid()))\n"
script += "print(pidfile, flush=True)\n"
# Change the working directory if one was requested. Otherwise attempt to
# change to the current one but don't fail if it wasn't possible.
d = cwd or os.getcwd()
script += f"if [[ -d {repr(d)} ]]; then "
script += f" cd {repr(d)}; "
script += f"if Path({repr(d)}).exists():\n"
script += f" os.chdir({repr(d)})\n"
if cwd is not None:
script += "else "
script += f" echo 'Failed to change directory to' {repr(d)} >2; "
script += "fi; "
script += "else:\n"
script += f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n"
script += f" sys.exit(1)\n"
# Add the environment variables that were requested
script += "env = dict(os.environ)\n"
for e in env:
key, *value = e.split("=", maxsplit=1)
value = shlex.quote(value[0]) if len(value) > 0 else ""
@@ -134,34 +139,22 @@ class RemoteProcess(CommandProcess):
f"'{e}' is an invalid environment variable so it is ignored"
)
continue
script += f"export {key}={value}; "
script += f"env[{repr(key)}] = {repr(value)}\n"
# Make the temporary files
for env_name, content in files.items():
script += "fname=$(mktemp); "
script += f"echo {shlex.quote(content)} >$fname; "
script += f"export {env_name}=$fname; "
script += "_, fname = tempfile.mkstemp()\n"
script += "with open(fname, 'w') as f:\n"
script += f" f.write({repr(content)})\n"
script += f"env[{repr(env_name)}] = fname\n"
# Finally add the rank
script += f"export MLX_RANK={rank}; "
script += f"env['MLX_RANK'] = '{rank}'\n"
script += "\n"
# Replace the process with the script
script += f"cmd=({' '.join(map(shlex.quote, command))}); "
script += 'exec "${cmd[@]}"'
return script
@staticmethod
def make_kill_script(pidfile):
script = ""
script += f"pid=$(cat {pidfile}); "
script += "if ps -p $pid >/dev/null; then "
script += " kill $pid; "
script += " echo 1; "
script += "else "
script += " echo 0; "
script += "fi; "
script += f"rm {pidfile}"
script += f"command = [{','.join(map(repr, command))}]\n"
script += "os.execve(command[0], command, env)\n"
return script
@@ -316,7 +309,7 @@ def launch_ring(parser, hosts, args, command):
_launch_with_io(
RemoteProcess,
[
((rank, h.ssh_hostname, args.python, cwd, files, env, command), {})
((rank, h.ssh_hostname, cwd, files, env, command), {})
for rank, h in enumerate(hosts)
],
args.verbose,
@@ -348,7 +341,6 @@ def launch_nccl(parser, hosts, args, command):
(
rank,
h.ssh_hostname,
args.python,
cwd,
{},
env + [f"CUDA_VISIBLE_DEVICES={rank % args.repeat_hosts}"],
@@ -382,7 +374,7 @@ def launch_jaccl(parser, hosts, args, command):
_launch_with_io(
RemoteProcess,
[
((rank, h.ssh_hostname, args.python, cwd, files, env, command), {})
((rank, h.ssh_hostname, cwd, files, env, command), {})
for rank, h in enumerate(hosts)
],
args.verbose,
@@ -511,20 +503,11 @@ def main():
default=12345,
help="The port to use for the NCCL communication (only for nccl backend)",
)
parser.add_argument(
"--no-verify-script",
action="store_false",
dest="verify_script",
help="Do not verify that the script exists",
)
parser.add_argument(
"--python", default=sys.executable, help="Use this python on the remote hosts"
)
args, rest = parser.parse_known_args()
if args.print_python:
print(args.python)
print(sys.executable)
return
if len(rest) == 0:
@@ -540,10 +523,10 @@ def main():
# Check if the script is a file and convert it to a full path
if (script := Path(rest[0])).exists() and script.is_file():
rest[0:1] = [args.python, str(script.resolve())]
rest[0:1] = [sys.executable, str(script.resolve())]
elif (command := shutil.which(rest[0])) is not None:
rest[0] = command
elif args.verify_script:
else:
raise ValueError(f"Invalid script or command {rest[0]}")
# Launch

View File

@@ -1,7 +1,7 @@
#!/bin/bash
auditwheel repair dist/* \
--plat manylinux_2_35_${1} \
--plat manylinux_2_35_x86_64 \
--exclude libcublas* \
--exclude libnvrtc* \
--exclude libcuda* \

View File

@@ -210,14 +210,6 @@ class TestReduce(mlx_tests.MLXTestCase):
ref = getattr(np, op)(np_arr, axis=axis)
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
def test_long_column(self):
a = (np.random.randn(8192, 64) * 32).astype(np.int32)
b = mx.array(a)
c1 = a.sum(0)
c2 = b.sum(0)
self.assertTrue(np.all(c1 == c2))
if __name__ == "__main__":
mlx_tests.MLXTestRunner(failfast=True)

View File

@@ -7,21 +7,13 @@ import re
import subprocess
from functools import partial
from pathlib import Path
from subprocess import run
from setuptools import Command, Extension, find_namespace_packages, setup
from setuptools.command.bdist_wheel import bdist_wheel
from setuptools.command.build_ext import build_ext
def cuda_toolkit_major_version():
out = subprocess.check_output(["nvcc", "--version"], stderr=subprocess.STDOUT)
text = out.decode()
m = re.search(r"release (\d+)", text)
if m:
return int(m.group(1))
return None
def get_version():
with open("mlx/version.h", "r") as fid:
for l in fid:
@@ -39,7 +31,7 @@ def get_version():
version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}"
if not pypi_release and not dev_release:
git_hash = (
subprocess.run(
run(
"git rev-parse --short HEAD".split(),
capture_output=True,
check=True,
@@ -266,7 +258,7 @@ if __name__ == "__main__":
entry_points = {
"console_scripts": [
"mlx.launch = mlx._distributed_utils.launch:main",
"mlx.distributed_config = mlx._distributed_utils.config:main",
# "mlx.distributed_config = mlx.distributed_run:distributed_config",
]
}
install_requires = []
@@ -292,11 +284,7 @@ if __name__ == "__main__":
install_requires.append(
f'mlx-metal=={version}; platform_system == "Darwin"'
)
extras["cuda"] = [f'mlx-cuda-12=={version}; platform_system == "Linux"']
for toolkit in [12, 13]:
extras[f"cuda{toolkit}"] = [
f'mlx-cuda-{toolkit}=={version}; platform_system == "Linux"'
]
extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"']
extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
_setup(
@@ -311,25 +299,13 @@ if __name__ == "__main__":
if build_macos:
name = "mlx-metal"
elif build_cuda:
toolkit = cuda_toolkit_major_version()
name = f"mlx-cuda-{toolkit}"
if toolkit == 12:
install_requires += [
"nvidia-cublas-cu12==12.9.*",
"nvidia-cuda-nvrtc-cu12==12.9.*",
]
elif toolkit == 13:
install_requires += [
"nvidia-cublas-cu13",
"nvidia-cuda-nvrtc-cu13",
]
else:
raise ValueError(f"Unknown toolkit {toolkit}")
name = "mlx-cuda"
install_requires += [
f"nvidia-cudnn-cu{toolkit}==9.*",
f"nvidia-nccl-cu{toolkit}",
"nvidia-cublas-cu12==12.9.*",
"nvidia-cuda-nvrtc-cu12==12.9.*",
"nvidia-cudnn-cu12==9.*",
"nvidia-nccl-cu12",
]
else:
name = "mlx-cpu"
_setup(

View File

@@ -1,4 +1,5 @@
// Copyright © 2023 Apple Inc.
#include <climits>
#include "doctest/doctest.h"
@@ -607,24 +608,3 @@ TEST_CASE("test make empty array") {
CHECK_EQ(a.size(), 0);
CHECK_EQ(a.dtype(), bool_);
}
TEST_CASE("test make array from user buffer") {
int size = 4096;
std::vector<int> buffer(size, 0);
int count = 0;
auto deleter = [&count](void*) { count++; };
{
auto a = array(buffer.data(), Shape{size}, int32, deleter);
if (metal::is_available()) {
CHECK_EQ(buffer.data(), a.data<int>());
}
auto b = a + array(1);
eval(b);
auto expected = ones({4096});
CHECK(array_equal(b, expected).item<bool>());
}
// deleter should always get called
CHECK_EQ(count, 1);
}