mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
10 Commits
ibv-backen
...
b862d842e1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b862d842e1 | ||
|
|
f7a400951a | ||
|
|
27232db1ba | ||
|
|
a4b3bc969b | ||
|
|
667c0f3bb9 | ||
|
|
6245824d42 | ||
|
|
39289ef025 | ||
|
|
aefc9bd3f6 | ||
|
|
997cfc7699 | ||
|
|
1fa8dc5797 |
11
.github/actions/build-cuda-release/action.yml
vendored
11
.github/actions/build-cuda-release/action.yml
vendored
@@ -1,6 +1,15 @@
|
||||
name: 'Build CUDA wheel'
|
||||
description: 'Build CUDA wheel'
|
||||
|
||||
inputs:
|
||||
arch:
|
||||
description: 'Platform architecture tag'
|
||||
required: true
|
||||
type: choice
|
||||
options:
|
||||
- x86_64
|
||||
- aarch64
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
@@ -12,4 +21,4 @@ runs:
|
||||
pip install auditwheel build patchelf setuptools
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=2 python -m build -w
|
||||
bash python/scripts/repair_cuda.sh
|
||||
bash python/scripts/repair_cuda.sh ${{ inputs.arch }}
|
||||
|
||||
1
.github/actions/setup-linux/action.yml
vendored
1
.github/actions/setup-linux/action.yml
vendored
@@ -15,6 +15,7 @@ runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Use ccache
|
||||
if: ${{ runner.arch == 'x86_64' }}
|
||||
uses: hendrikmuhs/ccache-action@v1.2
|
||||
with:
|
||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
|
||||
|
||||
10
.github/workflows/release.yml
vendored
10
.github/workflows/release.yml
vendored
@@ -128,7 +128,11 @@ jobs:
|
||||
|
||||
build_cuda_release:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
runs-on: ubuntu-22-large
|
||||
strategy:
|
||||
matrix:
|
||||
arch: ['x86_64', 'aarch64']
|
||||
toolkit: ['cuda-12.9', 'cuda-13.0']
|
||||
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }}
|
||||
env:
|
||||
PYPI_RELEASE: 1
|
||||
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||
@@ -136,9 +140,11 @@ jobs:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
toolkit: 'cuda-12.9'
|
||||
toolkit: ${{ matrix.toolkit }}
|
||||
- name: Build Python package
|
||||
uses: ./.github/actions/build-cuda-release
|
||||
with:
|
||||
arch: ${{ matrix.arch }}
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
|
||||
@@ -119,10 +119,6 @@ if(MLX_BUILD_METAL)
|
||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_SDK_VERSION
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-path"
|
||||
OUTPUT_VARIABLE CMAKE_OSX_SYSROOT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||
message(
|
||||
|
||||
@@ -29,17 +29,20 @@ MLX has a CUDA backend which you can install with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install mlx[cuda]
|
||||
pip install mlx[cuda12]
|
||||
|
||||
|
||||
To install the CUDA package from PyPi your system must meet the following
|
||||
requirements:
|
||||
|
||||
- Nvidia architecture >= SM 7.0 (Volta)
|
||||
- Nvidia architecture >= SM 7.5
|
||||
- Nvidia driver >= 550.54.14
|
||||
- CUDA toolkit >= 12.0
|
||||
- Linux distribution with glibc >= 2.35
|
||||
- Python >= 3.10
|
||||
|
||||
For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires
|
||||
an Nvidia driver >= 580 or an appropriate CUDA compatibility package.
|
||||
|
||||
CPU-only (Linux)
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
@@ -186,7 +186,7 @@ Boolean masks follow NumPy semantics:
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.arange(1000).reshape(10, 10, 10)
|
||||
>>> a[mx.random.randn(10, 10) > 0.0] = 0 # valid: mask covers axes 0 and 1
|
||||
>>> a[mx.random.normal((10, 10)) > 0.0] = 0 # valid: mask covers axes 0 and 1
|
||||
|
||||
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
|
||||
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
|
||||
namespace mlx::core::allocator {
|
||||
|
||||
Buffer malloc(size_t size) {
|
||||
auto buffer = allocator().malloc(size);
|
||||
if (size && !buffer.ptr()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
void free(Buffer buffer) {
|
||||
allocator().free(buffer);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::allocator
|
||||
@@ -28,16 +28,16 @@ class Buffer {
|
||||
};
|
||||
};
|
||||
|
||||
Buffer malloc(size_t size);
|
||||
|
||||
void free(Buffer buffer);
|
||||
|
||||
class Allocator {
|
||||
/** Abstract base class for a memory allocator. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size) = 0;
|
||||
virtual void free(Buffer buffer) = 0;
|
||||
virtual size_t size(Buffer buffer) const = 0;
|
||||
virtual Buffer make_buffer(void* ptr, size_t size) {
|
||||
return Buffer{nullptr};
|
||||
};
|
||||
virtual void release(Buffer buffer) {}
|
||||
|
||||
Allocator() = default;
|
||||
Allocator(const Allocator& other) = delete;
|
||||
@@ -49,4 +49,25 @@ class Allocator {
|
||||
|
||||
Allocator& allocator();
|
||||
|
||||
inline Buffer malloc(size_t size) {
|
||||
return allocator().malloc(size);
|
||||
}
|
||||
|
||||
inline void free(Buffer buffer) {
|
||||
allocator().free(buffer);
|
||||
}
|
||||
|
||||
// Make a Buffer from a raw pointer of the given size without a copy. If a
|
||||
// no-copy conversion is not possible then the returned buffer.ptr() will be
|
||||
// nullptr. Any buffer created with this function must be released with
|
||||
// release(buffer)
|
||||
inline Buffer make_buffer(void* ptr, size_t size) {
|
||||
return allocator().make_buffer(ptr, size);
|
||||
};
|
||||
|
||||
// Release a buffer from the allocator made with make_buffer
|
||||
inline void release(Buffer buffer) {
|
||||
allocator().release(buffer);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::allocator
|
||||
|
||||
@@ -82,6 +82,28 @@ array::array(std::initializer_list<int> data, Dtype dtype)
|
||||
init(data.begin());
|
||||
}
|
||||
|
||||
array::array(
|
||||
void* data,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
const std::function<void(void*)>& deleter)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
auto buffer = allocator::make_buffer(data, nbytes());
|
||||
if (buffer.ptr() == nullptr) {
|
||||
set_data(allocator::malloc(nbytes()));
|
||||
auto ptr = static_cast<char*>(data);
|
||||
std::copy(ptr, ptr + nbytes(), this->data<char>());
|
||||
deleter(data);
|
||||
} else {
|
||||
auto wrapped_deleter = [deleter](allocator::Buffer buffer) {
|
||||
auto ptr = buffer.ptr();
|
||||
allocator::release(buffer);
|
||||
return deleter(ptr);
|
||||
};
|
||||
set_data(buffer, std::move(wrapped_deleter));
|
||||
}
|
||||
}
|
||||
|
||||
/* Build an array from a shared buffer */
|
||||
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
|
||||
10
mlx/array.h
10
mlx/array.h
@@ -57,6 +57,16 @@ class array {
|
||||
Shape shape,
|
||||
Dtype dtype = TypeToDtype<T>());
|
||||
|
||||
/* Build an array from a raw pointer. The constructor will attempt to use the
|
||||
* input data without a copy. The deleter will be called when the array no
|
||||
* longer needs the underlying memory - after the array is destroyed in the
|
||||
* no-copy case and after the copy otherwise. */
|
||||
explicit array(
|
||||
void* data,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
const std::function<void(void*)>& deleter);
|
||||
|
||||
/* Build an array from a buffer */
|
||||
explicit array(
|
||||
allocator::Buffer data,
|
||||
|
||||
@@ -20,6 +20,19 @@ constexpr int page_size = 16384;
|
||||
// Any allocations smaller than this will try to use the small pool
|
||||
constexpr int small_block_size = 8;
|
||||
|
||||
#if CUDART_VERSION >= 13000
|
||||
inline cudaMemLocation cuda_mem_loc(int i) {
|
||||
cudaMemLocation loc;
|
||||
loc.type = cudaMemLocationTypeDevice;
|
||||
loc.id = i;
|
||||
return loc;
|
||||
}
|
||||
#else
|
||||
inline int cuda_mem_loc(int i) {
|
||||
return i;
|
||||
}
|
||||
#endif // CUDART_VERSION >= 13000
|
||||
|
||||
// The small pool size in bytes. This should be a multiple of the host page
|
||||
// size and small_block_size.
|
||||
constexpr int small_pool_size = 4 * page_size;
|
||||
@@ -35,13 +48,7 @@ SmallSizePool::SmallSizePool() {
|
||||
int device_count = 0;
|
||||
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
|
||||
for (int i = 0; i < device_count; ++i) {
|
||||
#if CUDART_VERSION >= 13000
|
||||
cudaMemLocation loc;
|
||||
loc.type = cudaMemLocationTypeDevice;
|
||||
loc.id = i;
|
||||
#else
|
||||
int loc = i;
|
||||
#endif // CUDART_VERSION >= 13000
|
||||
auto loc = cuda_mem_loc(i);
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
|
||||
}
|
||||
@@ -90,9 +97,10 @@ CudaAllocator::CudaAllocator()
|
||||
page_size,
|
||||
[](CudaBuffer* buf) { return buf->size; },
|
||||
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||
size_t free, total;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||
memory_limit_ = total * 0.9;
|
||||
size_t free;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total_memory_));
|
||||
memory_limit_ = total_memory_ * 0.95;
|
||||
free_limit_ = total_memory_ - memory_limit_;
|
||||
max_pool_size_ = memory_limit_;
|
||||
|
||||
int device_count = 0;
|
||||
@@ -104,6 +112,10 @@ CudaAllocator::CudaAllocator()
|
||||
cudaStream_t s;
|
||||
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
|
||||
free_streams_.push_back(s);
|
||||
|
||||
cudaMemPool_t mem_pool;
|
||||
CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&mem_pool, i));
|
||||
mem_pools_.push_back(mem_pool);
|
||||
}
|
||||
CHECK_CUDA_ERROR(cudaSetDevice(curr));
|
||||
}
|
||||
@@ -154,23 +166,35 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
|
||||
}
|
||||
lock.unlock();
|
||||
if (!buf) {
|
||||
cudaError_t err;
|
||||
void* data = nullptr;
|
||||
if (device == -1) {
|
||||
err = cudaMallocManaged(&data, size);
|
||||
CHECK_CUDA_ERROR(cudaMallocManaged(&data, size));
|
||||
} else {
|
||||
err = cudaMallocAsync(&data, size, stream);
|
||||
}
|
||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||
CHECK_CUDA_ERROR(cudaMallocAsync(&data, size, stream));
|
||||
}
|
||||
if (!data) {
|
||||
return Buffer{nullptr};
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
buf = new CudaBuffer{data, size, device};
|
||||
}
|
||||
lock.lock();
|
||||
|
||||
// If any cuda memory pool has too much reserved memory, clear some
|
||||
// memory from the cache. This prevents graph / kernel execution failing
|
||||
// from OOM
|
||||
if (get_cache_memory() > 0) {
|
||||
for (auto p : mem_pools_) {
|
||||
size_t used = 0;
|
||||
CHECK_CUDA_ERROR(cudaMemPoolGetAttribute(
|
||||
p, cudaMemPoolAttrReservedMemCurrent, &used));
|
||||
if (used > (total_memory_ - free_limit_)) {
|
||||
buffer_cache_.release_cached_buffers(free_limit_);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
active_memory_ += buf->size;
|
||||
peak_memory_ = std::max(active_memory_, peak_memory_);
|
||||
|
||||
@@ -71,11 +71,14 @@ class CudaAllocator : public allocator::Allocator {
|
||||
|
||||
std::mutex mutex_;
|
||||
size_t memory_limit_;
|
||||
size_t free_limit_;
|
||||
size_t total_memory_;
|
||||
size_t max_pool_size_;
|
||||
BufferCache<CudaBuffer> buffer_cache_;
|
||||
size_t active_memory_{0};
|
||||
size_t peak_memory_{0};
|
||||
std::vector<cudaStream_t> free_streams_;
|
||||
std::vector<cudaMemPool_t> mem_pools_;
|
||||
SmallSizePool scalar_pool_;
|
||||
};
|
||||
|
||||
|
||||
@@ -95,11 +95,14 @@ void copy_general_input(
|
||||
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
|
||||
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
|
||||
int ndim = shape.size();
|
||||
int work_per_thread = 1;
|
||||
|
||||
int work_per_thread = 8;
|
||||
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||
auto rest = out.size() / dim0;
|
||||
if (dim0 >= 4) {
|
||||
if (dim0 >= 4 && dim0 < 8) {
|
||||
work_per_thread = 4;
|
||||
} else if (dim0 < 4) {
|
||||
work_per_thread = 1;
|
||||
}
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
auto block_dims = get_block_dims(dim0, rest, 1);
|
||||
@@ -110,7 +113,10 @@ void copy_general_input(
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel =
|
||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
|
||||
if (work_per_thread == 4) {
|
||||
if (work_per_thread == 8) {
|
||||
kernel =
|
||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 8>;
|
||||
} else if (work_per_thread == 4) {
|
||||
kernel =
|
||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
|
||||
}
|
||||
@@ -127,7 +133,9 @@ void copy_general_input(
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
|
||||
if (work_per_thread == 4) {
|
||||
if (work_per_thread == 8) {
|
||||
kernel = cu::copy_g<InType, OutType, IdxT, 8>;
|
||||
} else if (work_per_thread == 4) {
|
||||
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
|
||||
}
|
||||
encoder.add_kernel_node(
|
||||
|
||||
@@ -318,46 +318,64 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
||||
insert_graph_dependencies(GraphNode{node, "K"});
|
||||
}
|
||||
|
||||
bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
|
||||
// CUDA graphs do not get updated correctly if a kernel node getting updated
|
||||
// has a different cluster shape than the node it's being updated with.
|
||||
std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
|
||||
// Constructs a key representing the nodes of a sub-graph.
|
||||
// Also checks if the sub-graph is updatable as CUDA graphs do not get
|
||||
// updated correctly if a kernel node getting updated has a different cluster
|
||||
// shape than the node it's being updated with.
|
||||
std::string key = "(";
|
||||
size_t num_nodes = 0;
|
||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));
|
||||
if (num_nodes == 0) {
|
||||
return true;
|
||||
return {key + ")", true};
|
||||
}
|
||||
|
||||
bool is_updatable = true;
|
||||
std::vector<cudaGraphNode_t> nodes(num_nodes);
|
||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
|
||||
for (const auto& node : nodes) {
|
||||
if (!is_updatable) {
|
||||
break;
|
||||
}
|
||||
cudaGraphNodeType type;
|
||||
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
||||
if (type == cudaGraphNodeTypeGraph) {
|
||||
switch (type) {
|
||||
case cudaGraphNodeTypeGraph: {
|
||||
// Try to be updatable for a structure like graph -> graph -> kernel
|
||||
if (num_nodes > 1) {
|
||||
return false;
|
||||
}
|
||||
cudaGraph_t child;
|
||||
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
|
||||
return is_graph_updatable(child, cluster_dim_x);
|
||||
} else if (type != cudaGraphNodeTypeKernel) {
|
||||
return false;
|
||||
} else {
|
||||
auto [subkey, sub_is_updatable] = subgraph_to_key(child);
|
||||
is_updatable &= sub_is_updatable;
|
||||
key += subkey;
|
||||
break;
|
||||
}
|
||||
case cudaGraphNodeTypeMemset:
|
||||
key += "M";
|
||||
break;
|
||||
case cudaGraphNodeTypeKernel: {
|
||||
cudaLaunchAttributeValue cluster_dim;
|
||||
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
||||
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
||||
// Only dim.x can be greater than 1
|
||||
// Only allow dim.x to be greater than 1
|
||||
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
|
||||
return false;
|
||||
is_updatable = false;
|
||||
} else {
|
||||
key += "K";
|
||||
key += std::to_string(cluster_dim.clusterDim.x);
|
||||
}
|
||||
// Only one child node allowed when subgraph uses clusters
|
||||
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
|
||||
return false;
|
||||
break;
|
||||
}
|
||||
cluster_dim_x = cluster_dim.clusterDim.x;
|
||||
case cudaGraphNodeTypeWaitEvent:
|
||||
key += "W";
|
||||
break;
|
||||
case cudaGraphNodeTypeEventRecord:
|
||||
key += "R";
|
||||
break;
|
||||
default:
|
||||
is_updatable = false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
key += ")";
|
||||
return {key, is_updatable};
|
||||
}
|
||||
|
||||
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||
@@ -370,11 +388,10 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||
return;
|
||||
}
|
||||
cudaGraphNode_t node;
|
||||
int cluster_dim_x = 0;
|
||||
is_graph_updatable_ &= is_graph_updatable(child, cluster_dim_x);
|
||||
auto [sub_graph_key, is_updatable] = subgraph_to_key(child);
|
||||
is_graph_updatable_ &= is_updatable;
|
||||
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||
insert_graph_dependencies(
|
||||
GraphNode{node, "G" + std::to_string(cluster_dim_x)});
|
||||
insert_graph_dependencies(GraphNode{node, sub_graph_key});
|
||||
}
|
||||
|
||||
bool CommandEncoder::needs_commit() {
|
||||
|
||||
@@ -106,7 +106,7 @@ class CommandEncoder {
|
||||
cudaGraphNode_t node;
|
||||
// K = kernel
|
||||
// E = empty
|
||||
// G* = subgraph (with metadata)
|
||||
// () = subgraph (with metadata)
|
||||
// Symbols ':', '-' are reserved as separators
|
||||
std::string node_type;
|
||||
std::string id;
|
||||
|
||||
@@ -89,9 +89,13 @@ template <
|
||||
int NDIM,
|
||||
int BM,
|
||||
int BN,
|
||||
int N_READS = 4>
|
||||
__global__ void
|
||||
col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||
int N_READS = 4,
|
||||
int BLOCKS = 1>
|
||||
__global__ void col_reduce_looped(
|
||||
T* in,
|
||||
U* out,
|
||||
const __grid_constant__ ColReduceArgs args,
|
||||
int64_t out_size) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
@@ -102,6 +106,8 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||
size_t tile_idx = grid.block_rank();
|
||||
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
|
||||
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
|
||||
size_t tile_out = tile_y / out_size;
|
||||
tile_y = tile_y % out_size;
|
||||
|
||||
// Compute the indices for the thread within the tile
|
||||
short thread_x = block.thread_rank() % threads_per_row;
|
||||
@@ -118,12 +124,23 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||
totals[i] = ReduceInit<Op, T>::value();
|
||||
}
|
||||
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
size_t total = args.non_col_reductions * args.reduction_size;
|
||||
size_t per_block, start, end;
|
||||
if constexpr (BLOCKS > 1) {
|
||||
per_block = (total + BLOCKS - 1) / BLOCKS;
|
||||
start = tile_out * per_block + thread_y;
|
||||
end = min((tile_out + 1) * per_block, total);
|
||||
} else {
|
||||
per_block = total;
|
||||
start = thread_y;
|
||||
end = total;
|
||||
}
|
||||
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
loop.next(start, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
if (tile_x * BN + BN <= args.reduction_stride) {
|
||||
if (args.reduction_stride % N_READS == 0) {
|
||||
for (size_t r = thread_y; r < total; r += BM) {
|
||||
for (size_t r = start; r < end; r += BM) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
@@ -132,7 +149,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
} else {
|
||||
for (size_t r = thread_y; r < total; r += BM) {
|
||||
for (size_t r = start; r < end; r += BM) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
@@ -142,7 +159,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t r = thread_y; r < total; r += BM) {
|
||||
for (size_t r = start; r < end; r += BM) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
thread_x,
|
||||
@@ -173,6 +190,9 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||
|
||||
// Write result.
|
||||
if (warp.thread_rank() == 0) {
|
||||
if (BLOCKS > 1) {
|
||||
out += tile_out * out_size * args.reduction_stride;
|
||||
}
|
||||
cub::StoreDirectBlocked(
|
||||
warp.meta_group_rank(),
|
||||
out + tile_y * args.reduction_stride + tile_x * BN,
|
||||
@@ -227,11 +247,12 @@ __global__ void col_reduce_small(
|
||||
inline auto output_grid_for_col_reduce(
|
||||
const array& out,
|
||||
const cu::ColReduceArgs& args,
|
||||
int bn) {
|
||||
int bn,
|
||||
int outer = 1) {
|
||||
int gx, gy = 1;
|
||||
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
|
||||
size_t n_outer_blocks = out.size() / args.reduction_stride;
|
||||
size_t n_blocks = n_outer_blocks * n_inner_blocks;
|
||||
size_t n_blocks = n_outer_blocks * n_inner_blocks * outer;
|
||||
while (n_blocks / gy > INT32_MAX) {
|
||||
gy *= 2;
|
||||
}
|
||||
@@ -277,7 +298,8 @@ void col_reduce_looped(
|
||||
0,
|
||||
indata,
|
||||
gpu_ptr<U>(out),
|
||||
static_cast<cu::ColReduceArgs>(args));
|
||||
static_cast<cu::ColReduceArgs>(args),
|
||||
out.size() / args.reduction_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -320,6 +342,117 @@ void col_reduce_small(
|
||||
});
|
||||
}
|
||||
|
||||
void col_reduce_two_pass(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan,
|
||||
const cu::ColReduceArgs& args) {
|
||||
// Allocate data for the output using in's layout to access them as
|
||||
// contiguously as possible.
|
||||
allocate_same_layout(out, in, axes, encoder);
|
||||
|
||||
// Allocate an intermediate array to hold the 1st pass result
|
||||
constexpr int outer = 32;
|
||||
|
||||
Shape intermediate_shape;
|
||||
intermediate_shape.push_back(outer);
|
||||
intermediate_shape.insert(
|
||||
intermediate_shape.end(), out.shape().begin(), out.shape().end());
|
||||
|
||||
Strides intermediate_strides;
|
||||
intermediate_strides.push_back(out.size());
|
||||
intermediate_strides.insert(
|
||||
intermediate_strides.end(), out.strides().begin(), out.strides().end());
|
||||
|
||||
array intermediate(intermediate_shape, out.dtype(), nullptr, {});
|
||||
auto [data_size, rc, cc] =
|
||||
check_contiguity(intermediate_shape, intermediate_strides);
|
||||
auto fl = out.flags();
|
||||
fl.row_contiguous = rc;
|
||||
fl.col_contiguous = cc;
|
||||
fl.contiguous = true;
|
||||
intermediate.set_data(
|
||||
cu::malloc_async(intermediate.nbytes(), encoder),
|
||||
data_size,
|
||||
intermediate_strides,
|
||||
fl,
|
||||
allocator::free);
|
||||
|
||||
encoder.add_temporary(intermediate);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(intermediate);
|
||||
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||
T* indata = const_cast<T*>(gpu_ptr<T>(in));
|
||||
|
||||
constexpr int N_READS = 4;
|
||||
constexpr int BM = 32;
|
||||
constexpr int BN = 32;
|
||||
dim3 grid = output_grid_for_col_reduce(out, args, BN, outer);
|
||||
int blocks = BM * BN / N_READS;
|
||||
auto kernel = cu::
|
||||
col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS, outer>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
blocks,
|
||||
0,
|
||||
indata,
|
||||
gpu_ptr<U>(intermediate),
|
||||
static_cast<cu::ColReduceArgs>(args),
|
||||
out.size() / args.reduction_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Prepare the reduction arguments for the 2nd pass
|
||||
cu::ColReduceArgs second_args = args;
|
||||
second_args.reduction_size = outer;
|
||||
second_args.reduction_stride = out.size();
|
||||
second_args.ndim = 0;
|
||||
second_args.reduce_shape[0] = outer;
|
||||
second_args.reduce_strides[0] = out.size();
|
||||
second_args.reduce_ndim = 1;
|
||||
second_args.non_col_reductions = 1;
|
||||
|
||||
encoder.set_input_array(intermediate);
|
||||
encoder.set_output_array(out);
|
||||
dispatch_all_types(intermediate.dtype(), [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
dispatch_reduce_ndim(second_args.reduce_ndim, [&](auto reduce_ndim) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
|
||||
constexpr int N_READS = 4;
|
||||
constexpr int BM = 32;
|
||||
constexpr int BN = 32;
|
||||
dim3 grid = output_grid_for_col_reduce(out, second_args, BN);
|
||||
int blocks = BM * BN / N_READS;
|
||||
auto kernel =
|
||||
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
blocks,
|
||||
0,
|
||||
gpu_ptr<T>(intermediate),
|
||||
gpu_ptr<U>(out),
|
||||
second_args,
|
||||
second_args.reduction_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void col_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
@@ -334,6 +467,18 @@ void col_reduce(
|
||||
// It is a general strided reduce. Each threadblock computes the output for
|
||||
// a subrow of the fast moving axis. For instance 32 elements.
|
||||
//
|
||||
// - col_reduce_small
|
||||
//
|
||||
// It is a column reduce for small columns. Each thread loops over the whole
|
||||
// column without communicating with any other thread.
|
||||
//
|
||||
// - col_reduce_two_pass
|
||||
//
|
||||
// It is a reduce for long columns. To increase parallelism, we split the
|
||||
// reduction in two passes. First we do a column reduce where many
|
||||
// threadblocks operate on different parts of the reduced axis. Then we
|
||||
// perform a final column reduce.
|
||||
//
|
||||
// Notes: As in row reduce we opt to read as much in order as possible and
|
||||
// leave transpositions as they are (contrary to our Metal backend).
|
||||
//
|
||||
@@ -349,6 +494,14 @@ void col_reduce(
|
||||
return;
|
||||
}
|
||||
|
||||
// Long column with smallish row
|
||||
size_t total_sums = args.non_col_reductions * args.reduction_size;
|
||||
size_t approx_threads = out.size();
|
||||
if (total_sums / approx_threads > 32) {
|
||||
col_reduce_two_pass(encoder, in, out, reduce_type, axes, plan, args);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fallback col reduce
|
||||
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
|
||||
}
|
||||
|
||||
@@ -80,7 +80,6 @@ CudaGraph::CudaGraph(cu::Device& device) {
|
||||
}
|
||||
|
||||
void CudaGraph::end_capture(cudaStream_t stream) {
|
||||
assert(handle_ == nullptr);
|
||||
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
|
||||
}
|
||||
|
||||
|
||||
@@ -7,8 +7,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s);
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||
}
|
||||
|
||||
@@ -149,7 +149,9 @@ Buffer MetalAllocator::malloc(size_t size) {
|
||||
buf = device_->newBuffer(size, resource_options);
|
||||
}
|
||||
if (!buf) {
|
||||
return Buffer{nullptr};
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
lk.lock();
|
||||
num_resources_++;
|
||||
@@ -201,6 +203,32 @@ size_t MetalAllocator::size(Buffer buffer) const {
|
||||
return static_cast<MTL::Buffer*>(buffer.ptr())->length();
|
||||
}
|
||||
|
||||
Buffer MetalAllocator::make_buffer(void* ptr, size_t size) {
|
||||
auto buf = device_->newBuffer(ptr, size, resource_options, nullptr);
|
||||
if (!buf) {
|
||||
return Buffer{nullptr};
|
||||
}
|
||||
std::unique_lock lk(mutex_);
|
||||
residency_set_.insert(buf);
|
||||
active_memory_ += buf->length();
|
||||
peak_memory_ = std::max(peak_memory_, active_memory_);
|
||||
num_resources_++;
|
||||
return Buffer{static_cast<void*>(buf)};
|
||||
}
|
||||
|
||||
void MetalAllocator::release(Buffer buffer) {
|
||||
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
||||
if (buf == nullptr) {
|
||||
return;
|
||||
}
|
||||
std::unique_lock lk(mutex_);
|
||||
active_memory_ -= buf->length();
|
||||
num_resources_--;
|
||||
lk.unlock();
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
buf->release();
|
||||
}
|
||||
|
||||
MetalAllocator& allocator() {
|
||||
// By creating the |allocator_| on heap, the destructor of MetalAllocator
|
||||
// will not be called on exit and buffers in the cache will be leaked. This
|
||||
|
||||
@@ -21,6 +21,9 @@ class MetalAllocator : public allocator::Allocator {
|
||||
virtual Buffer malloc(size_t size) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
virtual size_t size(Buffer buffer) const override;
|
||||
virtual Buffer make_buffer(void* ptr, size_t size) override;
|
||||
virtual void release(Buffer buffer) override;
|
||||
|
||||
size_t get_active_memory() {
|
||||
return active_memory_;
|
||||
};
|
||||
|
||||
@@ -25,6 +25,7 @@ class CommonAllocator : public Allocator {
|
||||
virtual Buffer malloc(size_t size) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
virtual size_t size(Buffer buffer) const override;
|
||||
|
||||
size_t get_active_memory() const {
|
||||
return active_memory_;
|
||||
};
|
||||
|
||||
@@ -4,11 +4,6 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
|
||||
|
||||
if(MLX_BUILD_CPU AND NOT WIN32)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/jaccl)
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/jaccl/jaccl.h"
|
||||
#include "mlx/distributed/mpi/mpi.h"
|
||||
#include "mlx/distributed/nccl/nccl.h"
|
||||
#include "mlx/distributed/ring/ring.h"
|
||||
@@ -103,27 +102,7 @@ class EmptyGroup : public GroupImpl {
|
||||
} // namespace detail
|
||||
|
||||
bool is_available() {
|
||||
return mpi::is_available() || ring::is_available() || nccl::is_available() ||
|
||||
jaccl::is_available();
|
||||
}
|
||||
|
||||
bool is_available(const std::string& bk) {
|
||||
if (bk == "any") {
|
||||
return is_available();
|
||||
}
|
||||
if (bk == "mpi") {
|
||||
return mpi::is_available();
|
||||
}
|
||||
if (bk == "ring") {
|
||||
return ring::is_available();
|
||||
}
|
||||
if (bk == "nccl") {
|
||||
return nccl::is_available();
|
||||
}
|
||||
if (bk == "jaccl") {
|
||||
return jaccl::is_available();
|
||||
}
|
||||
return false;
|
||||
return mpi::is_available() || ring::is_available() || nccl::is_available();
|
||||
}
|
||||
|
||||
int Group::rank() const {
|
||||
@@ -156,8 +135,6 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
||||
group = ring::init(strict);
|
||||
} else if (bk == "nccl") {
|
||||
group = nccl::init(strict);
|
||||
} else if (bk == "jaccl") {
|
||||
group = jaccl::init(strict);
|
||||
} else if (bk == "any") {
|
||||
if (mlx::core::cu::is_available()) {
|
||||
group = nccl::init(false);
|
||||
@@ -171,17 +148,13 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
||||
group = mpi::init(false);
|
||||
bk_ = "mpi";
|
||||
}
|
||||
if (group == nullptr) {
|
||||
group = jaccl::init(false);
|
||||
bk_ = "jaccl";
|
||||
}
|
||||
if (group == nullptr && strict) {
|
||||
throw std::runtime_error("[distributed] Couldn't initialize any backend");
|
||||
}
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "[distributed] The only valid values for backend are 'any', 'mpi', 'nccl', "
|
||||
<< "'jaccl' and 'ring' but '" << bk << "' was provided.";
|
||||
msg << "[distributed] The only valid values for backend are 'any', 'mpi' "
|
||||
<< "and 'ring' but '" << bk << "' was provided.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ class GroupImpl;
|
||||
|
||||
/* Check if a communication backend is available */
|
||||
bool is_available();
|
||||
bool is_available(const std::string& bk);
|
||||
|
||||
/**
|
||||
* A distributed::Group represents a group of independent mlx processes that
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
if(MLX_BUILD_CPU
|
||||
AND ${CMAKE_SYSTEM_NAME} MATCHES "Darwin"
|
||||
AND MACOS_SDK_VERSION GREATER_EQUAL 26.2)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jaccl.cpp)
|
||||
target_link_libraries(mlx PRIVATE rdma)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_jaccl.cpp)
|
||||
endif()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,12 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
|
||||
namespace mlx::core::distributed::jaccl {
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
|
||||
bool is_available();
|
||||
std::shared_ptr<GroupImpl> init(bool strict = false);
|
||||
|
||||
} // namespace mlx::core::distributed::jaccl
|
||||
@@ -1,20 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/distributed/jaccl/jaccl.h"
|
||||
|
||||
namespace mlx::core::distributed::jaccl {
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||
if (strict) {
|
||||
throw std::runtime_error("Cannot initialize jaccl distributed backend.");
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed::jaccl
|
||||
@@ -1,38 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
namespace mlx::core::distributed::detail {
|
||||
|
||||
template <typename T>
|
||||
struct SumOp {
|
||||
void operator()(const T* input, T* output, size_t N) const {
|
||||
while (N-- > 0) {
|
||||
*output += *input;
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MaxOp {
|
||||
void operator()(const T* input, T* output, size_t N) const {
|
||||
while (N-- > 0) {
|
||||
*output = std::max(*output, *input);
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MinOp {
|
||||
void operator()(const T* input, T* output, size_t N) const {
|
||||
while (N-- > 0) {
|
||||
*output = std::min(*output, *input);
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::distributed::detail
|
||||
@@ -1,6 +1,9 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <fcntl.h>
|
||||
#include <netdb.h>
|
||||
#include <netinet/in.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <sys/socket.h>
|
||||
#include <unistd.h>
|
||||
@@ -19,8 +22,6 @@
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/reduction_ops.h"
|
||||
#include "mlx/distributed/utils.h"
|
||||
#include "mlx/threadpool.h"
|
||||
|
||||
#ifndef SOL_TCP
|
||||
@@ -93,7 +94,6 @@ constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024;
|
||||
constexpr const size_t ALL_SUM_BUFFERS = 2;
|
||||
constexpr const int CONN_ATTEMPTS = 5;
|
||||
constexpr const int CONN_WAIT = 1000;
|
||||
constexpr const char* RING_TAG = "[ring]";
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
using json = nlohmann::json;
|
||||
@@ -296,6 +296,55 @@ class CommunicationThreads {
|
||||
std::unordered_map<int, SocketThread> threads_;
|
||||
};
|
||||
|
||||
struct address_t {
|
||||
sockaddr_storage addr;
|
||||
socklen_t len;
|
||||
|
||||
const sockaddr* get() const {
|
||||
return (struct sockaddr*)&addr;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Parse a sockaddr from an ip and port provided as strings.
|
||||
*/
|
||||
address_t parse_address(const std::string& ip, const std::string& port) {
|
||||
struct addrinfo hints, *res;
|
||||
memset(&hints, 0, sizeof(hints));
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
|
||||
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
|
||||
if (status != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Can't parse address " << ip << ":" << port;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
address_t result;
|
||||
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
|
||||
result.len = res->ai_addrlen;
|
||||
freeaddrinfo(res);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse a sockaddr provided as an <ip>:<port> string.
|
||||
*/
|
||||
address_t parse_address(const std::string& ip_port) {
|
||||
auto colon = ip_port.find(":");
|
||||
if (colon == std::string::npos) {
|
||||
std::ostringstream msg;
|
||||
msg << "Can't parse address " << ip_port;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
std::string ip(ip_port.begin(), ip_port.begin() + colon);
|
||||
std::string port(ip_port.begin() + colon + 1, ip_port.end());
|
||||
|
||||
return parse_address(ip, port);
|
||||
}
|
||||
|
||||
/**
|
||||
* Load all addresses from the json hostfile. The hostfile is a list of
|
||||
* addresses in order of rank. For each rank there can be many addresses so
|
||||
@@ -308,15 +357,15 @@ class CommunicationThreads {
|
||||
* ["ip3:5000", "ip3:5001"],
|
||||
* ]
|
||||
*/
|
||||
std::vector<std::vector<detail::address_t>> load_nodes(const char* hostfile) {
|
||||
std::vector<std::vector<detail::address_t>> nodes;
|
||||
std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
|
||||
std::vector<std::vector<address_t>> nodes;
|
||||
std::ifstream f(hostfile);
|
||||
|
||||
json hosts = json::parse(f);
|
||||
for (auto& h : hosts) {
|
||||
std::vector<detail::address_t> host;
|
||||
std::vector<address_t> host;
|
||||
for (auto& ips : h) {
|
||||
host.push_back(std::move(detail::parse_address(ips.get<std::string>())));
|
||||
host.push_back(parse_address(ips.get<std::string>()));
|
||||
}
|
||||
nodes.push_back(std::move(host));
|
||||
}
|
||||
@@ -328,15 +377,73 @@ std::vector<std::vector<detail::address_t>> load_nodes(const char* hostfile) {
|
||||
* Create a socket and accept one connection for each of the provided
|
||||
* addresses.
|
||||
*/
|
||||
std::vector<int> accept_connections(
|
||||
const std::vector<detail::address_t>& addresses) {
|
||||
std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
|
||||
std::vector<int> sockets;
|
||||
int success;
|
||||
|
||||
for (auto& address : addresses) {
|
||||
detail::TCPSocket socket(RING_TAG);
|
||||
socket.listen(RING_TAG, address);
|
||||
sockets.push_back(socket.accept(RING_TAG).detach());
|
||||
// Create the socket to wait for connections from the peers
|
||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't create socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Make sure we can launch immediately after shutdown by setting the
|
||||
// reuseaddr option so that we don't get address already in use errors
|
||||
int enable = 1;
|
||||
success = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't enable reuseaddr (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
success = setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't enable reuseport (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Bind the socket to the address and port
|
||||
success = bind(sock, address.get(), address.len);
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't bind socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Wait for connections
|
||||
success = listen(sock, 0);
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't listen (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
int peer_socket = accept(sock, nullptr, nullptr);
|
||||
if (peer_socket < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Accept failed (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Close the listening socket
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
|
||||
sockets.push_back(peer_socket);
|
||||
}
|
||||
|
||||
return sockets;
|
||||
@@ -347,42 +454,93 @@ std::vector<int> accept_connections(
|
||||
* provided addresses.
|
||||
*/
|
||||
std::vector<int> make_connections(
|
||||
const std::vector<detail::address_t>& addresses,
|
||||
const std::vector<address_t>& addresses,
|
||||
bool verbose) {
|
||||
std::vector<int> sockets;
|
||||
int success;
|
||||
|
||||
for (auto& address : addresses) {
|
||||
sockets.push_back(detail::TCPSocket::connect(
|
||||
RING_TAG,
|
||||
address,
|
||||
CONN_ATTEMPTS,
|
||||
CONN_WAIT,
|
||||
[verbose](int attempt, int wait) {
|
||||
int sock;
|
||||
|
||||
// Attempt to connect to the peer CONN_ATTEMPTS times with exponential
|
||||
// backoff. TODO: Do we need that?
|
||||
for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
|
||||
// Create the socket
|
||||
sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't create socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
if (attempt > 0) {
|
||||
int wait = (1 << (attempt - 1)) * CONN_WAIT;
|
||||
log_info(
|
||||
verbose,
|
||||
"Attempt",
|
||||
attempt,
|
||||
"waiting",
|
||||
"wait",
|
||||
wait,
|
||||
"ms (error:",
|
||||
errno,
|
||||
")");
|
||||
})
|
||||
.detach());
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
|
||||
}
|
||||
|
||||
success = connect(sock, address.get(), address.len);
|
||||
if (success == 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (success < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't connect (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
sockets.push_back(sock);
|
||||
}
|
||||
|
||||
return sockets;
|
||||
}
|
||||
template <typename T>
|
||||
struct SumOp {
|
||||
void operator()(const T* input, T* output, size_t N) {
|
||||
while (N-- > 0) {
|
||||
*output += *input;
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MaxOp {
|
||||
void operator()(const T* input, T* output, size_t N) {
|
||||
while (N-- > 0) {
|
||||
*output = std::max(*output, *input);
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MinOp {
|
||||
void operator()(const T* input, T* output, size_t N) {
|
||||
while (N-- > 0) {
|
||||
*output = std::min(*output, *input);
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
class RingGroup : public GroupImpl {
|
||||
public:
|
||||
RingGroup(
|
||||
int rank,
|
||||
std::vector<std::vector<detail::address_t>> nodes,
|
||||
bool verbose)
|
||||
RingGroup(int rank, std::vector<std::vector<address_t>> nodes, bool verbose)
|
||||
: rank_(rank), verbose_(verbose), pool_(0) {
|
||||
if (rank_ > 0 && rank_ >= nodes.size()) {
|
||||
throw std::runtime_error(
|
||||
@@ -475,17 +633,17 @@ class RingGroup : public GroupImpl {
|
||||
|
||||
void all_sum(const array& input, array& output, Stream stream) override {
|
||||
SWITCH_TYPE(
|
||||
output, all_reduce<T>(input, output, stream, detail::SumOp<T>()));
|
||||
output, all_reduce<T, SumOp<T>>(input, output, stream, SumOp<T>()));
|
||||
}
|
||||
|
||||
void all_max(const array& input, array& output, Stream stream) override {
|
||||
SWITCH_TYPE(
|
||||
output, all_reduce<T>(input, output, stream, detail::MaxOp<T>()));
|
||||
output, all_reduce<T, MaxOp<T>>(input, output, stream, MaxOp<T>()));
|
||||
}
|
||||
|
||||
void all_min(const array& input, array& output, Stream stream) override {
|
||||
SWITCH_TYPE(
|
||||
output, all_reduce<T>(input, output, stream, detail::MinOp<T>()));
|
||||
output, all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>()));
|
||||
}
|
||||
|
||||
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||
|
||||
@@ -1,204 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <netdb.h>
|
||||
#include <unistd.h>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
|
||||
#include "mlx/distributed/utils.h"
|
||||
|
||||
namespace mlx::core::distributed::detail {
|
||||
|
||||
/**
|
||||
* Parse a sockaddr from an ip and port provided as strings.
|
||||
*/
|
||||
address_t parse_address(const std::string& ip, const std::string& port) {
|
||||
struct addrinfo hints, *res;
|
||||
std::memset(&hints, 0, sizeof(hints));
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
|
||||
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
|
||||
if (status != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Can't parse address " << ip << ":" << port;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
address_t result;
|
||||
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
|
||||
result.len = res->ai_addrlen;
|
||||
freeaddrinfo(res);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse a sockaddr provided as an <ip>:<port> string.
|
||||
*/
|
||||
address_t parse_address(const std::string& ip_port) {
|
||||
auto colon = ip_port.find(":");
|
||||
if (colon == std::string::npos) {
|
||||
std::ostringstream msg;
|
||||
msg << "Can't parse address " << ip_port;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
std::string ip(ip_port.begin(), ip_port.begin() + colon);
|
||||
std::string port(ip_port.begin() + colon + 1, ip_port.end());
|
||||
|
||||
return parse_address(ip, port);
|
||||
}
|
||||
|
||||
TCPSocket::TCPSocket(const char* tag) {
|
||||
sock_ = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock_ < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Couldn't create socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
TCPSocket::TCPSocket(TCPSocket&& s) {
|
||||
sock_ = s.sock_;
|
||||
s.sock_ = -1;
|
||||
}
|
||||
|
||||
TCPSocket& TCPSocket::operator=(TCPSocket&& s) {
|
||||
if (this != &s) {
|
||||
sock_ = s.sock_;
|
||||
s.sock_ = -1;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
TCPSocket::TCPSocket(int s) : sock_(s) {}
|
||||
|
||||
TCPSocket::~TCPSocket() {
|
||||
if (sock_ > 0) {
|
||||
shutdown(sock_, 2);
|
||||
close(sock_);
|
||||
}
|
||||
}
|
||||
|
||||
int TCPSocket::detach() {
|
||||
int s = sock_;
|
||||
sock_ = -1;
|
||||
return s;
|
||||
}
|
||||
|
||||
void TCPSocket::listen(const char* tag, const address_t& addr) {
|
||||
int success;
|
||||
|
||||
// Make sure we can launch immediately after shutdown by setting the
|
||||
// reuseaddr option so that we don't get address already in use errors
|
||||
int enable = 1;
|
||||
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
||||
if (success < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Couldn't enable reuseaddr (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
|
||||
if (success < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Couldn't enable reuseport (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Bind the socket to the address and port
|
||||
success = bind(sock_, addr.get(), addr.len);
|
||||
if (success < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Couldn't bind socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Prepare waiting for connections
|
||||
success = ::listen(sock_, 0);
|
||||
if (success < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Couldn't listen (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
TCPSocket TCPSocket::accept(const char* tag) {
|
||||
int peer = ::accept(sock_, nullptr, nullptr);
|
||||
if (peer < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Accept failed (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return TCPSocket(peer);
|
||||
}
|
||||
|
||||
void TCPSocket::send(const char* tag, const void* data, size_t len) {
|
||||
while (len > 0) {
|
||||
auto n = ::send(sock_, data, len, 0);
|
||||
if (n <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Send failed with errno=" << errno;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
len -= n;
|
||||
data = static_cast<const char*>(data) + n;
|
||||
}
|
||||
}
|
||||
|
||||
void TCPSocket::recv(const char* tag, void* data, size_t len) {
|
||||
while (len > 0) {
|
||||
auto n = ::recv(sock_, data, len, 0);
|
||||
if (n <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Recv failed with errno=" << errno;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
len -= n;
|
||||
data = static_cast<char*>(data) + n;
|
||||
}
|
||||
}
|
||||
|
||||
TCPSocket TCPSocket::connect(
|
||||
const char* tag,
|
||||
const address_t& addr,
|
||||
int num_retries,
|
||||
int wait,
|
||||
std::function<void(int, int)> cb) {
|
||||
int sock, success;
|
||||
|
||||
// Attempt to connect `num_retries` times with exponential backoff.
|
||||
for (int attempt = 0; attempt < num_retries; attempt++) {
|
||||
// Create the socket
|
||||
sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Couldn't create socket to connect (error: " << errno
|
||||
<< ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
success = ::connect(sock, addr.get(), addr.len);
|
||||
if (success == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
cb(attempt, wait);
|
||||
if (wait > 0) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
|
||||
}
|
||||
|
||||
wait <<= 1;
|
||||
}
|
||||
|
||||
if (success < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Couldn't connect (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return TCPSocket(sock);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed::detail
|
||||
@@ -1,67 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <sys/socket.h>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
namespace mlx::core::distributed::detail {
|
||||
|
||||
struct address_t {
|
||||
sockaddr_storage addr;
|
||||
socklen_t len;
|
||||
|
||||
const sockaddr* get() const {
|
||||
return (struct sockaddr*)&addr;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Parse a sockaddr from an ip and port provided as strings.
|
||||
*/
|
||||
address_t parse_address(const std::string& ip, const std::string& port);
|
||||
|
||||
/**
|
||||
* Parse a sockaddr provided as an <ip>:<port> string.
|
||||
*/
|
||||
address_t parse_address(const std::string& ip_port);
|
||||
|
||||
/**
|
||||
* Small wrapper over a TCP socket to simplify initiating connections.
|
||||
*/
|
||||
class TCPSocket {
|
||||
public:
|
||||
TCPSocket(const char* tag);
|
||||
TCPSocket(const TCPSocket&) = delete;
|
||||
TCPSocket& operator=(const TCPSocket&) = delete;
|
||||
TCPSocket(TCPSocket&& s);
|
||||
TCPSocket& operator=(TCPSocket&&);
|
||||
~TCPSocket();
|
||||
|
||||
void listen(const char* tag, const address_t& addr);
|
||||
TCPSocket accept(const char* tag);
|
||||
|
||||
void send(const char* tag, const void* data, size_t len);
|
||||
void recv(const char* tag, void* data, size_t len);
|
||||
|
||||
int detach();
|
||||
|
||||
operator int() const {
|
||||
return sock_;
|
||||
}
|
||||
|
||||
static TCPSocket connect(
|
||||
const char* tag,
|
||||
const address_t& addr,
|
||||
int num_retries = 1,
|
||||
int wait = 0,
|
||||
std::function<void(int, int)> cb = nullptr);
|
||||
|
||||
private:
|
||||
TCPSocket(int sock);
|
||||
|
||||
int sock_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::distributed::detail
|
||||
@@ -1,85 +0,0 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import ipaddress
|
||||
import json
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class Host:
|
||||
rank: int
|
||||
ssh_hostname: str
|
||||
ips: list[str]
|
||||
rdma: list[Optional[str]]
|
||||
|
||||
|
||||
def positive_number(x):
|
||||
x = int(x)
|
||||
if x <= 0:
|
||||
raise ValueError("Number should be positive")
|
||||
return x
|
||||
|
||||
|
||||
def log(verbose, *args, **kwargs):
|
||||
if not verbose:
|
||||
return
|
||||
print("\033[32m[INFO]", *args, "\033[0m", **kwargs)
|
||||
|
||||
|
||||
def log_warning(*args, **kwargs):
|
||||
kwargs["file"] = sys.stderr
|
||||
print("\033[33m[WARN]", *args, "\033[0m", **kwargs)
|
||||
|
||||
|
||||
def log_error(*args, **kwargs):
|
||||
kwargs["file"] = sys.stderr
|
||||
print("\033[31m[ERROR]", *args, "\033[0m", **kwargs)
|
||||
|
||||
|
||||
def parse_hostlist(parser, hostlist, repeats):
|
||||
hosts = []
|
||||
for i, h in enumerate(hostlist.split(",")):
|
||||
if h == "":
|
||||
raise ValueError("Hostname cannot be empty")
|
||||
try:
|
||||
ipaddress.ip_address(h)
|
||||
ips = [h]
|
||||
except ValueError:
|
||||
ips = []
|
||||
for i in range(repeats):
|
||||
hosts.append(Host(i, h, ips))
|
||||
return hosts
|
||||
|
||||
|
||||
def parse_hostfile(parser, hostfile):
|
||||
"""Parse the json hostfile that contains both the hostnames to ssh into and
|
||||
the ips to communicate over when using the ring backend.
|
||||
|
||||
Example:
|
||||
|
||||
[
|
||||
{"ssh": "hostname1", "ips": ["123.123.123.1"], "rdma": [null, "rdma_en2", "rdma_en3"]},
|
||||
{"ssh": "hostname2", "ips": ["123.123.123.2"], "rdma": ["rdma_en2", null, "rdma_en3"]},
|
||||
...
|
||||
{"ssh": "hostnameN", "ips": ["123.123.123.N"], "rdma": ["rdma_en2", "rdma_en3", null]},
|
||||
]
|
||||
|
||||
Args:
|
||||
hostfile (str): The path to the json file containing the host
|
||||
information
|
||||
"""
|
||||
hostfile = Path(hostfile)
|
||||
if not hostfile.exists():
|
||||
parser.error(f"Hostfile {str(hostfile)} doesn't exist")
|
||||
|
||||
try:
|
||||
hosts = []
|
||||
with open(hostfile) as f:
|
||||
for i, h in enumerate(json.load(f)):
|
||||
hosts.append(Host(i, h["ssh"], h.get("ips", []), h.get("rdma", [])))
|
||||
return hosts
|
||||
except Exception as e:
|
||||
parser.error(f"Failed to parse hostfile {str(hostfile)} ({str(e)})")
|
||||
@@ -1,540 +0,0 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import shlex
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
from collections import Counter
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from queue import Empty as QueueEmpty
|
||||
from queue import Queue
|
||||
from select import select
|
||||
from subprocess import PIPE, Popen, run
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .common import log, log_warning, parse_hostfile, parse_hostlist, positive_number
|
||||
|
||||
|
||||
class CommandProcess:
|
||||
@property
|
||||
def process(self):
|
||||
"""Return the Popen object that refers to the current command."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def exit_status(self):
|
||||
"""Return a tuple (returncode, killed) for the command. It should be
|
||||
(None, None) while the command is running normally."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def preprocess_output(self, data: str, is_stdout=False):
|
||||
"""Preprocess the output of the command so that extra data can be
|
||||
capture or the format changed on the fly."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def terminate(self):
|
||||
"""Terminate or return the exit code."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class RemoteProcess(CommandProcess):
|
||||
def __init__(self, rank, host, cwd, files, env, command):
|
||||
is_local = host == "127.0.0.1"
|
||||
script = RemoteProcess.make_monitor_script(rank, cwd, files, env, command)
|
||||
script_b64 = base64.b64encode(script.encode()).decode()
|
||||
cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
|
||||
if not is_local:
|
||||
cmd = f"ssh {host} '{cmd}'"
|
||||
|
||||
self._host = host
|
||||
self._pidfile = None
|
||||
self._is_local = is_local
|
||||
self._process = Popen(
|
||||
cmd,
|
||||
shell=True,
|
||||
stdin=PIPE,
|
||||
stdout=PIPE,
|
||||
stderr=PIPE,
|
||||
)
|
||||
|
||||
self._killed = False
|
||||
|
||||
@property
|
||||
def process(self):
|
||||
return self._process
|
||||
|
||||
@property
|
||||
def exit_status(self):
|
||||
return self._process.poll(), self._killed
|
||||
|
||||
def preprocess_output(self, data, is_stdout=False):
|
||||
if self._pidfile is None:
|
||||
pidfile, *rest = data.split("\n", maxsplit=1)
|
||||
self._pidfile = pidfile
|
||||
return rest[0] if rest else ""
|
||||
|
||||
return data
|
||||
|
||||
def terminate(self):
|
||||
if self._killed:
|
||||
return
|
||||
|
||||
self._process.terminate()
|
||||
self._process.wait()
|
||||
|
||||
# Kill the remote program if possible
|
||||
cmd = ""
|
||||
cmd += f"pid=$(cat {self._pidfile}); "
|
||||
cmd += "if ps -p $pid >/dev/null; then "
|
||||
cmd += " kill $pid; "
|
||||
cmd += " echo 1; "
|
||||
cmd += "else "
|
||||
cmd += " echo 0; "
|
||||
cmd += "fi; "
|
||||
cmd += f"rm {self._pidfile}"
|
||||
if not self._is_local:
|
||||
cmd = f"ssh {self._host} '{cmd}'"
|
||||
c = run(cmd, check=True, shell=True, capture_output=True, text=True)
|
||||
|
||||
self._killed = c.stdout.strip() == "1"
|
||||
|
||||
@staticmethod
|
||||
def make_monitor_script(rank, cwd, files, env, command):
|
||||
# Imports that are used throughout
|
||||
script = ""
|
||||
script += "import os\n"
|
||||
script += "import sys\n"
|
||||
script += "import tempfile\n"
|
||||
script += "from pathlib import Path\n"
|
||||
|
||||
# Write the PID to a file so we can kill the process if needed
|
||||
script += "_, pidfile = tempfile.mkstemp() \n"
|
||||
script += "open(pidfile, 'w').write(str(os.getpid()))\n"
|
||||
script += "print(pidfile, flush=True)\n"
|
||||
|
||||
# Change the working directory if one was requested. Otherwise attempt to
|
||||
# change to the current one but don't fail if it wasn't possible.
|
||||
d = cwd or os.getcwd()
|
||||
script += f"if Path({repr(d)}).exists():\n"
|
||||
script += f" os.chdir({repr(d)})\n"
|
||||
if cwd is not None:
|
||||
script += "else:\n"
|
||||
script += f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n"
|
||||
script += f" sys.exit(1)\n"
|
||||
|
||||
# Add the environment variables that were requested
|
||||
script += "env = dict(os.environ)\n"
|
||||
for e in env:
|
||||
key, *value = e.split("=", maxsplit=1)
|
||||
value = shlex.quote(value[0]) if len(value) > 0 else ""
|
||||
if not all(c.isalnum() or c == "_" for c in key):
|
||||
log_warning(
|
||||
f"'{e}' is an invalid environment variable so it is ignored"
|
||||
)
|
||||
continue
|
||||
script += f"env[{repr(key)}] = {repr(value)}\n"
|
||||
|
||||
# Make the temporary files
|
||||
for env_name, content in files.items():
|
||||
script += "_, fname = tempfile.mkstemp()\n"
|
||||
script += "with open(fname, 'w') as f:\n"
|
||||
script += f" f.write({repr(content)})\n"
|
||||
script += f"env[{repr(env_name)}] = fname\n"
|
||||
|
||||
# Finally add the rank
|
||||
script += f"env['MLX_RANK'] = '{rank}'\n"
|
||||
script += "\n"
|
||||
|
||||
# Replace the process with the script
|
||||
script += f"command = [{','.join(map(repr, command))}]\n"
|
||||
script += "os.execve(command[0], command, env)\n"
|
||||
|
||||
return script
|
||||
|
||||
|
||||
def _launch_with_io(command_class, arguments, verbose):
|
||||
stop = False
|
||||
exit_codes = [(None, None)] * len(arguments)
|
||||
|
||||
def _thread_fn(rank, *args, **kwargs):
|
||||
stdin_queue = kwargs.pop("stdin_queue")
|
||||
stdout_queue = kwargs.pop("stdout_queue")
|
||||
stderr_queue = kwargs.pop("stderr_queue")
|
||||
|
||||
command = command_class(rank, *args, **kwargs)
|
||||
p = command.process
|
||||
os.set_blocking(p.stdout.fileno(), False)
|
||||
os.set_blocking(p.stderr.fileno(), False)
|
||||
os.set_blocking(p.stdin.fileno(), False)
|
||||
|
||||
to_read = [p.stdout.fileno(), p.stderr.fileno()]
|
||||
to_write = [p.stdin.fileno()]
|
||||
|
||||
stdin_buffer = b""
|
||||
while p.poll() is None:
|
||||
try:
|
||||
stdin_buffer += stdin_queue.get_nowait()
|
||||
except QueueEmpty:
|
||||
pass
|
||||
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
|
||||
for fd in rlist:
|
||||
is_stdout = fd == p.stdout.fileno()
|
||||
msg = os.read(fd, 8192).decode(errors="ignore")
|
||||
msg = command.preprocess_output(msg, is_stdout)
|
||||
if is_stdout:
|
||||
stdout_queue.put(msg.encode())
|
||||
else:
|
||||
stderr_queue.put(msg.encode())
|
||||
for fd in wlist:
|
||||
if len(stdin_buffer) > 0:
|
||||
n = os.write(fd, stdin_buffer)
|
||||
stdin_buffer = stdin_buffer[n:]
|
||||
if stop:
|
||||
command.terminate()
|
||||
break
|
||||
exit_codes[rank] = command.exit_status
|
||||
|
||||
if exit_codes[rank][1]:
|
||||
log_warning(f"Node with rank {rank} was killed")
|
||||
elif exit_codes[rank][0] != 0:
|
||||
log_warning(f"Node with rank {rank} exited with code {exit_codes[rank][0]}")
|
||||
else:
|
||||
log(verbose, f"Node with rank {rank} completed")
|
||||
|
||||
stdin_queues = []
|
||||
stdout_queues = []
|
||||
stderr_queues = []
|
||||
threads = []
|
||||
for i, (args, kwargs) in enumerate(arguments):
|
||||
stdin_queues.append(Queue())
|
||||
stdout_queues.append(Queue())
|
||||
stderr_queues.append(Queue())
|
||||
t = threading.Thread(
|
||||
target=_thread_fn,
|
||||
args=args,
|
||||
kwargs=kwargs
|
||||
| {
|
||||
"stdin_queue": stdin_queues[-1],
|
||||
"stdout_queue": stdout_queues[-1],
|
||||
"stderr_queue": stderr_queues[-1],
|
||||
},
|
||||
)
|
||||
t.start()
|
||||
threads.append(t)
|
||||
|
||||
os.set_blocking(sys.stdin.fileno(), False)
|
||||
os.set_blocking(sys.stdout.fileno(), True)
|
||||
os.set_blocking(sys.stderr.fileno(), True)
|
||||
while not stop or any(not q.empty() for q in chain(stdout_queues, stderr_queues)):
|
||||
# Broadcast user input to the jobs
|
||||
rlist, _, _ = select([sys.stdin.fileno()], [], [], 0.1)
|
||||
for fd in rlist:
|
||||
stdin_buffer = os.read(fd, 8192)
|
||||
for q in stdin_queues:
|
||||
q.put(stdin_buffer)
|
||||
|
||||
# Gather job output
|
||||
for q in stdout_queues:
|
||||
try:
|
||||
while not q.empty():
|
||||
sys.stdout.buffer.write(q.get_nowait())
|
||||
except QueueEmpty:
|
||||
pass
|
||||
for q in stderr_queues:
|
||||
try:
|
||||
while not q.empty():
|
||||
sys.stderr.buffer.write(q.get_nowait())
|
||||
except QueueEmpty:
|
||||
pass
|
||||
sys.stdout.buffer.flush()
|
||||
sys.stderr.buffer.flush()
|
||||
|
||||
# Check if all are running and terminate otherwise
|
||||
if any(t.is_alive() for t in threads):
|
||||
for i, t in enumerate(threads):
|
||||
if not t.is_alive():
|
||||
if exit_codes[i][0] != 0:
|
||||
stop = True
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
# Wait for the jobs to finish
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Process any remaining outputs
|
||||
for q in stdout_queues:
|
||||
while not q.empty():
|
||||
sys.stdout.buffer.write(q.get())
|
||||
for q in stderr_queues:
|
||||
while not q.empty():
|
||||
sys.stderr.buffer.write(q.get())
|
||||
sys.stdout.buffer.flush()
|
||||
sys.stderr.buffer.flush()
|
||||
|
||||
|
||||
def launch_ring(parser, hosts, args, command):
|
||||
if any(len(h.ips) == 0 for h in hosts):
|
||||
parser.error(
|
||||
"The ring backend requires IPs to be provided instead of hostnames"
|
||||
)
|
||||
|
||||
port = args.starting_port
|
||||
ring_hosts = []
|
||||
for h in hosts:
|
||||
node = []
|
||||
for ip in h.ips:
|
||||
for i in range(args.connections_per_ip):
|
||||
node.append(f"{ip}:{port}")
|
||||
port += 1
|
||||
ring_hosts.append(node)
|
||||
hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else ""
|
||||
|
||||
files = {"MLX_HOSTFILE": hostfile}
|
||||
env = args.env
|
||||
if args.verbose:
|
||||
env.append("MLX_RING_VERBOSE=1")
|
||||
cwd = args.cwd
|
||||
|
||||
log(args.verbose, "Running", shlex.join(command))
|
||||
|
||||
_launch_with_io(
|
||||
RemoteProcess,
|
||||
[
|
||||
((rank, h.ssh_hostname, cwd, files, env, command), {})
|
||||
for rank, h in enumerate(hosts)
|
||||
],
|
||||
args.verbose,
|
||||
)
|
||||
|
||||
|
||||
def launch_nccl(parser, hosts, args, command):
|
||||
if not hosts[0].ips:
|
||||
raise ValueError("Rank 0 should have an IP reachable from all other ranks")
|
||||
|
||||
master_host = hosts[0].ips[0]
|
||||
master_port = args.nccl_port
|
||||
world_size = len(hosts)
|
||||
|
||||
env = args.env
|
||||
cwd = args.cwd
|
||||
if args.verbose:
|
||||
env.append("NCCL_DEBUG=INFO")
|
||||
env.append(f"NCCL_HOST_IP={master_host}")
|
||||
env.append(f"NCCL_PORT={master_port}")
|
||||
env.append(f"MLX_WORLD_SIZE={world_size}")
|
||||
|
||||
log(args.verbose, "Running", shlex.join(command))
|
||||
|
||||
_launch_with_io(
|
||||
RemoteProcess,
|
||||
[
|
||||
(
|
||||
(
|
||||
rank,
|
||||
h.ssh_hostname,
|
||||
cwd,
|
||||
{},
|
||||
env + [f"CUDA_VISIBLE_DEVICES={rank % args.repeat_hosts}"],
|
||||
command,
|
||||
),
|
||||
{},
|
||||
)
|
||||
for rank, h in enumerate(hosts)
|
||||
],
|
||||
args.verbose,
|
||||
)
|
||||
|
||||
|
||||
def launch_jaccl(parser, hosts, args, command):
|
||||
if not hosts[0].ips:
|
||||
raise ValueError("Rank 0 should have an IP reachable from all other ranks")
|
||||
|
||||
have_rdmas = all(len(h.rdma) == len(hosts) for h in hosts)
|
||||
have_nulls = all(h.rdma[i] is None for i, h in enumerate(hosts))
|
||||
if not have_rdmas or not have_nulls:
|
||||
raise ValueError("Malformed hostfile for jaccl backend")
|
||||
|
||||
coordinator = hosts[0].ips[0]
|
||||
env = args.env
|
||||
cwd = args.cwd
|
||||
env.append(f"MLX_JACCL_COORDINATOR={coordinator}:{args.starting_port}")
|
||||
files = {"MLX_IBV_DEVICES": json.dumps([h.rdma for h in hosts])}
|
||||
|
||||
log(args.verbose, "Running", shlex.join(command))
|
||||
|
||||
_launch_with_io(
|
||||
RemoteProcess,
|
||||
[
|
||||
((rank, h.ssh_hostname, cwd, files, env, command), {})
|
||||
for rank, h in enumerate(hosts)
|
||||
],
|
||||
args.verbose,
|
||||
)
|
||||
|
||||
|
||||
def get_mpi_libname():
|
||||
try:
|
||||
ompi_info = run(["which", "ompi_info"], check=True, capture_output=True)
|
||||
ompi_info = ompi_info.stdout.strip().decode()
|
||||
|
||||
if platform.system() == "Darwin":
|
||||
otool_output = run(
|
||||
["otool", "-L", ompi_info], check=True, capture_output=True
|
||||
)
|
||||
else:
|
||||
otool_output = run(["ldd", ompi_info], check=True, capture_output=True)
|
||||
otool_output = otool_output.stdout.decode()
|
||||
|
||||
# StopIteration if not found
|
||||
libmpi_line = next(
|
||||
filter(lambda line: "libmpi" in line, otool_output.splitlines())
|
||||
)
|
||||
return libmpi_line.strip().split()[0].removeprefix("@rpath/")
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def launch_mpi(parser, hosts, args, command):
|
||||
mpirun = run(["which", "mpirun"], check=True, capture_output=True)
|
||||
mpirun = mpirun.stdout.strip().decode()
|
||||
|
||||
# Compatibility with homebrew and pip installs
|
||||
mpi_libname = get_mpi_libname()
|
||||
if mpi_libname is not None:
|
||||
dyld = Path(mpirun).parent.parent / "lib"
|
||||
args.env = [
|
||||
f"DYLD_LIBRARY_PATH={str(dyld)}",
|
||||
f"MLX_MPI_LIBNAME={mpi_libname}",
|
||||
] + args.env
|
||||
|
||||
log(args.verbose, f"Using '{mpirun}'")
|
||||
with tempfile.NamedTemporaryFile(mode="w") as f:
|
||||
hosts = Counter((h.ssh_hostname for h in hosts))
|
||||
for h, n in hosts.items():
|
||||
print(f"{h} slots={n}", file=f)
|
||||
f.flush()
|
||||
|
||||
cmd = [
|
||||
mpirun,
|
||||
"--output",
|
||||
":raw", # do not line buffer output
|
||||
"--hostfile",
|
||||
f.name,
|
||||
*(["-cwd", args.cwd] if args.cwd else []),
|
||||
*sum((["-x", e] for e in args.env), []),
|
||||
*sum([shlex.split(arg) for arg in args.mpi_arg], []),
|
||||
"--",
|
||||
*command,
|
||||
]
|
||||
log(args.verbose, "Running", " ".join(cmd))
|
||||
try:
|
||||
run(cmd)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Launch an MLX distributed program")
|
||||
parser.add_argument(
|
||||
"--print-python",
|
||||
action="store_true",
|
||||
help="Print the path to the current python executable and exit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", action="store_true", help="Print debug messages in stdout"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeat-hosts",
|
||||
"-n",
|
||||
type=positive_number,
|
||||
default=1,
|
||||
help="Repeat each host a given number of times",
|
||||
)
|
||||
parser.add_argument("--hostfile", help="The file containing the hosts")
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
choices=["ring", "mpi", "nccl", "jaccl"],
|
||||
default="nccl" if mx.cuda.is_available() else "ring",
|
||||
help="Which distributed backend to launch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
default=[],
|
||||
help="Set environment variables for the jobs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mpi-arg",
|
||||
action="append",
|
||||
default=[],
|
||||
help="Arguments to pass directly to mpirun",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--connections-per-ip",
|
||||
default=1,
|
||||
type=int,
|
||||
help="How many connections per ip to use for the ring backend",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--starting-port",
|
||||
"-p",
|
||||
type=int,
|
||||
default=32323,
|
||||
help="For the ring backend listen on this port increasing by 1 per rank and IP",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cwd", help="Set the working directory on each node to the provided one"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nccl-port",
|
||||
type=int,
|
||||
default=12345,
|
||||
help="The port to use for the NCCL communication (only for nccl backend)",
|
||||
)
|
||||
|
||||
args, rest = parser.parse_known_args()
|
||||
|
||||
if args.print_python:
|
||||
print(sys.executable)
|
||||
return
|
||||
|
||||
if len(rest) == 0:
|
||||
parser.error("No script is provided")
|
||||
if rest[0] == "--":
|
||||
rest.pop(0)
|
||||
|
||||
# Try to extract a list of hosts and corresponding ips
|
||||
if args.hostfile is not None:
|
||||
hosts = parse_hostfile(parser, args.hostfile)
|
||||
else:
|
||||
hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts)
|
||||
|
||||
# Check if the script is a file and convert it to a full path
|
||||
if (script := Path(rest[0])).exists() and script.is_file():
|
||||
rest[0:1] = [sys.executable, str(script.resolve())]
|
||||
elif (command := shutil.which(rest[0])) is not None:
|
||||
rest[0] = command
|
||||
else:
|
||||
raise ValueError(f"Invalid script or command {rest[0]}")
|
||||
|
||||
# Launch
|
||||
if args.backend == "ring":
|
||||
launch_ring(parser, hosts, args, rest)
|
||||
if args.backend == "mpi":
|
||||
launch_mpi(parser, hosts, args, rest)
|
||||
if args.backend == "nccl":
|
||||
launch_nccl(parser, hosts, args, rest)
|
||||
if args.backend == "jaccl":
|
||||
launch_jaccl(parser, hosts, args, rest)
|
||||
@@ -832,7 +832,7 @@ def main():
|
||||
parser.add_argument("--hostfile", help="The file containing the hosts")
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
choices=["ring", "mpi", "nccl", "jaccl"],
|
||||
choices=["ring", "mpi", "nccl"],
|
||||
default="nccl" if mx.cuda.is_available() else "ring",
|
||||
help="Which distributed backend to launch",
|
||||
)
|
||||
@@ -903,8 +903,6 @@ def main():
|
||||
launch_mpi(parser, hosts, args, rest)
|
||||
if args.backend == "nccl":
|
||||
launch_nccl(parser, hosts, args, rest)
|
||||
if args.backend == "jaccl":
|
||||
launch_jaccl(parser, hosts, args, rest)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
auditwheel repair dist/* \
|
||||
--plat manylinux_2_35_x86_64 \
|
||||
--plat manylinux_2_35_${1} \
|
||||
--exclude libcublas* \
|
||||
--exclude libnvrtc* \
|
||||
--exclude libcuda* \
|
||||
|
||||
@@ -52,25 +52,9 @@ void init_distributed(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"is_available",
|
||||
[](const std::string& backend) {
|
||||
return mx::distributed::is_available(backend);
|
||||
},
|
||||
"backend"_a = "any",
|
||||
nb::sig("def is_available(backend: str = 'any') -> bool"),
|
||||
&mx::distributed::is_available,
|
||||
R"pbdoc(
|
||||
Check if a communication backend is available.
|
||||
|
||||
Note, this function returns whether MLX has the capability of
|
||||
instantiating that distributed backend not whether it is possible to
|
||||
create a communication group. For that purpose one should use
|
||||
``init(strict=True)``.
|
||||
|
||||
Args:
|
||||
backend (str, optional): The name of the backend to check for availability.
|
||||
It takes the same values as ``init()``. Default: ``any``.
|
||||
|
||||
Returns:
|
||||
bool: Whether the distributed backend is available.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
@@ -95,10 +79,10 @@ void init_distributed(nb::module_& parent_module) {
|
||||
in case ``mx.distributed.is_available()`` returns False otherwise
|
||||
it throws a runtime error. Default: ``False``
|
||||
backend (str, optional): Which distributed backend to initialize.
|
||||
Possible values ``mpi``, ``ring``, ``nccl``, ``jaccl``, ``any``. If
|
||||
set to ``any`` all available backends are tried and the first one
|
||||
that succeeds becomes the global group which will be returned in
|
||||
subsequent calls. Default: ``any``
|
||||
Possible values ``mpi``, ``ring``, ``nccl``, ``any``. If set to ``any`` all
|
||||
available backends are tried and the first one that succeeds
|
||||
becomes the global group which will be returned in subsequent
|
||||
calls. Default: ``any``
|
||||
|
||||
Returns:
|
||||
Group: The group representing all the launched processes.
|
||||
|
||||
@@ -210,6 +210,14 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
ref = getattr(np, op)(np_arr, axis=axis)
|
||||
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
|
||||
|
||||
def test_long_column(self):
|
||||
a = (np.random.randn(8192, 64) * 32).astype(np.int32)
|
||||
b = mx.array(a)
|
||||
|
||||
c1 = a.sum(0)
|
||||
c2 = b.sum(0)
|
||||
self.assertTrue(np.all(c1 == c2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner(failfast=True)
|
||||
|
||||
40
setup.py
40
setup.py
@@ -7,13 +7,21 @@ import re
|
||||
import subprocess
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from subprocess import run
|
||||
|
||||
from setuptools import Command, Extension, find_namespace_packages, setup
|
||||
from setuptools.command.bdist_wheel import bdist_wheel
|
||||
from setuptools.command.build_ext import build_ext
|
||||
|
||||
|
||||
def cuda_toolkit_major_version():
|
||||
out = subprocess.check_output(["nvcc", "--version"], stderr=subprocess.STDOUT)
|
||||
text = out.decode()
|
||||
m = re.search(r"release (\d+)", text)
|
||||
if m:
|
||||
return int(m.group(1))
|
||||
return None
|
||||
|
||||
|
||||
def get_version():
|
||||
with open("mlx/version.h", "r") as fid:
|
||||
for l in fid:
|
||||
@@ -31,7 +39,7 @@ def get_version():
|
||||
version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}"
|
||||
if not pypi_release and not dev_release:
|
||||
git_hash = (
|
||||
run(
|
||||
subprocess.run(
|
||||
"git rev-parse --short HEAD".split(),
|
||||
capture_output=True,
|
||||
check=True,
|
||||
@@ -257,8 +265,8 @@ if __name__ == "__main__":
|
||||
}
|
||||
entry_points = {
|
||||
"console_scripts": [
|
||||
"mlx.launch = mlx._distributed_utils.launch:main",
|
||||
# "mlx.distributed_config = mlx.distributed_run:distributed_config",
|
||||
"mlx.launch = mlx.distributed_run:main",
|
||||
"mlx.distributed_config = mlx.distributed_run:distributed_config",
|
||||
]
|
||||
}
|
||||
install_requires = []
|
||||
@@ -284,7 +292,11 @@ if __name__ == "__main__":
|
||||
install_requires.append(
|
||||
f'mlx-metal=={version}; platform_system == "Darwin"'
|
||||
)
|
||||
extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"']
|
||||
extras["cuda"] = [f'mlx-cuda-12=={version}; platform_system == "Linux"']
|
||||
for toolkit in [12, 13]:
|
||||
extras[f"cuda{toolkit}"] = [
|
||||
f'mlx-cuda-{toolkit}=={version}; platform_system == "Linux"'
|
||||
]
|
||||
extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
|
||||
|
||||
_setup(
|
||||
@@ -299,13 +311,25 @@ if __name__ == "__main__":
|
||||
if build_macos:
|
||||
name = "mlx-metal"
|
||||
elif build_cuda:
|
||||
name = "mlx-cuda"
|
||||
toolkit = cuda_toolkit_major_version()
|
||||
name = f"mlx-cuda-{toolkit}"
|
||||
if toolkit == 12:
|
||||
install_requires += [
|
||||
"nvidia-cublas-cu12==12.9.*",
|
||||
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
||||
"nvidia-cudnn-cu12==9.*",
|
||||
"nvidia-nccl-cu12",
|
||||
]
|
||||
elif toolkit == 13:
|
||||
install_requires += [
|
||||
"nvidia-cublas-cu13",
|
||||
"nvidia-cuda-nvrtc-cu13",
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unknown toolkit {toolkit}")
|
||||
install_requires += [
|
||||
f"nvidia-cudnn-cu{toolkit}==9.*",
|
||||
f"nvidia-nccl-cu{toolkit}",
|
||||
]
|
||||
|
||||
else:
|
||||
name = "mlx-cpu"
|
||||
_setup(
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <climits>
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
@@ -608,3 +607,24 @@ TEST_CASE("test make empty array") {
|
||||
CHECK_EQ(a.size(), 0);
|
||||
CHECK_EQ(a.dtype(), bool_);
|
||||
}
|
||||
|
||||
TEST_CASE("test make array from user buffer") {
|
||||
int size = 4096;
|
||||
std::vector<int> buffer(size, 0);
|
||||
|
||||
int count = 0;
|
||||
auto deleter = [&count](void*) { count++; };
|
||||
|
||||
{
|
||||
auto a = array(buffer.data(), Shape{size}, int32, deleter);
|
||||
if (metal::is_available()) {
|
||||
CHECK_EQ(buffer.data(), a.data<int>());
|
||||
}
|
||||
auto b = a + array(1);
|
||||
eval(b);
|
||||
auto expected = ones({4096});
|
||||
CHECK(array_equal(b, expected).item<bool>());
|
||||
}
|
||||
// deleter should always get called
|
||||
CHECK_EQ(count, 1);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user