Compare commits

..

14 Commits

Author SHA1 Message Date
Angelos Katharopoulos
fd1d0821d2 Make sure softmax doesn't change the actual max 2025-06-22 23:34:32 -07:00
Angelos Katharopoulos
818e8e663e Add an init reduce 2025-06-22 21:28:41 -07:00
Angelos Katharopoulos
cc4b995723 Working col reduce 2025-06-21 23:39:40 -07:00
Angelos Katharopoulos
664d8e42b8 Add comments and clean up 2025-06-21 13:03:37 -07:00
Angelos Katharopoulos
abdb21f27c Add helpers and atomic kernel 2025-06-21 12:37:35 -07:00
Angelos Katharopoulos
880751a084 Remove segmented reduce and fix row reduce 2025-06-20 21:49:15 -07:00
Angelos Katharopoulos
cd523ffd9f Working row reduce looped 2025-06-20 21:49:15 -07:00
Angelos Katharopoulos
4d2b682a13 Simple row reduce 2025-06-20 21:49:15 -07:00
Angelos Katharopoulos
b70a964cde Optimize all reduce a bit 2025-06-20 21:49:15 -07:00
Angelos Katharopoulos
9cf7ef1068 Add all reduce and atomic updates 2025-06-20 21:49:15 -07:00
Angelos Katharopoulos
ab7c310914 Adapt the torch benchmark to run in CUDA 2025-06-20 21:49:15 -07:00
Angelos Katharopoulos
5adf185f86
Fix update_modules() when providing a subset (#2308) 2025-06-20 17:19:46 -07:00
Awni Hannun
c9a9180584
Cuda perf tuning (#2307)
* perf tuning

* fix adding inputs arrays in matmul / srot

* format

* fix
2025-06-20 14:50:57 -07:00
Awni Hannun
76831ed83d
Build CUDA release in Circle (#2306)
* cuda release

* add license
2025-06-19 15:26:36 -07:00
24 changed files with 559 additions and 436 deletions

View File

@ -16,6 +16,9 @@ parameters:
linux_release:
type: boolean
default: false
cuda_release:
type: boolean
default: false
jobs:
build_documentation:
@ -104,7 +107,7 @@ jobs:
command: |
echo "stubs"
pip install typing_extensions
python setup.py generate_stubs
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
@ -162,7 +165,7 @@ jobs:
command: |
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
@ -223,7 +226,6 @@ jobs:
command: |
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
python -m venv env
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
@ -283,7 +285,7 @@ jobs:
command: |
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
python setup.py generate_stubs
- run:
name: Build Python package
command: |
@ -342,7 +344,7 @@ jobs:
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
pip install . -v
pip install typing_extensions
python setup.py generate_stubs
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python -m build --wheel
@ -356,6 +358,48 @@ jobs:
- store_artifacts:
path: wheelhouse/
build_cuda_release:
parameters:
python_version:
type: string
default: "3.9"
extra_env:
type: string
default: "DEV_RELEASE=1"
machine:
image: linux-cuda-12:default
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- run:
name: Build wheel
command: |
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
python -m venv env
source env/bin/activate
pip install auditwheel
pip install patchelf
pip install build
pip install twine
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install ".[dev]" -v
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build --wheel
bash python/scripts/repair_cuda.sh
- run:
name: Upload package
command: |
source env/bin/activate
twine upload wheelhouse/*.whl
- store_artifacts:
path: wheelhouse/
workflows:
build_and_test:
when:
@ -625,3 +669,14 @@ workflows:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]
cuda_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.cuda_release >>
jobs:
- build_cuda_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]

View File

@ -30,6 +30,16 @@ MLX is also available on conda-forge. To install MLX with conda do:
conda install conda-forge::mlx
CUDA
^^^^
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
.. code-block:: shell
pip install mlx-cuda
Troubleshooting
^^^^^^^^^^^^^^^
@ -65,6 +75,8 @@ Build Requirements
Python API
^^^^^^^^^^
.. _python install:
To build and install the MLX python library from source, first, clone MLX from
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
@ -107,6 +119,8 @@ IDE:
C++ API
^^^^^^^
.. _cpp install:
Currently, MLX must be built and installed from source.
Similarly to the python library, to build and install the MLX C++ library start
@ -185,6 +199,7 @@ should point to the path to the built metal library.
xcrun -sdk macosx --show-sdk-version
Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~
@ -213,6 +228,50 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists across reboots.
Linux
^^^^^
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
For example on Ubuntu, run the following:
.. code-block:: shell
apt-get update -y
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
From here follow the instructions to install either the :ref:`Python <python
install>` or :ref:`C++ <cpp install>` APIs.
CUDA
^^^^
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
and the CUDA toolkit. For example on Ubuntu, run the following:
.. code-block:: shell
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y
apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
When building either the Python or C++ APIs make sure to pass the cmake flag
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
To build the C++ package run:
.. code-block:: shell
mkdir -p build && cd build
cmake .. -DMLX_BUILD_CUDA=ON && make -j
Troubleshooting
^^^^^^^^^^^^^^^

View File

@ -5,11 +5,9 @@
namespace mlx::core {
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
Shape shape,
Strides strides,
const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();
for (int i = axes.size() - 1; i >= 0; i--) {
int a = axes[i];
shape.erase(shape.begin() + a);
@ -19,6 +17,15 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
return std::make_pair(shape, strides);
}
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();
return shapes_without_reduction_axes(
std::move(shape), std::move(strides), axes);
}
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&

View File

@ -51,5 +51,9 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes);
std::pair<Shape, Strides> shapes_without_reduction_axes(
Shape shape,
Strides strides,
const std::vector<int>& axes);
} // namespace mlx::core

View File

@ -31,6 +31,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu

View File

@ -3,6 +3,7 @@
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/utils.h"
#include <cuda_runtime.h>
#include <fmt/format.h>
@ -14,9 +15,11 @@ namespace mlx::core {
namespace cu {
constexpr int page_size = 16384;
CudaAllocator::CudaAllocator()
: buffer_cache_(
getpagesize(),
page_size,
[](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) {
cuda_free(buf->data);
@ -31,7 +34,14 @@ CudaAllocator::CudaAllocator()
Buffer CudaAllocator::malloc(size_t size) {
// Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_);
if (size < page_size) {
size = next_power_of_2(size);
} else {
size = page_size * ((size + page_size - 1) / page_size);
}
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) {
// If we have a lot of memory pressure or are over the maximum cache size,

View File

@ -24,7 +24,6 @@ void copy_gpu_inplace(
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
return;

View File

@ -114,7 +114,7 @@ void CommandEncoder::synchronize() {
std::future<void> f = p->get_future();
add_completed_handler([p = std::move(p)]() { p->set_value(); });
worker_.end_batch();
worker_.commit();
commit();
f.wait();
}

View File

@ -155,8 +155,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
#pragma unroll
for (int i = NDIM - 1; i >= 0; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * IdxT(b_strides[i]);
elem /= shape[i];
}
return cuda::std::make_tuple(a_loc, b_loc);
@ -175,9 +175,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
#pragma unroll
for (int i = NDIM - 1; i >= 0; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
c_loc += dim_idx * c_strides[i];
a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * IdxT(b_strides[i]);
c_loc += dim_idx * IdxT(c_strides[i]);
elem /= shape[i];
}
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
@ -206,8 +206,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
IdxT b_loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * IdxT(b_strides[i]);
elem /= shape[i];
}
return cuda::std::make_tuple(a_loc, b_loc);
@ -226,9 +226,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
IdxT c_loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
c_loc += dim_idx * c_strides[i];
a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * IdxT(b_strides[i]);
c_loc += dim_idx * IdxT(c_strides[i]);
elem /= shape[i];
}
return cuda::std::make_tuple(a_loc, b_loc, c_loc);

View File

@ -162,11 +162,15 @@ class MatMul {
}
}
array workspace(
allocator::malloc(heuristic_.workspaceSize),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);
void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) {
array workspace(
allocator::malloc(heuristic_.workspaceSize),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);
workspace_ptr = workspace.data<void>();
}
encoder.launch_kernel([&](cudaStream_t stream) {
CHECK_CUBLAS_ERROR(cublasLtMatmul(
@ -183,8 +187,8 @@ class MatMul {
out,
out_desc_,
&heuristic_.algo,
workspace.data<void>(),
workspace.nbytes(),
workspace_ptr,
heuristic_.workspaceSize,
stream));
});
}
@ -358,9 +362,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
a_batch_strides.back(),
b_batch_strides.back());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto nbatch = batch_count / batch_shape.back();
if (nbatch == 1) {
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
return;
}
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
for (size_t i = 0; i < nbatch; ++i) {
matmul.run(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
@ -444,10 +457,28 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
b_batch_strides.back(),
c_batch_strides.back());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto nbatch = batch_count / batch_shape.back();
if (nbatch == 1) {
matmul.run(
encoder,
out.data<int8_t>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
alpha_,
beta_);
return;
}
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
for (size_t i = 0; i < nbatch; ++i) {
matmul.run(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,

View File

@ -25,7 +25,8 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(s);
if (in.size() == 0) {
throw std::runtime_error("Should never reach here.");
init_reduce(encoder, in, out, reduce_type_);
return;
}
// Reduce.

View File

@ -15,6 +15,9 @@ namespace cg = cooperative_groups;
template <typename T, typename U, typename ReduceOp, int N = 4>
__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
// TODO: Process multiple "rows" in each thread
constexpr int M = 1;
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
@ -23,10 +26,8 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
ReduceOp op;
T vals[N];
U accs[N];
for (int i = 0; i < N; i++) {
accs[i] = init;
}
U accs[M];
accs[0] = init;
size_t start = grid.block_rank() * block_step;
size_t end = start + block_step;
@ -35,7 +36,7 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
for (size_t i = start; i + block.size() * N <= check; i += block.size() * N) {
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
for (int j = 0; j < N; j++) {
accs[j] = op(accs[j], __cast<U, T>(vals[j]));
accs[0] = op(accs[0], __cast<U, T>(vals[j]));
}
}
@ -45,26 +46,12 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
cub::LoadDirectBlocked(
block.thread_rank(), in + offset, vals, block_end, __cast<T, U>(init));
for (int i = 0; i < N; i++) {
accs[i] = op(accs[i], __cast<U, T>(vals[i]));
accs[0] = op(accs[0], __cast<U, T>(vals[i]));
}
}
for (int i = 1; i < N; i++) {
accs[0] = op(accs[0], accs[i]);
}
accs[0] = cg::reduce(warp, accs[0], op);
if (warp.meta_group_size() > 1) {
__shared__ U shared_accumulators[32];
if (warp.thread_rank() == 0) {
shared_accumulators[warp.meta_group_rank()] = accs[0];
}
block.sync();
accs[0] = (warp.thread_rank() < warp.meta_group_size())
? shared_accumulators[warp.thread_rank()]
: init;
accs[0] = cg::reduce(warp, accs[0], op);
}
__shared__ U shared_accumulators[32];
block_reduce(block, warp, accs, shared_accumulators, op, init);
if (block.thread_rank() == 0) {
out[grid.block_rank()] = accs[0];

View File

@ -64,86 +64,6 @@ struct ColReduceArgs {
}
};
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
__global__ void col_reduce_small(
const T* in,
U* out,
const __grid_constant__ ColReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
int column =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
if (column * N_READS >= args.reduction_stride) {
return;
}
int out_idx = grid.block_rank() / grid.dim_blocks().x;
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
Op op;
U totals[N_READS];
for (int i = 0; i < N_READS; i++) {
totals[i] = ReduceInit<Op, T>::value();
}
// Read input to local.
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(
block.thread_index().y,
args.reduce_shape.data(),
args.reduce_strides.data());
for (size_t r = block.thread_index().y;
r < args.non_col_reductions * args.reduction_size;
r += block.dim_threads().y) {
U vals[N_READS];
cub::LoadDirectBlocked(
column,
make_cast_iterator<U>(in + loop.location()),
vals,
args.reduction_stride,
ReduceInit<Op, T>::value());
for (int i = 0; i < N_READS; i++) {
totals[i] = op(vals[i], totals[i]);
}
loop.next(
block.dim_threads().y,
args.reduce_shape.data(),
args.reduce_strides.data());
}
// Do block reduce when each column has more than 1 element to reduce.
if (block.dim_threads().y > 1) {
__shared__ U shared_vals[32 * 8 * N_READS];
size_t col =
block.thread_index().y * block.dim_threads().x + block.thread_index().x;
for (int i = 0; i < N_READS; i++) {
shared_vals[col * N_READS + i] = totals[i];
}
block.sync();
if (block.thread_index().y == 0) {
for (int i = 0; i < N_READS; i++) {
totals[i] = shared_vals[block.thread_index().x * N_READS + i];
}
for (int j = 1; j < block.dim_threads().y; j++) {
col = j * block.dim_threads().x + block.thread_index().x;
for (int i = 0; i < N_READS; i++) {
totals[i] = op(shared_vals[col * N_READS + i], totals[i]);
}
}
}
}
// Write result.
if (block.thread_index().y == 0) {
cub::StoreDirectBlocked(
column,
out + out_idx * args.reduction_stride,
totals,
args.reduction_stride);
}
}
template <
typename T,
typename U,
@ -152,67 +72,83 @@ template <
int BM,
int BN,
int N_READS = 4>
__global__ void col_reduce_looped(
const T* in,
U* out,
const __grid_constant__ ColReduceArgs args) {
__global__ void
col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
constexpr int n_warps = BN / N_READS;
constexpr int threads_per_row = BN / N_READS;
int out_idx = grid.block_rank() / grid.dim_blocks().x;
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
// Compute the indices for the tile
size_t tile_idx = grid.block_rank();
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
// Compute the indices for the thread within the tile
short thread_x = block.thread_rank() % threads_per_row;
short thread_y = block.thread_rank() / threads_per_row;
// Move the input pointer
in += elem_to_loc(tile_y, args.shape.data(), args.strides.data(), args.ndim) +
tile_x * BN;
// Initialize the running totals
Op op;
U totals[N_READS];
for (int i = 0; i < N_READS; i++) {
totals[i] = ReduceInit<Op, T>::value();
}
// Read input to local.
int r = block.thread_rank() / n_warps;
int column = block.thread_rank() % n_warps;
int in_offset = grid.block_index().x * BN;
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(r, args.reduce_shape.data(), args.reduce_strides.data());
for (; r < args.non_col_reductions * args.reduction_size; r += BM) {
U vals[N_READS];
cub::LoadDirectBlocked(
column,
make_cast_iterator<U>(in + loop.location() + in_offset),
vals,
args.reduction_stride - in_offset,
ReduceInit<Op, T>::value());
for (int i = 0; i < N_READS; i++) {
totals[i] = op(vals[i], totals[i]);
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
size_t total = args.non_col_reductions * args.reduction_size;
if (tile_x * BN + BN <= args.reduction_stride) {
for (size_t r = thread_y; r < total; r += BM) {
T vals[N_READS];
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) {
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
}
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
}
} else {
for (size_t r = thread_y; r < total; r += BM) {
T vals[N_READS];
cub::LoadDirectBlocked(
thread_x,
in + loop.location(),
vals,
args.reduction_stride - tile_x * BN,
__cast<T, U>(ReduceInit<Op, T>::value()));
for (int i = 0; i < N_READS; i++) {
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
}
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
}
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
}
// Do warp reduce for each output.
constexpr int n_outputs = BN / n_warps;
constexpr int n_outputs = BN / threads_per_row;
static_assert(BM == 32 && n_outputs == N_READS);
__shared__ U shared_vals[BM * BN];
size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS;
short s_idx = thread_y * BN + thread_x * N_READS;
for (int i = 0; i < N_READS; i++) {
shared_vals[col + i] = totals[i];
shared_vals[s_idx + i] = totals[i];
}
block.sync();
col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
s_idx = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
for (int i = 0; i < n_outputs; i++) {
totals[i] = cg::reduce(warp, shared_vals[col + i], op);
totals[i] = cg::reduce(warp, shared_vals[s_idx + i], op);
}
// Write result.
if (warp.thread_rank() == 0) {
size_t out_offset = grid.block_index().x * BN;
cub::StoreDirectBlocked(
warp.meta_group_rank(),
out + out_idx * args.reduction_stride + out_offset,
out + tile_y * args.reduction_stride + tile_x * BN,
totals,
args.reduction_stride - out_offset);
args.reduction_stride - tile_x * BN);
}
}
@ -230,6 +166,53 @@ inline auto output_grid_for_col_reduce(
return get_2d_grid_dims(out_shape, out_strides);
}
void col_reduce_looped(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
cu::ColReduceArgs args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
// Just a way to get out of the constness because cub doesn't like it ...
// (sigh)
array x = in;
encoder.set_input_array(x);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, args);
size_t extra_blocks = cuda::ceil_div(args.reduction_stride, BN);
if (grid.x * extra_blocks < INT32_MAX) {
grid.x *= extra_blocks;
} else if (grid.y * extra_blocks < 65536) {
grid.y *= extra_blocks;
} else {
throw std::runtime_error(
"[col_reduce_looped] Need to factorize reduction_stride");
}
int blocks = BM * BN / N_READS;
auto kernel = cu::col_reduce_looped<T, U, OP, NDIM, BM, BN, N_READS>;
kernel<<<grid, blocks, 0, stream>>>(x.data<T>(), out.data<U>(), args);
});
});
});
});
}
void col_reduce(
cu::CommandEncoder& encoder,
const array& in,
@ -237,42 +220,24 @@ void col_reduce(
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan) {
// Current col reduce options
//
// - col_reduce_looped
//
// It is a general strided reduce. Each threadblock computes the output for
// a subrow of the fast moving axis. For instance 32 elements.
//
// Notes: As in row reduce we opt to read as much in order as possible and
// leave
// transpositions as they are (contrary to our Metal backend).
//
// Moreover we need different kernels for short rows and tuning
// Make the args struct to help route to the best kernel
cu::ColReduceArgs args(in, plan, axes);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
using InType = cuda_type_t<CTYPE>;
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using OutType = cu::ReduceResult<OP, InType>::type;
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
constexpr int N_READS = 4;
dim3 block_dims;
dim3 num_blocks = output_grid_for_col_reduce(out, args);
num_blocks.z = num_blocks.y;
num_blocks.y = num_blocks.x;
auto kernel =
cu::col_reduce_small<InType, OutType, OP, NDIM, N_READS>;
size_t total = args.non_col_reductions * args.reduction_size;
if (total < 32) {
size_t stride_blocks =
cuda::ceil_div(args.reduction_stride, N_READS);
block_dims.x = std::min(stride_blocks, 32ul);
block_dims.y = std::min(total, 8ul);
num_blocks.x = cuda::ceil_div(stride_blocks, block_dims.x);
} else {
constexpr int BM = 32;
constexpr int BN = 32;
block_dims.x = BM * BN / N_READS;
num_blocks.x = cuda::ceil_div(args.reduction_stride, BN);
kernel = cu::
col_reduce_looped<InType, OutType, OP, NDIM, BM, BN, N_READS>;
}
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>(), out.data<OutType>(), args);
});
});
});
});
// Fallback col reduce
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
}
} // namespace mlx::core

View File

@ -0,0 +1,51 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/reduce/reduce.cuh"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T, typename U, typename Op>
__global__ void init_reduce(U* out, size_t size) {
auto index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = ReduceInit<Op, T>::value();
}
}
} // namespace cu
void init_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type) {
// Allocate if needed
if (out.data_shared_ptr() == nullptr) {
out.set_data(allocator::malloc(out.nbytes()));
}
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
auto kernel = cu::init_reduce<T, U, OP>;
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
grid.x = (grid.x + 1023) / 1024;
kernel<<<grid, block, 0, stream>>>(out.data<U>(), out.size());
});
});
});
}
} // namespace mlx::core

View File

@ -53,14 +53,6 @@ void all_reduce(
array& out,
Reduce::ReduceType reduce_type);
void segmented_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan);
void row_reduce(
cu::CommandEncoder& encoder,
const array& in,
@ -77,4 +69,10 @@ void col_reduce(
const std::vector<int>& axes,
const ReductionPlan& plan);
void init_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type);
} // namespace mlx::core

View File

@ -4,7 +4,14 @@
#include "mlx/backend/cuda/device/utils.cuh"
namespace mlx::core::cu {
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <size_t N>
struct uint_by_size;
@ -62,4 +69,66 @@ inline __device__ cuComplex __cast<cuComplex, bool>(bool x) {
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
}
} // namespace mlx::core::cu
template <typename T, int N, typename Block, typename Warp, typename Op>
inline __device__ void
block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) {
// First reduce in the current warp
for (int i = 0; i < N; i++) {
vals[i] = cg::reduce(warp, vals[i], op);
}
// Reduce across warps
if (warp.meta_group_size() > 1) {
if (warp.thread_rank() == 0) {
for (int i = 0; i < N; i++) {
smem[warp.meta_group_rank() * N + i] = vals[i];
}
}
block.sync();
if (warp.thread_rank() < warp.meta_group_size()) {
for (int i = 0; i < N; i++) {
vals[i] = smem[warp.thread_rank() * N + i];
}
} else {
for (int i = 0; i < N; i++) {
vals[i] = init;
}
}
for (int i = 0; i < N; i++) {
vals[i] = cg::reduce(warp, vals[i], op);
}
}
}
} // namespace cu
inline void allocate_same_layout(
array& out,
const array& in,
const std::vector<int>& axes) {
// Initialize out such that it matches in's layout. Basically we keep any
// transpositions as it were and that allows us either to skip finding the
// location of the output that matches the input or simply contiguous read or
// writes.
auto out_strides = in.strides();
for (auto ax : axes) {
for (auto& s : out_strides) {
if (s > in.strides(ax)) {
s /= in.shape(ax);
}
}
}
auto [data_size, rc, cc] = check_contiguity(out.shape(), out_strides);
auto fl = in.flags();
fl.row_contiguous = rc;
fl.col_contiguous = cc;
fl.contiguous = data_size == out.size();
out.set_data(
allocator::malloc(out.nbytes()),
data_size,
out_strides,
fl,
allocator::free);
}
} // namespace mlx::core

View File

@ -55,87 +55,28 @@ struct RowReduceArgs {
non_row_reductions *= reduce_shape[i];
}
}
// Convert shape and strides as if in was contiguous
void convert_shapes_to_contiguous(
const array& in,
const std::vector<int>& axes) {
auto shape_vec = in.shape();
auto strides_vec = in.strides();
size_t s = 1;
for (int i = in.ndim() - 1; i >= 0; i--) {
strides_vec[i] = s;
s *= shape_vec[i];
}
std::tie(shape_vec, strides_vec) =
shapes_without_reduction_axes(shape_vec, strides_vec, axes);
std::tie(shape_vec, strides_vec) =
collapse_contiguous_dims(shape_vec, strides_vec);
shape = const_param(shape_vec);
strides = const_param(strides_vec);
ndim = shape_vec.size();
}
};
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
__global__ void row_reduce_small(
const T* in,
U* out,
size_t out_size,
const __grid_constant__ RowReduceArgs args) {
size_t out_idx = cg::this_grid().thread_rank();
if (out_idx >= out_size) {
return;
}
Op op;
U total_val = ReduceInit<Op, T>::value();
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
for (size_t n = 0; n < args.non_row_reductions; n++) {
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
U vals[N_READS];
cub::LoadDirectBlocked(
r,
make_cast_iterator<U>(in + loop.location()),
vals,
args.row_size,
ReduceInit<Op, T>::value());
total_val = op(total_val, cub::ThreadReduce(vals, op));
}
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
out[out_idx] = total_val;
}
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
__global__ void row_reduce_small_warp(
const T* in,
U* out,
size_t out_size,
const __grid_constant__ RowReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
size_t out_idx = grid.thread_rank() / WARP_SIZE;
if (out_idx >= out_size) {
return;
}
Op op;
U total_val = ReduceInit<Op, T>::value();
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
for (size_t n = warp.thread_rank(); n < args.non_row_reductions;
n += WARP_SIZE) {
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
U vals[N_READS];
cub::LoadDirectBlocked(
r,
make_cast_iterator<U>(in + loop.location()),
vals,
args.row_size,
ReduceInit<Op, T>::value());
total_val = op(total_val, cub::ThreadReduce(vals, op));
}
loop.next(WARP_SIZE, args.reduce_shape.data(), args.reduce_strides.data());
}
total_val = cg::reduce(warp, total_val, op);
if (warp.thread_rank() == 0) {
out[out_idx] = total_val;
}
}
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
auto grid = cg::this_grid();
@ -153,59 +94,37 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
const size_t start_row =
min(n_rows - M, static_cast<size_t>(grid.block_rank() * M));
const size_t full_blocks = size / (block.size() * N);
const size_t final_offset = full_blocks * (block.size() * N);
in += start_row * size;
out += start_row;
int i = 0;
for (; i + block.size() * N <= size; i += block.size() * N) {
for (size_t r = 0; r < full_blocks; r++) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlockedVectorized<T, N>(
block.thread_rank(), in + k * size + i, vals[k]);
block.thread_rank(), in + k * size + r * (block.size() * N), vals[k]);
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
}
}
}
if (size > i) {
if (final_offset < size) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlocked(
block.thread_rank(),
in + k * size + i,
in + k * size + final_offset,
vals[k],
size,
__cast<T, U>(init));
for (int j = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
}
}
}
for (int i = 0; i < M; i++) {
accs[i] = cg::reduce(warp, accs[i], op);
}
if (warp.meta_group_size() > 1) {
__shared__ U shared_accumulators[32 * M];
if (warp.thread_rank() == 0) {
for (int i = 0; i < M; i++) {
shared_accumulators[warp.meta_group_rank() * M + i] = accs[i];
}
}
block.sync();
if (warp.thread_rank() < warp.meta_group_size()) {
for (int i = 0; i < M; i++) {
accs[i] = shared_accumulators[warp.thread_rank() * M + i];
}
} else {
for (int i = 0; i < M; i++) {
accs[i] = init;
}
}
for (int i = 0; i < M; i++) {
accs[i] = cg::reduce(warp, accs[i], op);
}
}
__shared__ U shared_accumulators[32 * M];
block_reduce(block, warp, accs, shared_accumulators, op, init);
if (block.thread_rank() == 0) {
if (grid.block_rank() * M + M <= n_rows) {
@ -226,7 +145,7 @@ template <
typename U,
typename Op,
int NDIM,
int BLOCK_DIM_X,
int BLOCK_DIM,
int N_READS = 4>
__global__ void row_reduce_looped(
T* in,
@ -237,27 +156,28 @@ __global__ void row_reduce_looped(
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
size_t out_idx = grid.thread_rank() / BLOCK_DIM_X;
if (out_idx >= out_size) {
return;
}
size_t out_idx = grid.block_rank();
Op op;
U total_val = ReduceInit<Op, T>::value();
U total[1];
U init = ReduceInit<Op, T>::value();
total[0] = init;
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
size_t full_blocks = args.row_size / (BLOCK_DIM_X * N_READS);
size_t final_offset = full_blocks * BLOCK_DIM_X * N_READS;
size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS);
size_t final_offset = full_blocks * BLOCK_DIM * N_READS;
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
for (size_t n = 0; n < args.non_row_reductions; n++) {
for (size_t r = 0; r < full_blocks; r++) {
T vals[N_READS];
cub::LoadDirectBlockedVectorized<T, N_READS>(
block.thread_rank(),
in + loop.location() + r * BLOCK_DIM_X * N_READS,
in + loop.location() + r * BLOCK_DIM * N_READS,
vals);
for (int i = 0; i < N_READS; i++) {
total_val = op(total_val, __cast<U, T>(vals[i]));
total[0] = op(total[0], __cast<U, T>(vals[i]));
}
}
if (final_offset < args.row_size) {
@ -267,21 +187,20 @@ __global__ void row_reduce_looped(
in + loop.location() + final_offset,
vals,
args.row_size - final_offset,
__cast<T, U>(ReduceInit<Op, T>::value()));
__cast<T, U>(init));
for (int i = 0; i < N_READS; i++) {
total_val = op(total_val, __cast<U, T>(vals[i]));
total[0] = op(total[0], __cast<U, T>(vals[i]));
}
}
// TODO: Maybe block.sync() here?
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
typedef cub::BlockReduce<U, BLOCK_DIM_X> BlockReduceT;
__shared__ typename BlockReduceT::TempStorage temp;
total_val = BlockReduceT(temp).Reduce(total_val, op);
__shared__ U shared_accumulators[32];
block_reduce(block, warp, total, shared_accumulators, op, init);
if (block.thread_rank() == 0) {
out[out_idx] = total_val;
out[out_idx] = total[0];
}
}
@ -296,23 +215,9 @@ void row_reduce_simple(
const ReductionPlan& plan) {
constexpr int N_READS = 8;
// Initialize out such that its strides match in's layout (except the fastest
// moving axis)
auto out_strides = in.strides();
for (auto& s : out_strides) {
s /= plan.shape.back();
}
auto [data_size, rc, cc] = check_contiguity(out.shape(), out_strides);
auto fl = in.flags();
fl.row_contiguous = rc;
fl.col_contiguous = cc;
fl.contiguous = data_size == out.size();
out.set_data(
allocator::malloc(out.nbytes()),
data_size,
out_strides,
fl,
allocator::free);
// Allocate data for the output using in's layout to avoid elem_to_loc in the
// kernel.
allocate_same_layout(out, in, axes);
// Just a way to get out of the constness because cub doesn't like it ...
// (sigh)
@ -329,7 +234,7 @@ void row_reduce_simple(
using U = cu::ReduceResult<OP, T>::type;
// Calculate the grid and block dims
size_t reductions = plan.shape.back() / N_READS;
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
int threads = std::min(1024UL, reductions);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
@ -356,31 +261,13 @@ void row_reduce_looped(
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan) {
const ReductionPlan& plan,
cu::RowReduceArgs args) {
constexpr int N_READS = 8;
// Initialize out such that it matches in's layout. Basically we keep any
// transpositions as it were and that allows us to skip finding the location
// of the output that matches the input.
auto out_strides = in.strides();
for (auto ax : axes) {
for (auto& s : out_strides) {
if (s > in.strides(ax)) {
s /= in.shape(ax);
}
}
}
auto [data_size, rc, cc] = check_contiguity(out.shape(), out_strides);
auto fl = in.flags();
fl.row_contiguous = rc;
fl.col_contiguous = cc;
fl.contiguous = data_size == out.size();
out.set_data(
allocator::malloc(out.nbytes()),
data_size,
out_strides,
fl,
allocator::free);
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
// Just a way to get out of the constness because cub doesn't like it ...
// (sigh)
@ -395,9 +282,9 @@ void row_reduce_looped(
using U = cu::ReduceResult<OP, T>::type;
// Calculate the grid and block dims
cu::RowReduceArgs args(in, plan, axes);
args.convert_shapes_to_contiguous(x, axes);
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
size_t reductions = args.row_size / N_READS;
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
int threads = std::min(1024UL, reductions);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
dim3 block(threads, 1, 1);
@ -426,62 +313,35 @@ void row_reduce(
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan) {
// Current row reduction options
//
// - row_reduce_simple
//
// That means that we are simply reducing across the fastest moving axis.
// We are reducing 1 or 2 rows per threadblock depending on the size of
// output.
//
// - row_reduce_looped
//
// It is a general row reduction. We are computing 1 output per
// threadblock. We read the fastest moving axis vectorized and loop over
// the rest of the axes.
//
// Notes: We opt to read as much in order as possible and leave
// transpositions as they are (contrary to our Metal backend).
// Simple row reduce means that we have 1 axis that we are reducing over and
// it has stride 1.
if (plan.shape.size() == 1) {
row_reduce_simple(encoder, in, out, reduce_type, axes, plan);
return;
}
// Fallback row reduce
row_reduce_looped(encoder, in, out, reduce_type, axes, plan);
// Make the args struct to help route to the best kernel
cu::RowReduceArgs args(in, plan, axes);
// encoder.launch_kernel([&](cudaStream_t stream) {
// MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
// using InType = cuda_type_t<CTYPE>;
// MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
// using OutType = cu::ReduceResult<OP, InType>::type;
// MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
// constexpr size_t N_READS = 4;
// dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides());
// dim3 block_dims, num_blocks;
// auto kernel =
// cu::row_reduce_small<InType, OutType, OP, NDIM, N_READS>;
// if (args.row_size <= 64) {
// if ((args.non_row_reductions < 32 && args.row_size <= 8) ||
// (args.non_row_reductions <= 8)) {
// block_dims.x = std::min(out_dims.x, 1024u);
// num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x);
// num_blocks.y = out_dims.y;
// } else {
// block_dims.x = WARP_SIZE;
// num_blocks.y = out_dims.x;
// num_blocks.z = out_dims.y;
// kernel =
// cu::row_reduce_small_warp<InType, OutType, OP, NDIM,
// N_READS>;
// }
// } else {
// size_t num_threads = cuda::ceil_div(args.row_size, N_READS);
// num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE;
// MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, {
// num_blocks.y = out_dims.x;
// num_blocks.z = out_dims.y;
// block_dims.x = BLOCK_DIM_X;
// kernel = cu::row_reduce_looped<
// InType,
// OutType,
// OP,
// NDIM,
// BLOCK_DIM_X,
// N_READS>;
// });
// }
// kernel<<<num_blocks, block_dims, 0, stream>>>(
// in.data<InType>(), out.data<OutType>(), out.size(), args);
// });
// });
// });
// });
// Fallback row reduce
row_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args));
}
} // namespace mlx::core

View File

@ -51,7 +51,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
make_cast_iterator<AccT>(in),
vals,
axis_size,
Limits<AccT>::finite_min());
Limits<AccT>::min());
prevmax = maxval;
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
// Online normalizer calculation for softmax:
@ -79,7 +79,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
block.sync();
maxval = warp.thread_rank() < warp.meta_group_size()
? local_max[warp.thread_rank()]
: Limits<AccT>::finite_min();
: Limits<AccT>::min();
maxval = cg::reduce(warp, maxval, max_op);
normalizer = normalizer * softmax_exp(prevmax - maxval);
if (warp.thread_rank() == 0) {

View File

@ -79,9 +79,6 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
array out = out_;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
if (axis < 0) {
axis += in.ndim();
}
@ -106,6 +103,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.flags());
}
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {

View File

@ -413,7 +413,7 @@ class Module(dict):
f'Module does not have sub-module named "{k}".'
)
elif isinstance(modules, list):
for i in range(len(dst)):
for i in range(len(modules)):
current_value = dst[i]
new_value = modules[i]
if self.is_module(current_value) and self.is_module(new_value):

View File

@ -0,0 +1,17 @@
#!/bin/bash
auditwheel repair dist/* \
--plat manylinux_2_35_x86_64 \
--exclude libcublas* \
--exclude libnvrtc*
cd wheelhouse
repaired_wheel=$(find . -name "*.whl" -print -quit)
unzip -q "${repaired_wheel}"
core_so=$(find mlx -name "core*.so" -print -quit)
rpath=$(patchelf --print-rpath "${core_so}")
rpath=$rpath:\$ORIGIN/../nvidia/cublas/lib:\$ORIGIN/../nvidia/cuda_nvrtc/lib
patchelf --force-rpath --set-rpath "$rpath" "$core_so"
# Re-zip the repaired wheel
zip -r -q "${repaired_wheel}" .

View File

@ -13,7 +13,6 @@ cuda_skip = {
"TestLayers.test_upsample",
"TestOps.test_complex_ops",
"TestOps.test_dynamic_slicing",
"TestOps.test_softmax",
"TestReduce.test_axis_permutation_sums",
"TestReduce.test_dtypes",
"TestReduce.test_expand_sums",

View File

@ -259,6 +259,11 @@ class TestBase(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
m = m.update_modules({"list": ["hi"]})
# Allow updating a strict subset
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
m.update_modules({"layers": [{}, nn.Linear(3, 4)]})
self.assertEqual(m.layers[1].weight.shape, (4, 3))
class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self):

View File

@ -174,20 +174,26 @@ if __name__ == "__main__":
)
package_dir = {"": "python"}
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]}
install_requires = []
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
if build_cuda:
install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"]
setup(
name="mlx",
name="mlx-cuda" if build_cuda else "mlx",
version=get_version(),
author="MLX Contributors",
author_email="mlx@group.apple.com",
description="A framework for machine learning on Apple silicon.",
long_description=long_description,
long_description_content_type="text/markdown",
license="MIT",
url="https://github.com/ml-explore/mlx",
packages=packages,
package_dir=package_dir,
package_data=package_data,
include_package_data=True,
install_requires=install_requires,
extras_require={
"dev": [
"nanobind==2.4.0",