Compare commits

..

13 Commits

Author SHA1 Message Date
Angelos Katharopoulos
dd91ee9534 Refactoring launcher 2025-12-08 02:57:50 -08:00
Angelos Katharopoulos
8fab4f0929 Change the name to a fun pun 2025-12-04 14:20:52 -08:00
Angelos Katharopoulos
47af2c8cb0 Add headers for gcc 2025-12-04 14:20:52 -08:00
Angelos Katharopoulos
f40152ebc1 Expose per-backend availability in C++ and python 2025-12-04 14:20:52 -08:00
Angelos Katharopoulos
5d7e6a0642 Add a no_ibv 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
b9b78b1059 Add empty sum_scatter 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
45727b0c02 Add send/recv 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
2444fbdfe9 Make sure that there is space for work completions 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
f3b605e53c Add working reduce and semi-working all gather 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
0388ae3aaf Fix ring 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
d4c1de4a8b Fix side channel initialization for more than 2 peers 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
4dbffb3954 All gather 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
b1a60b2d2d Initial working all reduce 2025-12-04 14:20:51 -08:00
24 changed files with 105 additions and 426 deletions

View File

@@ -1,15 +1,6 @@
name: 'Build CUDA wheel' name: 'Build CUDA wheel'
description: 'Build CUDA wheel' description: 'Build CUDA wheel'
inputs:
arch:
description: 'Platform architecture tag'
required: true
type: choice
options:
- x86_64
- aarch64
runs: runs:
using: "composite" using: "composite"
steps: steps:
@@ -21,4 +12,4 @@ runs:
pip install auditwheel build patchelf setuptools pip install auditwheel build patchelf setuptools
python setup.py clean --all python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w MLX_BUILD_STAGE=2 python -m build -w
bash python/scripts/repair_cuda.sh ${{ inputs.arch }} bash python/scripts/repair_cuda.sh

View File

@@ -15,7 +15,6 @@ runs:
using: "composite" using: "composite"
steps: steps:
- name: Use ccache - name: Use ccache
if: ${{ runner.arch == 'x86_64' }}
uses: hendrikmuhs/ccache-action@v1.2 uses: hendrikmuhs/ccache-action@v1.2
with: with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }} key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}

View File

@@ -128,11 +128,7 @@ jobs:
build_cuda_release: build_cuda_release:
if: github.repository == 'ml-explore/mlx' if: github.repository == 'ml-explore/mlx'
strategy: runs-on: ubuntu-22-large
matrix:
arch: ['x86_64', 'aarch64']
toolkit: ['cuda-12.9', 'cuda-13.0']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }}
env: env:
PYPI_RELEASE: 1 PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }} DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
@@ -140,11 +136,9 @@ jobs:
- uses: actions/checkout@v6 - uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux - uses: ./.github/actions/setup-linux
with: with:
toolkit: ${{ matrix.toolkit }} toolkit: 'cuda-12.9'
- name: Build Python package - name: Build Python package
uses: ./.github/actions/build-cuda-release uses: ./.github/actions/build-cuda-release
with:
arch: ${{ matrix.arch }}
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v5
with: with:

View File

@@ -29,20 +29,17 @@ MLX has a CUDA backend which you can install with:
.. code-block:: shell .. code-block:: shell
pip install mlx[cuda12] pip install mlx[cuda]
To install the CUDA package from PyPi your system must meet the following To install the CUDA package from PyPi your system must meet the following
requirements: requirements:
- Nvidia architecture >= SM 7.5 - Nvidia architecture >= SM 7.0 (Volta)
- Nvidia driver >= 550.54.14 - Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0 - CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35 - Linux distribution with glibc >= 2.35
- Python >= 3.10 - Python >= 3.10
For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires
an Nvidia driver >= 580 or an appropriate CUDA compatibility package.
CPU-only (Linux) CPU-only (Linux)
^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^

View File

@@ -1,6 +1,7 @@
target_sources( target_sources(
mlx mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp

24
mlx/allocator.cpp Normal file
View File

@@ -0,0 +1,24 @@
// Copyright © 2023 Apple Inc.
#include <cstdlib>
#include <sstream>
#include "mlx/allocator.h"
namespace mlx::core::allocator {
Buffer malloc(size_t size) {
auto buffer = allocator().malloc(size);
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
void free(Buffer buffer) {
allocator().free(buffer);
}
} // namespace mlx::core::allocator

View File

@@ -28,16 +28,16 @@ class Buffer {
}; };
}; };
Buffer malloc(size_t size);
void free(Buffer buffer);
class Allocator { class Allocator {
/** Abstract base class for a memory allocator. */ /** Abstract base class for a memory allocator. */
public: public:
virtual Buffer malloc(size_t size) = 0; virtual Buffer malloc(size_t size) = 0;
virtual void free(Buffer buffer) = 0; virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0; virtual size_t size(Buffer buffer) const = 0;
virtual Buffer make_buffer(void* ptr, size_t size) {
return Buffer{nullptr};
};
virtual void release(Buffer buffer) {}
Allocator() = default; Allocator() = default;
Allocator(const Allocator& other) = delete; Allocator(const Allocator& other) = delete;
@@ -49,25 +49,4 @@ class Allocator {
Allocator& allocator(); Allocator& allocator();
inline Buffer malloc(size_t size) {
return allocator().malloc(size);
}
inline void free(Buffer buffer) {
allocator().free(buffer);
}
// Make a Buffer from a raw pointer of the given size without a copy. If a
// no-copy conversion is not possible then the returned buffer.ptr() will be
// nullptr. Any buffer created with this function must be released with
// release(buffer)
inline Buffer make_buffer(void* ptr, size_t size) {
return allocator().make_buffer(ptr, size);
};
// Release a buffer from the allocator made with make_buffer
inline void release(Buffer buffer) {
allocator().release(buffer);
}
} // namespace mlx::core::allocator } // namespace mlx::core::allocator

View File

@@ -82,28 +82,6 @@ array::array(std::initializer_list<int> data, Dtype dtype)
init(data.begin()); init(data.begin());
} }
array::array(
void* data,
Shape shape,
Dtype dtype,
const std::function<void(void*)>& deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
auto buffer = allocator::make_buffer(data, nbytes());
if (buffer.ptr() == nullptr) {
set_data(allocator::malloc(nbytes()));
auto ptr = static_cast<char*>(data);
std::copy(ptr, ptr + nbytes(), this->data<char>());
deleter(data);
} else {
auto wrapped_deleter = [deleter](allocator::Buffer buffer) {
auto ptr = buffer.ptr();
allocator::release(buffer);
return deleter(ptr);
};
set_data(buffer, std::move(wrapped_deleter));
}
}
/* Build an array from a shared buffer */ /* Build an array from a shared buffer */
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter) array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) { : array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {

View File

@@ -57,16 +57,6 @@ class array {
Shape shape, Shape shape,
Dtype dtype = TypeToDtype<T>()); Dtype dtype = TypeToDtype<T>());
/* Build an array from a raw pointer. The constructor will attempt to use the
* input data without a copy. The deleter will be called when the array no
* longer needs the underlying memory - after the array is destroyed in the
* no-copy case and after the copy otherwise. */
explicit array(
void* data,
Shape shape,
Dtype dtype,
const std::function<void(void*)>& deleter);
/* Build an array from a buffer */ /* Build an array from a buffer */
explicit array( explicit array(
allocator::Buffer data, allocator::Buffer data,

View File

@@ -20,19 +20,6 @@ constexpr int page_size = 16384;
// Any allocations smaller than this will try to use the small pool // Any allocations smaller than this will try to use the small pool
constexpr int small_block_size = 8; constexpr int small_block_size = 8;
#if CUDART_VERSION >= 13000
inline cudaMemLocation cuda_mem_loc(int i) {
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = i;
return loc;
}
#else
inline int cuda_mem_loc(int i) {
return i;
}
#endif // CUDART_VERSION >= 13000
// The small pool size in bytes. This should be a multiple of the host page // The small pool size in bytes. This should be a multiple of the host page
// size and small_block_size. // size and small_block_size.
constexpr int small_pool_size = 4 * page_size; constexpr int small_pool_size = 4 * page_size;
@@ -48,7 +35,13 @@ SmallSizePool::SmallSizePool() {
int device_count = 0; int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count)); CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
for (int i = 0; i < device_count; ++i) { for (int i = 0; i < device_count; ++i) {
auto loc = cuda_mem_loc(i); #if CUDART_VERSION >= 13000
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = i;
#else
int loc = i;
#endif // CUDART_VERSION >= 13000
CHECK_CUDA_ERROR( CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc)); cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
} }
@@ -97,10 +90,9 @@ CudaAllocator::CudaAllocator()
page_size, page_size,
[](CudaBuffer* buf) { return buf->size; }, [](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) { [this](CudaBuffer* buf) { cuda_free(buf); }) {
size_t free; size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total_memory_)); CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total_memory_ * 0.95; memory_limit_ = total * 0.9;
free_limit_ = total_memory_ - memory_limit_;
max_pool_size_ = memory_limit_; max_pool_size_ = memory_limit_;
int device_count = 0; int device_count = 0;
@@ -112,10 +104,6 @@ CudaAllocator::CudaAllocator()
cudaStream_t s; cudaStream_t s;
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking)); CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
free_streams_.push_back(s); free_streams_.push_back(s);
cudaMemPool_t mem_pool;
CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&mem_pool, i));
mem_pools_.push_back(mem_pool);
} }
CHECK_CUDA_ERROR(cudaSetDevice(curr)); CHECK_CUDA_ERROR(cudaSetDevice(curr));
} }
@@ -166,35 +154,23 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
} }
lock.unlock(); lock.unlock();
if (!buf) { if (!buf) {
cudaError_t err;
void* data = nullptr; void* data = nullptr;
if (device == -1) { if (device == -1) {
CHECK_CUDA_ERROR(cudaMallocManaged(&data, size)); err = cudaMallocManaged(&data, size);
} else { } else {
CHECK_CUDA_ERROR(cudaMallocAsync(&data, size, stream)); err = cudaMallocAsync(&data, size, stream);
}
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
} }
if (!data) { if (!data) {
std::ostringstream msg; return Buffer{nullptr};
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
} }
buf = new CudaBuffer{data, size, device}; buf = new CudaBuffer{data, size, device};
} }
lock.lock(); lock.lock();
// If any cuda memory pool has too much reserved memory, clear some
// memory from the cache. This prevents graph / kernel execution failing
// from OOM
if (get_cache_memory() > 0) {
for (auto p : mem_pools_) {
size_t used = 0;
CHECK_CUDA_ERROR(cudaMemPoolGetAttribute(
p, cudaMemPoolAttrReservedMemCurrent, &used));
if (used > (total_memory_ - free_limit_)) {
buffer_cache_.release_cached_buffers(free_limit_);
break;
}
}
}
} }
active_memory_ += buf->size; active_memory_ += buf->size;
peak_memory_ = std::max(active_memory_, peak_memory_); peak_memory_ = std::max(active_memory_, peak_memory_);

View File

@@ -71,14 +71,11 @@ class CudaAllocator : public allocator::Allocator {
std::mutex mutex_; std::mutex mutex_;
size_t memory_limit_; size_t memory_limit_;
size_t free_limit_;
size_t total_memory_;
size_t max_pool_size_; size_t max_pool_size_;
BufferCache<CudaBuffer> buffer_cache_; BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0}; size_t active_memory_{0};
size_t peak_memory_{0}; size_t peak_memory_{0};
std::vector<cudaStream_t> free_streams_; std::vector<cudaStream_t> free_streams_;
std::vector<cudaMemPool_t> mem_pools_;
SmallSizePool scalar_pool_; SmallSizePool scalar_pool_;
}; };

View File

@@ -95,14 +95,11 @@ void copy_general_input(
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in; const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out; OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size(); int ndim = shape.size();
int work_per_thread = 1;
int work_per_thread = 8;
auto dim0 = ndim > 0 ? shape.back() : 1; auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0; auto rest = out.size() / dim0;
if (dim0 >= 4 && dim0 < 8) { if (dim0 >= 4) {
work_per_thread = 4; work_per_thread = 4;
} else if (dim0 < 4) {
work_per_thread = 1;
} }
dim0 = (dim0 + work_per_thread - 1) / work_per_thread; dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1); auto block_dims = get_block_dims(dim0, rest, 1);
@@ -113,10 +110,7 @@ void copy_general_input(
dispatch_1_2_3(ndim, [&](auto dims_constant) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = auto kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>; cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
if (work_per_thread == 8) { if (work_per_thread == 4) {
kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 8>;
} else if (work_per_thread == 4) {
kernel = kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>; cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
} }
@@ -133,9 +127,7 @@ void copy_general_input(
}); });
} else { // ndim >= 4 } else { // ndim >= 4
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>; auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
if (work_per_thread == 8) { if (work_per_thread == 4) {
kernel = cu::copy_g<InType, OutType, IdxT, 8>;
} else if (work_per_thread == 4) {
kernel = cu::copy_g<InType, OutType, IdxT, 4>; kernel = cu::copy_g<InType, OutType, IdxT, 4>;
} }
encoder.add_kernel_node( encoder.add_kernel_node(

View File

@@ -318,52 +318,46 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
insert_graph_dependencies(GraphNode{node, "K"}); insert_graph_dependencies(GraphNode{node, "K"});
} }
std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) { bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
// Constructs a key representing the nodes of a sub-graph. // CUDA graphs do not get updated correctly if a kernel node getting updated
// Also checks if the sub-graph is updatable as CUDA graphs do not get // has a different cluster shape than the node it's being updated with.
// updated correctly if a kernel node getting updated has a different cluster
// shape than the node it's being updated with.
std::string key = "(";
size_t num_nodes = 0; size_t num_nodes = 0;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes)); CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));
if (num_nodes == 0) { if (num_nodes == 0) {
return {key + ")", true}; return true;
} }
bool is_updatable = true;
std::vector<cudaGraphNode_t> nodes(num_nodes); std::vector<cudaGraphNode_t> nodes(num_nodes);
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes)); CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
for (const auto& node : nodes) { for (const auto& node : nodes) {
if (!is_updatable) {
break;
}
cudaGraphNodeType type; cudaGraphNodeType type;
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type)); CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
if (type == cudaGraphNodeTypeGraph) { if (type == cudaGraphNodeTypeGraph) {
// Try to be updatable for a structure like graph -> graph -> kernel // Try to be updatable for a structure like graph -> graph -> kernel
if (num_nodes > 1) {
return false;
}
cudaGraph_t child; cudaGraph_t child;
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child)); CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
auto [subkey, sub_is_updatable] = subgraph_to_key(child); return is_graph_updatable(child, cluster_dim_x);
is_updatable &= sub_is_updatable;
key += subkey;
} else if (type == cudaGraphNodeTypeMemset) {
key += "M";
} else if (type != cudaGraphNodeTypeKernel) { } else if (type != cudaGraphNodeTypeKernel) {
is_updatable = false; return false;
} else { } else {
cudaLaunchAttributeValue cluster_dim; cudaLaunchAttributeValue cluster_dim;
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute( CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
node, cudaLaunchAttributeClusterDimension, &cluster_dim)); node, cudaLaunchAttributeClusterDimension, &cluster_dim));
// Only allow dim.x to be greater than 1 // Only dim.x can be greater than 1
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) { if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
is_updatable = false; return false;
} else { }
key += "K"; // Only one child node allowed when subgraph uses clusters
key += std::to_string(cluster_dim.clusterDim.x); if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
return false;
}
cluster_dim_x = cluster_dim.clusterDim.x;
} }
} }
} return true;
key += ")";
return {key, is_updatable};
} }
void CommandEncoder::add_graph_node(cudaGraph_t child) { void CommandEncoder::add_graph_node(cudaGraph_t child) {
@@ -376,10 +370,11 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
return; return;
} }
cudaGraphNode_t node; cudaGraphNode_t node;
auto [sub_graph_key, is_updatable] = subgraph_to_key(child); int cluster_dim_x = 0;
is_graph_updatable_ &= is_updatable; is_graph_updatable_ &= is_graph_updatable(child, cluster_dim_x);
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child)); CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies(GraphNode{node, sub_graph_key}); insert_graph_dependencies(
GraphNode{node, "G" + std::to_string(cluster_dim_x)});
} }
bool CommandEncoder::needs_commit() { bool CommandEncoder::needs_commit() {

View File

@@ -106,7 +106,7 @@ class CommandEncoder {
cudaGraphNode_t node; cudaGraphNode_t node;
// K = kernel // K = kernel
// E = empty // E = empty
// () = subgraph (with metadata) // G* = subgraph (with metadata)
// Symbols ':', '-' are reserved as separators // Symbols ':', '-' are reserved as separators
std::string node_type; std::string node_type;
std::string id; std::string id;

View File

@@ -89,13 +89,9 @@ template <
int NDIM, int NDIM,
int BM, int BM,
int BN, int BN,
int N_READS = 4, int N_READS = 4>
int BLOCKS = 1> __global__ void
__global__ void col_reduce_looped( col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
T* in,
U* out,
const __grid_constant__ ColReduceArgs args,
int64_t out_size) {
auto grid = cg::this_grid(); auto grid = cg::this_grid();
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block); auto warp = cg::tiled_partition<WARP_SIZE>(block);
@@ -106,8 +102,6 @@ __global__ void col_reduce_looped(
size_t tile_idx = grid.block_rank(); size_t tile_idx = grid.block_rank();
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN); size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN); size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
size_t tile_out = tile_y / out_size;
tile_y = tile_y % out_size;
// Compute the indices for the thread within the tile // Compute the indices for the thread within the tile
short thread_x = block.thread_rank() % threads_per_row; short thread_x = block.thread_rank() % threads_per_row;
@@ -124,23 +118,12 @@ __global__ void col_reduce_looped(
totals[i] = ReduceInit<Op, T>::value(); totals[i] = ReduceInit<Op, T>::value();
} }
size_t total = args.non_col_reductions * args.reduction_size;
size_t per_block, start, end;
if constexpr (BLOCKS > 1) {
per_block = (total + BLOCKS - 1) / BLOCKS;
start = tile_out * per_block + thread_y;
end = min((tile_out + 1) * per_block, total);
} else {
per_block = total;
start = thread_y;
end = total;
}
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim); LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(start, args.reduce_shape.data(), args.reduce_strides.data()); loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
size_t total = args.non_col_reductions * args.reduction_size;
if (tile_x * BN + BN <= args.reduction_stride) { if (tile_x * BN + BN <= args.reduction_stride) {
if (args.reduction_stride % N_READS == 0) { if (args.reduction_stride % N_READS == 0) {
for (size_t r = start; r < end; r += BM) { for (size_t r = thread_y; r < total; r += BM) {
T vals[N_READS]; T vals[N_READS];
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals); cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
@@ -149,7 +132,7 @@ __global__ void col_reduce_looped(
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
} }
} else { } else {
for (size_t r = start; r < end; r += BM) { for (size_t r = thread_y; r < total; r += BM) {
T vals[N_READS]; T vals[N_READS];
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals); cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
@@ -159,7 +142,7 @@ __global__ void col_reduce_looped(
} }
} }
} else { } else {
for (size_t r = start; r < end; r += BM) { for (size_t r = thread_y; r < total; r += BM) {
T vals[N_READS]; T vals[N_READS];
cub::LoadDirectBlocked( cub::LoadDirectBlocked(
thread_x, thread_x,
@@ -190,9 +173,6 @@ __global__ void col_reduce_looped(
// Write result. // Write result.
if (warp.thread_rank() == 0) { if (warp.thread_rank() == 0) {
if (BLOCKS > 1) {
out += tile_out * out_size * args.reduction_stride;
}
cub::StoreDirectBlocked( cub::StoreDirectBlocked(
warp.meta_group_rank(), warp.meta_group_rank(),
out + tile_y * args.reduction_stride + tile_x * BN, out + tile_y * args.reduction_stride + tile_x * BN,
@@ -247,12 +227,11 @@ __global__ void col_reduce_small(
inline auto output_grid_for_col_reduce( inline auto output_grid_for_col_reduce(
const array& out, const array& out,
const cu::ColReduceArgs& args, const cu::ColReduceArgs& args,
int bn, int bn) {
int outer = 1) {
int gx, gy = 1; int gx, gy = 1;
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn); size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
size_t n_outer_blocks = out.size() / args.reduction_stride; size_t n_outer_blocks = out.size() / args.reduction_stride;
size_t n_blocks = n_outer_blocks * n_inner_blocks * outer; size_t n_blocks = n_outer_blocks * n_inner_blocks;
while (n_blocks / gy > INT32_MAX) { while (n_blocks / gy > INT32_MAX) {
gy *= 2; gy *= 2;
} }
@@ -298,8 +277,7 @@ void col_reduce_looped(
0, 0,
indata, indata,
gpu_ptr<U>(out), gpu_ptr<U>(out),
static_cast<cu::ColReduceArgs>(args), static_cast<cu::ColReduceArgs>(args));
out.size() / args.reduction_stride);
}); });
}); });
}); });
@@ -342,117 +320,6 @@ void col_reduce_small(
}); });
} }
void col_reduce_two_pass(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes, encoder);
// Allocate an intermediate array to hold the 1st pass result
constexpr int outer = 32;
Shape intermediate_shape;
intermediate_shape.push_back(outer);
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end());
Strides intermediate_strides;
intermediate_strides.push_back(out.size());
intermediate_strides.insert(
intermediate_strides.end(), out.strides().begin(), out.strides().end());
array intermediate(intermediate_shape, out.dtype(), nullptr, {});
auto [data_size, rc, cc] =
check_contiguity(intermediate_shape, intermediate_strides);
auto fl = out.flags();
fl.row_contiguous = rc;
fl.col_contiguous = cc;
fl.contiguous = true;
intermediate.set_data(
cu::malloc_async(intermediate.nbytes(), encoder),
data_size,
intermediate_strides,
fl,
allocator::free);
encoder.add_temporary(intermediate);
encoder.set_input_array(in);
encoder.set_output_array(intermediate);
dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(gpu_ptr<T>(in));
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, args, BN, outer);
int blocks = BM * BN / N_READS;
auto kernel = cu::
col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS, outer>;
encoder.add_kernel_node(
kernel,
grid,
blocks,
0,
indata,
gpu_ptr<U>(intermediate),
static_cast<cu::ColReduceArgs>(args),
out.size() / args.reduction_stride);
});
});
});
// Prepare the reduction arguments for the 2nd pass
cu::ColReduceArgs second_args = args;
second_args.reduction_size = outer;
second_args.reduction_stride = out.size();
second_args.ndim = 0;
second_args.reduce_shape[0] = outer;
second_args.reduce_strides[0] = out.size();
second_args.reduce_ndim = 1;
second_args.non_col_reductions = 1;
encoder.set_input_array(intermediate);
encoder.set_output_array(out);
dispatch_all_types(intermediate.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ndim(second_args.reduce_ndim, [&](auto reduce_ndim) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, second_args, BN);
int blocks = BM * BN / N_READS;
auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
encoder.add_kernel_node(
kernel,
grid,
blocks,
0,
gpu_ptr<T>(intermediate),
gpu_ptr<U>(out),
second_args,
second_args.reduction_stride);
});
});
});
}
void col_reduce( void col_reduce(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
const array& in, const array& in,
@@ -467,18 +334,6 @@ void col_reduce(
// It is a general strided reduce. Each threadblock computes the output for // It is a general strided reduce. Each threadblock computes the output for
// a subrow of the fast moving axis. For instance 32 elements. // a subrow of the fast moving axis. For instance 32 elements.
// //
// - col_reduce_small
//
// It is a column reduce for small columns. Each thread loops over the whole
// column without communicating with any other thread.
//
// - col_reduce_two_pass
//
// It is a reduce for long columns. To increase parallelism, we split the
// reduction in two passes. First we do a column reduce where many
// threadblocks operate on different parts of the reduced axis. Then we
// perform a final column reduce.
//
// Notes: As in row reduce we opt to read as much in order as possible and // Notes: As in row reduce we opt to read as much in order as possible and
// leave transpositions as they are (contrary to our Metal backend). // leave transpositions as they are (contrary to our Metal backend).
// //
@@ -494,14 +349,6 @@ void col_reduce(
return; return;
} }
// Long column with smallish row
size_t total_sums = args.non_col_reductions * args.reduction_size;
size_t approx_threads = out.size();
if (total_sums / approx_threads > 32) {
col_reduce_two_pass(encoder, in, out, reduce_type, axes, plan, args);
return;
}
// Fallback col reduce // Fallback col reduce
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args); col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
} }

View File

@@ -80,6 +80,7 @@ CudaGraph::CudaGraph(cu::Device& device) {
} }
void CudaGraph::end_capture(cudaStream_t stream) { void CudaGraph::end_capture(cudaStream_t stream) {
assert(handle_ == nullptr);
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_)); CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
} }

View File

@@ -7,6 +7,8 @@
namespace mlx::core { namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s);
void copy_gpu(const array& in, array& out, CopyType ctype) { void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream()); copy_gpu(in, out, ctype, out.primitive().stream());
} }

View File

@@ -149,9 +149,7 @@ Buffer MetalAllocator::malloc(size_t size) {
buf = device_->newBuffer(size, resource_options); buf = device_->newBuffer(size, resource_options);
} }
if (!buf) { if (!buf) {
std::ostringstream msg; return Buffer{nullptr};
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
} }
lk.lock(); lk.lock();
num_resources_++; num_resources_++;
@@ -203,32 +201,6 @@ size_t MetalAllocator::size(Buffer buffer) const {
return static_cast<MTL::Buffer*>(buffer.ptr())->length(); return static_cast<MTL::Buffer*>(buffer.ptr())->length();
} }
Buffer MetalAllocator::make_buffer(void* ptr, size_t size) {
auto buf = device_->newBuffer(ptr, size, resource_options, nullptr);
if (!buf) {
return Buffer{nullptr};
}
std::unique_lock lk(mutex_);
residency_set_.insert(buf);
active_memory_ += buf->length();
peak_memory_ = std::max(peak_memory_, active_memory_);
num_resources_++;
return Buffer{static_cast<void*>(buf)};
}
void MetalAllocator::release(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
if (buf == nullptr) {
return;
}
std::unique_lock lk(mutex_);
active_memory_ -= buf->length();
num_resources_--;
lk.unlock();
auto pool = metal::new_scoped_memory_pool();
buf->release();
}
MetalAllocator& allocator() { MetalAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of MetalAllocator // By creating the |allocator_| on heap, the destructor of MetalAllocator
// will not be called on exit and buffers in the cache will be leaked. This // will not be called on exit and buffers in the cache will be leaked. This

View File

@@ -21,9 +21,6 @@ class MetalAllocator : public allocator::Allocator {
virtual Buffer malloc(size_t size) override; virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override; virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override; virtual size_t size(Buffer buffer) const override;
virtual Buffer make_buffer(void* ptr, size_t size) override;
virtual void release(Buffer buffer) override;
size_t get_active_memory() { size_t get_active_memory() {
return active_memory_; return active_memory_;
}; };

View File

@@ -25,7 +25,6 @@ class CommonAllocator : public Allocator {
virtual Buffer malloc(size_t size) override; virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override; virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override; virtual size_t size(Buffer buffer) const override;
size_t get_active_memory() const { size_t get_active_memory() const {
return active_memory_; return active_memory_;
}; };

View File

@@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
auditwheel repair dist/* \ auditwheel repair dist/* \
--plat manylinux_2_35_${1} \ --plat manylinux_2_35_x86_64 \
--exclude libcublas* \ --exclude libcublas* \
--exclude libnvrtc* \ --exclude libnvrtc* \
--exclude libcuda* \ --exclude libcuda* \

View File

@@ -210,14 +210,6 @@ class TestReduce(mlx_tests.MLXTestCase):
ref = getattr(np, op)(np_arr, axis=axis) ref = getattr(np, op)(np_arr, axis=axis)
self.assertTrue(np.array_equal(out, ref, equal_nan=True)) self.assertTrue(np.array_equal(out, ref, equal_nan=True))
def test_long_column(self):
a = (np.random.randn(8192, 64) * 32).astype(np.int32)
b = mx.array(a)
c1 = a.sum(0)
c2 = b.sum(0)
self.assertTrue(np.all(c1 == c2))
if __name__ == "__main__": if __name__ == "__main__":
mlx_tests.MLXTestRunner(failfast=True) mlx_tests.MLXTestRunner(failfast=True)

View File

@@ -7,21 +7,13 @@ import re
import subprocess import subprocess
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from subprocess import run
from setuptools import Command, Extension, find_namespace_packages, setup from setuptools import Command, Extension, find_namespace_packages, setup
from setuptools.command.bdist_wheel import bdist_wheel from setuptools.command.bdist_wheel import bdist_wheel
from setuptools.command.build_ext import build_ext from setuptools.command.build_ext import build_ext
def cuda_toolkit_major_version():
out = subprocess.check_output(["nvcc", "--version"], stderr=subprocess.STDOUT)
text = out.decode()
m = re.search(r"release (\d+)", text)
if m:
return int(m.group(1))
return None
def get_version(): def get_version():
with open("mlx/version.h", "r") as fid: with open("mlx/version.h", "r") as fid:
for l in fid: for l in fid:
@@ -39,7 +31,7 @@ def get_version():
version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}" version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}"
if not pypi_release and not dev_release: if not pypi_release and not dev_release:
git_hash = ( git_hash = (
subprocess.run( run(
"git rev-parse --short HEAD".split(), "git rev-parse --short HEAD".split(),
capture_output=True, capture_output=True,
check=True, check=True,
@@ -292,11 +284,7 @@ if __name__ == "__main__":
install_requires.append( install_requires.append(
f'mlx-metal=={version}; platform_system == "Darwin"' f'mlx-metal=={version}; platform_system == "Darwin"'
) )
extras["cuda"] = [f'mlx-cuda-12=={version}; platform_system == "Linux"'] extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"']
for toolkit in [12, 13]:
extras[f"cuda{toolkit}"] = [
f'mlx-cuda-{toolkit}=={version}; platform_system == "Linux"'
]
extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"'] extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
_setup( _setup(
@@ -311,25 +299,13 @@ if __name__ == "__main__":
if build_macos: if build_macos:
name = "mlx-metal" name = "mlx-metal"
elif build_cuda: elif build_cuda:
toolkit = cuda_toolkit_major_version() name = "mlx-cuda"
name = f"mlx-cuda-{toolkit}"
if toolkit == 12:
install_requires += [ install_requires += [
"nvidia-cublas-cu12==12.9.*", "nvidia-cublas-cu12==12.9.*",
"nvidia-cuda-nvrtc-cu12==12.9.*", "nvidia-cuda-nvrtc-cu12==12.9.*",
"nvidia-cudnn-cu12==9.*",
"nvidia-nccl-cu12",
] ]
elif toolkit == 13:
install_requires += [
"nvidia-cublas-cu13",
"nvidia-cuda-nvrtc-cu13",
]
else:
raise ValueError(f"Unknown toolkit {toolkit}")
install_requires += [
f"nvidia-cudnn-cu{toolkit}==9.*",
f"nvidia-nccl-cu{toolkit}",
]
else: else:
name = "mlx-cpu" name = "mlx-cpu"
_setup( _setup(

View File

@@ -1,4 +1,5 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <climits> #include <climits>
#include "doctest/doctest.h" #include "doctest/doctest.h"
@@ -607,24 +608,3 @@ TEST_CASE("test make empty array") {
CHECK_EQ(a.size(), 0); CHECK_EQ(a.size(), 0);
CHECK_EQ(a.dtype(), bool_); CHECK_EQ(a.dtype(), bool_);
} }
TEST_CASE("test make array from user buffer") {
int size = 4096;
std::vector<int> buffer(size, 0);
int count = 0;
auto deleter = [&count](void*) { count++; };
{
auto a = array(buffer.data(), Shape{size}, int32, deleter);
if (metal::is_available()) {
CHECK_EQ(buffer.data(), a.data<int>());
}
auto b = a + array(1);
eval(b);
auto expected = ones({4096});
CHECK(array_equal(b, expected).item<bool>());
}
// deleter should always get called
CHECK_EQ(count, 1);
}