Compare commits

..

13 Commits

Author SHA1 Message Date
Angelos Katharopoulos
dd91ee9534 Refactoring launcher 2025-12-08 02:57:50 -08:00
Angelos Katharopoulos
8fab4f0929 Change the name to a fun pun 2025-12-04 14:20:52 -08:00
Angelos Katharopoulos
47af2c8cb0 Add headers for gcc 2025-12-04 14:20:52 -08:00
Angelos Katharopoulos
f40152ebc1 Expose per-backend availability in C++ and python 2025-12-04 14:20:52 -08:00
Angelos Katharopoulos
5d7e6a0642 Add a no_ibv 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
b9b78b1059 Add empty sum_scatter 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
45727b0c02 Add send/recv 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
2444fbdfe9 Make sure that there is space for work completions 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
f3b605e53c Add working reduce and semi-working all gather 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
0388ae3aaf Fix ring 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
d4c1de4a8b Fix side channel initialization for more than 2 peers 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
4dbffb3954 All gather 2025-12-04 14:20:51 -08:00
Angelos Katharopoulos
b1a60b2d2d Initial working all reduce 2025-12-04 14:20:51 -08:00
24 changed files with 105 additions and 426 deletions

View File

@@ -1,15 +1,6 @@
name: 'Build CUDA wheel'
description: 'Build CUDA wheel'
inputs:
arch:
description: 'Platform architecture tag'
required: true
type: choice
options:
- x86_64
- aarch64
runs:
using: "composite"
steps:
@@ -21,4 +12,4 @@ runs:
pip install auditwheel build patchelf setuptools
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
bash python/scripts/repair_cuda.sh ${{ inputs.arch }}
bash python/scripts/repair_cuda.sh

View File

@@ -15,7 +15,6 @@ runs:
using: "composite"
steps:
- name: Use ccache
if: ${{ runner.arch == 'x86_64' }}
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}

View File

@@ -128,11 +128,7 @@ jobs:
build_cuda_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
arch: ['x86_64', 'aarch64']
toolkit: ['cuda-12.9', 'cuda-13.0']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }}
runs-on: ubuntu-22-large
env:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
@@ -140,11 +136,9 @@ jobs:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: ${{ matrix.toolkit }}
toolkit: 'cuda-12.9'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
arch: ${{ matrix.arch }}
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:

View File

@@ -29,20 +29,17 @@ MLX has a CUDA backend which you can install with:
.. code-block:: shell
pip install mlx[cuda12]
pip install mlx[cuda]
To install the CUDA package from PyPi your system must meet the following
requirements:
- Nvidia architecture >= SM 7.5
- Nvidia architecture >= SM 7.0 (Volta)
- Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35
- Python >= 3.10
For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires
an Nvidia driver >= 580 or an appropriate CUDA compatibility package.
CPU-only (Linux)
^^^^^^^^^^^^^^^^

View File

@@ -1,6 +1,7 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp

24
mlx/allocator.cpp Normal file
View File

@@ -0,0 +1,24 @@
// Copyright © 2023 Apple Inc.
#include <cstdlib>
#include <sstream>
#include "mlx/allocator.h"
namespace mlx::core::allocator {
Buffer malloc(size_t size) {
auto buffer = allocator().malloc(size);
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
void free(Buffer buffer) {
allocator().free(buffer);
}
} // namespace mlx::core::allocator

View File

@@ -28,16 +28,16 @@ class Buffer {
};
};
Buffer malloc(size_t size);
void free(Buffer buffer);
class Allocator {
/** Abstract base class for a memory allocator. */
public:
virtual Buffer malloc(size_t size) = 0;
virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0;
virtual Buffer make_buffer(void* ptr, size_t size) {
return Buffer{nullptr};
};
virtual void release(Buffer buffer) {}
Allocator() = default;
Allocator(const Allocator& other) = delete;
@@ -49,25 +49,4 @@ class Allocator {
Allocator& allocator();
inline Buffer malloc(size_t size) {
return allocator().malloc(size);
}
inline void free(Buffer buffer) {
allocator().free(buffer);
}
// Make a Buffer from a raw pointer of the given size without a copy. If a
// no-copy conversion is not possible then the returned buffer.ptr() will be
// nullptr. Any buffer created with this function must be released with
// release(buffer)
inline Buffer make_buffer(void* ptr, size_t size) {
return allocator().make_buffer(ptr, size);
};
// Release a buffer from the allocator made with make_buffer
inline void release(Buffer buffer) {
allocator().release(buffer);
}
} // namespace mlx::core::allocator

View File

@@ -82,28 +82,6 @@ array::array(std::initializer_list<int> data, Dtype dtype)
init(data.begin());
}
array::array(
void* data,
Shape shape,
Dtype dtype,
const std::function<void(void*)>& deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
auto buffer = allocator::make_buffer(data, nbytes());
if (buffer.ptr() == nullptr) {
set_data(allocator::malloc(nbytes()));
auto ptr = static_cast<char*>(data);
std::copy(ptr, ptr + nbytes(), this->data<char>());
deleter(data);
} else {
auto wrapped_deleter = [deleter](allocator::Buffer buffer) {
auto ptr = buffer.ptr();
allocator::release(buffer);
return deleter(ptr);
};
set_data(buffer, std::move(wrapped_deleter));
}
}
/* Build an array from a shared buffer */
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {

View File

@@ -57,16 +57,6 @@ class array {
Shape shape,
Dtype dtype = TypeToDtype<T>());
/* Build an array from a raw pointer. The constructor will attempt to use the
* input data without a copy. The deleter will be called when the array no
* longer needs the underlying memory - after the array is destroyed in the
* no-copy case and after the copy otherwise. */
explicit array(
void* data,
Shape shape,
Dtype dtype,
const std::function<void(void*)>& deleter);
/* Build an array from a buffer */
explicit array(
allocator::Buffer data,

View File

@@ -20,19 +20,6 @@ constexpr int page_size = 16384;
// Any allocations smaller than this will try to use the small pool
constexpr int small_block_size = 8;
#if CUDART_VERSION >= 13000
inline cudaMemLocation cuda_mem_loc(int i) {
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = i;
return loc;
}
#else
inline int cuda_mem_loc(int i) {
return i;
}
#endif // CUDART_VERSION >= 13000
// The small pool size in bytes. This should be a multiple of the host page
// size and small_block_size.
constexpr int small_pool_size = 4 * page_size;
@@ -48,7 +35,13 @@ SmallSizePool::SmallSizePool() {
int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
for (int i = 0; i < device_count; ++i) {
auto loc = cuda_mem_loc(i);
#if CUDART_VERSION >= 13000
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = i;
#else
int loc = i;
#endif // CUDART_VERSION >= 13000
CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
}
@@ -97,10 +90,9 @@ CudaAllocator::CudaAllocator()
page_size,
[](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) {
size_t free;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total_memory_));
memory_limit_ = total_memory_ * 0.95;
free_limit_ = total_memory_ - memory_limit_;
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.9;
max_pool_size_ = memory_limit_;
int device_count = 0;
@@ -112,10 +104,6 @@ CudaAllocator::CudaAllocator()
cudaStream_t s;
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
free_streams_.push_back(s);
cudaMemPool_t mem_pool;
CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&mem_pool, i));
mem_pools_.push_back(mem_pool);
}
CHECK_CUDA_ERROR(cudaSetDevice(curr));
}
@@ -166,35 +154,23 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
}
lock.unlock();
if (!buf) {
cudaError_t err;
void* data = nullptr;
if (device == -1) {
CHECK_CUDA_ERROR(cudaMallocManaged(&data, size));
err = cudaMallocManaged(&data, size);
} else {
CHECK_CUDA_ERROR(cudaMallocAsync(&data, size, stream));
err = cudaMallocAsync(&data, size, stream);
}
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
if (!data) {
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
return Buffer{nullptr};
}
buf = new CudaBuffer{data, size, device};
}
lock.lock();
// If any cuda memory pool has too much reserved memory, clear some
// memory from the cache. This prevents graph / kernel execution failing
// from OOM
if (get_cache_memory() > 0) {
for (auto p : mem_pools_) {
size_t used = 0;
CHECK_CUDA_ERROR(cudaMemPoolGetAttribute(
p, cudaMemPoolAttrReservedMemCurrent, &used));
if (used > (total_memory_ - free_limit_)) {
buffer_cache_.release_cached_buffers(free_limit_);
break;
}
}
}
}
active_memory_ += buf->size;
peak_memory_ = std::max(active_memory_, peak_memory_);

View File

@@ -71,14 +71,11 @@ class CudaAllocator : public allocator::Allocator {
std::mutex mutex_;
size_t memory_limit_;
size_t free_limit_;
size_t total_memory_;
size_t max_pool_size_;
BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0};
size_t peak_memory_{0};
std::vector<cudaStream_t> free_streams_;
std::vector<cudaMemPool_t> mem_pools_;
SmallSizePool scalar_pool_;
};

View File

@@ -95,14 +95,11 @@ void copy_general_input(
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size();
int work_per_thread = 8;
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4 && dim0 < 8) {
if (dim0 >= 4) {
work_per_thread = 4;
} else if (dim0 < 4) {
work_per_thread = 1;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
@@ -113,10 +110,7 @@ void copy_general_input(
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
if (work_per_thread == 8) {
kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 8>;
} else if (work_per_thread == 4) {
if (work_per_thread == 4) {
kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
}
@@ -133,9 +127,7 @@ void copy_general_input(
});
} else { // ndim >= 4
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
if (work_per_thread == 8) {
kernel = cu::copy_g<InType, OutType, IdxT, 8>;
} else if (work_per_thread == 4) {
if (work_per_thread == 4) {
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
}
encoder.add_kernel_node(

View File

@@ -318,52 +318,46 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
insert_graph_dependencies(GraphNode{node, "K"});
}
std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
// Constructs a key representing the nodes of a sub-graph.
// Also checks if the sub-graph is updatable as CUDA graphs do not get
// updated correctly if a kernel node getting updated has a different cluster
// shape than the node it's being updated with.
std::string key = "(";
bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
// CUDA graphs do not get updated correctly if a kernel node getting updated
// has a different cluster shape than the node it's being updated with.
size_t num_nodes = 0;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));
if (num_nodes == 0) {
return {key + ")", true};
return true;
}
bool is_updatable = true;
std::vector<cudaGraphNode_t> nodes(num_nodes);
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
for (const auto& node : nodes) {
if (!is_updatable) {
break;
}
cudaGraphNodeType type;
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
if (type == cudaGraphNodeTypeGraph) {
// Try to be updatable for a structure like graph -> graph -> kernel
if (num_nodes > 1) {
return false;
}
cudaGraph_t child;
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
auto [subkey, sub_is_updatable] = subgraph_to_key(child);
is_updatable &= sub_is_updatable;
key += subkey;
} else if (type == cudaGraphNodeTypeMemset) {
key += "M";
return is_graph_updatable(child, cluster_dim_x);
} else if (type != cudaGraphNodeTypeKernel) {
is_updatable = false;
return false;
} else {
cudaLaunchAttributeValue cluster_dim;
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
// Only allow dim.x to be greater than 1
// Only dim.x can be greater than 1
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
is_updatable = false;
} else {
key += "K";
key += std::to_string(cluster_dim.clusterDim.x);
return false;
}
// Only one child node allowed when subgraph uses clusters
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
return false;
}
cluster_dim_x = cluster_dim.clusterDim.x;
}
}
}
key += ")";
return {key, is_updatable};
return true;
}
void CommandEncoder::add_graph_node(cudaGraph_t child) {
@@ -376,10 +370,11 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
return;
}
cudaGraphNode_t node;
auto [sub_graph_key, is_updatable] = subgraph_to_key(child);
is_graph_updatable_ &= is_updatable;
int cluster_dim_x = 0;
is_graph_updatable_ &= is_graph_updatable(child, cluster_dim_x);
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies(GraphNode{node, sub_graph_key});
insert_graph_dependencies(
GraphNode{node, "G" + std::to_string(cluster_dim_x)});
}
bool CommandEncoder::needs_commit() {

View File

@@ -106,7 +106,7 @@ class CommandEncoder {
cudaGraphNode_t node;
// K = kernel
// E = empty
// () = subgraph (with metadata)
// G* = subgraph (with metadata)
// Symbols ':', '-' are reserved as separators
std::string node_type;
std::string id;

View File

@@ -89,13 +89,9 @@ template <
int NDIM,
int BM,
int BN,
int N_READS = 4,
int BLOCKS = 1>
__global__ void col_reduce_looped(
T* in,
U* out,
const __grid_constant__ ColReduceArgs args,
int64_t out_size) {
int N_READS = 4>
__global__ void
col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
@@ -106,8 +102,6 @@ __global__ void col_reduce_looped(
size_t tile_idx = grid.block_rank();
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
size_t tile_out = tile_y / out_size;
tile_y = tile_y % out_size;
// Compute the indices for the thread within the tile
short thread_x = block.thread_rank() % threads_per_row;
@@ -124,23 +118,12 @@ __global__ void col_reduce_looped(
totals[i] = ReduceInit<Op, T>::value();
}
size_t total = args.non_col_reductions * args.reduction_size;
size_t per_block, start, end;
if constexpr (BLOCKS > 1) {
per_block = (total + BLOCKS - 1) / BLOCKS;
start = tile_out * per_block + thread_y;
end = min((tile_out + 1) * per_block, total);
} else {
per_block = total;
start = thread_y;
end = total;
}
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(start, args.reduce_shape.data(), args.reduce_strides.data());
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
size_t total = args.non_col_reductions * args.reduction_size;
if (tile_x * BN + BN <= args.reduction_stride) {
if (args.reduction_stride % N_READS == 0) {
for (size_t r = start; r < end; r += BM) {
for (size_t r = thread_y; r < total; r += BM) {
T vals[N_READS];
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) {
@@ -149,7 +132,7 @@ __global__ void col_reduce_looped(
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
}
} else {
for (size_t r = start; r < end; r += BM) {
for (size_t r = thread_y; r < total; r += BM) {
T vals[N_READS];
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) {
@@ -159,7 +142,7 @@ __global__ void col_reduce_looped(
}
}
} else {
for (size_t r = start; r < end; r += BM) {
for (size_t r = thread_y; r < total; r += BM) {
T vals[N_READS];
cub::LoadDirectBlocked(
thread_x,
@@ -190,9 +173,6 @@ __global__ void col_reduce_looped(
// Write result.
if (warp.thread_rank() == 0) {
if (BLOCKS > 1) {
out += tile_out * out_size * args.reduction_stride;
}
cub::StoreDirectBlocked(
warp.meta_group_rank(),
out + tile_y * args.reduction_stride + tile_x * BN,
@@ -247,12 +227,11 @@ __global__ void col_reduce_small(
inline auto output_grid_for_col_reduce(
const array& out,
const cu::ColReduceArgs& args,
int bn,
int outer = 1) {
int bn) {
int gx, gy = 1;
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
size_t n_outer_blocks = out.size() / args.reduction_stride;
size_t n_blocks = n_outer_blocks * n_inner_blocks * outer;
size_t n_blocks = n_outer_blocks * n_inner_blocks;
while (n_blocks / gy > INT32_MAX) {
gy *= 2;
}
@@ -298,8 +277,7 @@ void col_reduce_looped(
0,
indata,
gpu_ptr<U>(out),
static_cast<cu::ColReduceArgs>(args),
out.size() / args.reduction_stride);
static_cast<cu::ColReduceArgs>(args));
});
});
});
@@ -342,117 +320,6 @@ void col_reduce_small(
});
}
void col_reduce_two_pass(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes, encoder);
// Allocate an intermediate array to hold the 1st pass result
constexpr int outer = 32;
Shape intermediate_shape;
intermediate_shape.push_back(outer);
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end());
Strides intermediate_strides;
intermediate_strides.push_back(out.size());
intermediate_strides.insert(
intermediate_strides.end(), out.strides().begin(), out.strides().end());
array intermediate(intermediate_shape, out.dtype(), nullptr, {});
auto [data_size, rc, cc] =
check_contiguity(intermediate_shape, intermediate_strides);
auto fl = out.flags();
fl.row_contiguous = rc;
fl.col_contiguous = cc;
fl.contiguous = true;
intermediate.set_data(
cu::malloc_async(intermediate.nbytes(), encoder),
data_size,
intermediate_strides,
fl,
allocator::free);
encoder.add_temporary(intermediate);
encoder.set_input_array(in);
encoder.set_output_array(intermediate);
dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(gpu_ptr<T>(in));
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, args, BN, outer);
int blocks = BM * BN / N_READS;
auto kernel = cu::
col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS, outer>;
encoder.add_kernel_node(
kernel,
grid,
blocks,
0,
indata,
gpu_ptr<U>(intermediate),
static_cast<cu::ColReduceArgs>(args),
out.size() / args.reduction_stride);
});
});
});
// Prepare the reduction arguments for the 2nd pass
cu::ColReduceArgs second_args = args;
second_args.reduction_size = outer;
second_args.reduction_stride = out.size();
second_args.ndim = 0;
second_args.reduce_shape[0] = outer;
second_args.reduce_strides[0] = out.size();
second_args.reduce_ndim = 1;
second_args.non_col_reductions = 1;
encoder.set_input_array(intermediate);
encoder.set_output_array(out);
dispatch_all_types(intermediate.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ndim(second_args.reduce_ndim, [&](auto reduce_ndim) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, second_args, BN);
int blocks = BM * BN / N_READS;
auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
encoder.add_kernel_node(
kernel,
grid,
blocks,
0,
gpu_ptr<T>(intermediate),
gpu_ptr<U>(out),
second_args,
second_args.reduction_stride);
});
});
});
}
void col_reduce(
cu::CommandEncoder& encoder,
const array& in,
@@ -467,18 +334,6 @@ void col_reduce(
// It is a general strided reduce. Each threadblock computes the output for
// a subrow of the fast moving axis. For instance 32 elements.
//
// - col_reduce_small
//
// It is a column reduce for small columns. Each thread loops over the whole
// column without communicating with any other thread.
//
// - col_reduce_two_pass
//
// It is a reduce for long columns. To increase parallelism, we split the
// reduction in two passes. First we do a column reduce where many
// threadblocks operate on different parts of the reduced axis. Then we
// perform a final column reduce.
//
// Notes: As in row reduce we opt to read as much in order as possible and
// leave transpositions as they are (contrary to our Metal backend).
//
@@ -494,14 +349,6 @@ void col_reduce(
return;
}
// Long column with smallish row
size_t total_sums = args.non_col_reductions * args.reduction_size;
size_t approx_threads = out.size();
if (total_sums / approx_threads > 32) {
col_reduce_two_pass(encoder, in, out, reduce_type, axes, plan, args);
return;
}
// Fallback col reduce
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
}

View File

@@ -80,6 +80,7 @@ CudaGraph::CudaGraph(cu::Device& device) {
}
void CudaGraph::end_capture(cudaStream_t stream) {
assert(handle_ == nullptr);
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
}

View File

@@ -7,6 +7,8 @@
namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s);
void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream());
}

View File

@@ -149,9 +149,7 @@ Buffer MetalAllocator::malloc(size_t size) {
buf = device_->newBuffer(size, resource_options);
}
if (!buf) {
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
return Buffer{nullptr};
}
lk.lock();
num_resources_++;
@@ -203,32 +201,6 @@ size_t MetalAllocator::size(Buffer buffer) const {
return static_cast<MTL::Buffer*>(buffer.ptr())->length();
}
Buffer MetalAllocator::make_buffer(void* ptr, size_t size) {
auto buf = device_->newBuffer(ptr, size, resource_options, nullptr);
if (!buf) {
return Buffer{nullptr};
}
std::unique_lock lk(mutex_);
residency_set_.insert(buf);
active_memory_ += buf->length();
peak_memory_ = std::max(peak_memory_, active_memory_);
num_resources_++;
return Buffer{static_cast<void*>(buf)};
}
void MetalAllocator::release(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
if (buf == nullptr) {
return;
}
std::unique_lock lk(mutex_);
active_memory_ -= buf->length();
num_resources_--;
lk.unlock();
auto pool = metal::new_scoped_memory_pool();
buf->release();
}
MetalAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of MetalAllocator
// will not be called on exit and buffers in the cache will be leaked. This

View File

@@ -21,9 +21,6 @@ class MetalAllocator : public allocator::Allocator {
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
virtual Buffer make_buffer(void* ptr, size_t size) override;
virtual void release(Buffer buffer) override;
size_t get_active_memory() {
return active_memory_;
};

View File

@@ -25,7 +25,6 @@ class CommonAllocator : public Allocator {
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
size_t get_active_memory() const {
return active_memory_;
};

View File

@@ -1,7 +1,7 @@
#!/bin/bash
auditwheel repair dist/* \
--plat manylinux_2_35_${1} \
--plat manylinux_2_35_x86_64 \
--exclude libcublas* \
--exclude libnvrtc* \
--exclude libcuda* \

View File

@@ -210,14 +210,6 @@ class TestReduce(mlx_tests.MLXTestCase):
ref = getattr(np, op)(np_arr, axis=axis)
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
def test_long_column(self):
a = (np.random.randn(8192, 64) * 32).astype(np.int32)
b = mx.array(a)
c1 = a.sum(0)
c2 = b.sum(0)
self.assertTrue(np.all(c1 == c2))
if __name__ == "__main__":
mlx_tests.MLXTestRunner(failfast=True)

View File

@@ -7,21 +7,13 @@ import re
import subprocess
from functools import partial
from pathlib import Path
from subprocess import run
from setuptools import Command, Extension, find_namespace_packages, setup
from setuptools.command.bdist_wheel import bdist_wheel
from setuptools.command.build_ext import build_ext
def cuda_toolkit_major_version():
out = subprocess.check_output(["nvcc", "--version"], stderr=subprocess.STDOUT)
text = out.decode()
m = re.search(r"release (\d+)", text)
if m:
return int(m.group(1))
return None
def get_version():
with open("mlx/version.h", "r") as fid:
for l in fid:
@@ -39,7 +31,7 @@ def get_version():
version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}"
if not pypi_release and not dev_release:
git_hash = (
subprocess.run(
run(
"git rev-parse --short HEAD".split(),
capture_output=True,
check=True,
@@ -292,11 +284,7 @@ if __name__ == "__main__":
install_requires.append(
f'mlx-metal=={version}; platform_system == "Darwin"'
)
extras["cuda"] = [f'mlx-cuda-12=={version}; platform_system == "Linux"']
for toolkit in [12, 13]:
extras[f"cuda{toolkit}"] = [
f'mlx-cuda-{toolkit}=={version}; platform_system == "Linux"'
]
extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"']
extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
_setup(
@@ -311,25 +299,13 @@ if __name__ == "__main__":
if build_macos:
name = "mlx-metal"
elif build_cuda:
toolkit = cuda_toolkit_major_version()
name = f"mlx-cuda-{toolkit}"
if toolkit == 12:
name = "mlx-cuda"
install_requires += [
"nvidia-cublas-cu12==12.9.*",
"nvidia-cuda-nvrtc-cu12==12.9.*",
"nvidia-cudnn-cu12==9.*",
"nvidia-nccl-cu12",
]
elif toolkit == 13:
install_requires += [
"nvidia-cublas-cu13",
"nvidia-cuda-nvrtc-cu13",
]
else:
raise ValueError(f"Unknown toolkit {toolkit}")
install_requires += [
f"nvidia-cudnn-cu{toolkit}==9.*",
f"nvidia-nccl-cu{toolkit}",
]
else:
name = "mlx-cpu"
_setup(

View File

@@ -1,4 +1,5 @@
// Copyright © 2023 Apple Inc.
#include <climits>
#include "doctest/doctest.h"
@@ -607,24 +608,3 @@ TEST_CASE("test make empty array") {
CHECK_EQ(a.size(), 0);
CHECK_EQ(a.dtype(), bool_);
}
TEST_CASE("test make array from user buffer") {
int size = 4096;
std::vector<int> buffer(size, 0);
int count = 0;
auto deleter = [&count](void*) { count++; };
{
auto a = array(buffer.data(), Shape{size}, int32, deleter);
if (metal::is_available()) {
CHECK_EQ(buffer.data(), a.data<int>());
}
auto b = a + array(1);
eval(b);
auto expected = ones({4096});
CHECK(array_equal(b, expected).item<bool>());
}
// deleter should always get called
CHECK_EQ(count, 1);
}