Compare commits

..

1 Commits

Author SHA1 Message Date
Angelos Katharopoulos
870208eff5 Start sdpa vector 2025-06-16 17:38:39 -07:00
45 changed files with 601 additions and 1529 deletions

View File

@@ -16,9 +16,6 @@ parameters:
linux_release: linux_release:
type: boolean type: boolean
default: false default: false
cuda_release:
type: boolean
default: false
jobs: jobs:
build_documentation: build_documentation:
@@ -107,7 +104,7 @@ jobs:
command: | command: |
echo "stubs" echo "stubs"
pip install typing_extensions pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |
@@ -165,7 +162,7 @@ jobs:
command: | command: |
source env/bin/activate source env/bin/activate
pip install typing_extensions pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |
@@ -226,6 +223,7 @@ jobs:
command: | command: |
sudo apt-get update sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
python -m venv env python -m venv env
source env/bin/activate source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
@@ -285,7 +283,7 @@ jobs:
command: | command: |
source env/bin/activate source env/bin/activate
pip install typing_extensions pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
- run: - run:
name: Build Python package name: Build Python package
command: | command: |
@@ -344,7 +342,7 @@ jobs:
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
pip install . -v pip install . -v
pip install typing_extensions pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
<< parameters.extra_env >> \ << parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python -m build --wheel python -m build --wheel
@@ -358,48 +356,6 @@ jobs:
- store_artifacts: - store_artifacts:
path: wheelhouse/ path: wheelhouse/
build_cuda_release:
parameters:
python_version:
type: string
default: "3.9"
extra_env:
type: string
default: "DEV_RELEASE=1"
machine:
image: linux-cuda-12:default
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- run:
name: Build wheel
command: |
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
python -m venv env
source env/bin/activate
pip install auditwheel
pip install patchelf
pip install build
pip install twine
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install ".[dev]" -v
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build --wheel
bash python/scripts/repair_cuda.sh
- run:
name: Upload package
command: |
source env/bin/activate
twine upload wheelhouse/*.whl
- store_artifacts:
path: wheelhouse/
workflows: workflows:
build_and_test: build_and_test:
when: when:
@@ -669,14 +625,3 @@ workflows:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"] extra_env: ["PYPI_RELEASE=1"]
cuda_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.cuda_release >>
jobs:
- build_cuda_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]

View File

@@ -5,7 +5,6 @@ import os
import time import time
import torch import torch
import torch.cuda
import torch.mps import torch.mps
@@ -45,10 +44,8 @@ def bench(f, *args):
def sync_if_needed(x): def sync_if_needed(x):
if x.device == torch.device("mps"): if x.device != torch.device("cpu"):
torch.mps.synchronize() torch.mps.synchronize()
elif x.device == torch.device("cuda"):
torch.cuda.synchronize()
@torch.no_grad() @torch.no_grad()
@@ -102,14 +99,6 @@ def reduction(op, axis, x):
sync_if_needed(x) sync_if_needed(x)
@torch.no_grad()
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
sync_if_needed(x)
@torch.no_grad() @torch.no_grad()
def softmax(axis, x): def softmax(axis, x):
ys = [] ys = []
@@ -351,11 +340,7 @@ if __name__ == "__main__":
args.axis.pop(0) args.axis.pop(0)
torch.set_num_threads(1) torch.set_num_threads(1)
device = "mps" device = "cpu" if args.cpu else "mps"
if torch.cuda.is_available():
device = "cuda"
if args.cpu:
device = "cpu"
types = args.dtype types = args.dtype
if not types: if not types:
@@ -475,8 +460,5 @@ if __name__ == "__main__":
elif args.benchmark == "selu": elif args.benchmark == "selu":
print(bench(selu, x)) print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else: else:
raise ValueError(f"Unknown benchmark `{args.benchmark}`.") raise ValueError(f"Unknown benchmark `{args.benchmark}`.")

View File

@@ -30,16 +30,6 @@ MLX is also available on conda-forge. To install MLX with conda do:
conda install conda-forge::mlx conda install conda-forge::mlx
CUDA
^^^^
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
.. code-block:: shell
pip install mlx-cuda
Troubleshooting Troubleshooting
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^
@@ -75,8 +65,6 @@ Build Requirements
Python API Python API
^^^^^^^^^^ ^^^^^^^^^^
.. _python install:
To build and install the MLX python library from source, first, clone MLX from To build and install the MLX python library from source, first, clone MLX from
`its GitHub repo <https://github.com/ml-explore/mlx>`_: `its GitHub repo <https://github.com/ml-explore/mlx>`_:
@@ -119,8 +107,6 @@ IDE:
C++ API C++ API
^^^^^^^ ^^^^^^^
.. _cpp install:
Currently, MLX must be built and installed from source. Currently, MLX must be built and installed from source.
Similarly to the python library, to build and install the MLX C++ library start Similarly to the python library, to build and install the MLX C++ library start
@@ -199,7 +185,6 @@ should point to the path to the built metal library.
xcrun -sdk macosx --show-sdk-version xcrun -sdk macosx --show-sdk-version
Binary Size Minimization Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~
@@ -228,50 +213,6 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists across reboots. Metal kernel cache persists across reboots.
Linux
^^^^^
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
For example on Ubuntu, run the following:
.. code-block:: shell
apt-get update -y
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
From here follow the instructions to install either the :ref:`Python <python
install>` or :ref:`C++ <cpp install>` APIs.
CUDA
^^^^
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
and the CUDA toolkit. For example on Ubuntu, run the following:
.. code-block:: shell
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y
apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
When building either the Python or C++ APIs make sure to pass the cmake flag
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
To build the C++ package run:
.. code-block:: shell
mkdir -p build && cd build
cmake .. -DMLX_BUILD_CUDA=ON && make -j
Troubleshooting Troubleshooting
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^

View File

@@ -14,8 +14,6 @@ void print_constant(std::ostream& os, const array& x) {
return print_float_constant<float16_t>(os, x); return print_float_constant<float16_t>(os, x);
case bfloat16: case bfloat16:
return print_float_constant<bfloat16_t>(os, x); return print_float_constant<bfloat16_t>(os, x);
case float64:
return print_float_constant<double>(os, x);
case complex64: case complex64:
return print_complex_constant<complex64_t>(os, x); return print_complex_constant<complex64_t>(os, x);
case int8: case int8:
@@ -52,8 +50,6 @@ std::string get_type_string(Dtype d) {
return "float16_t"; return "float16_t";
case bfloat16: case bfloat16:
return "bfloat16_t"; return "bfloat16_t";
case float64:
return "double";
case complex64: case complex64:
return "complex64_t"; return "complex64_t";
case bool_: case bool_:

View File

@@ -18,12 +18,8 @@ std::string get_type_string(Dtype d);
template <typename T> template <typename T>
void print_float_constant(std::ostream& os, const array& x) { void print_float_constant(std::ostream& os, const array& x) {
auto old_precision = os.precision(); auto old_precision = os.precision();
if constexpr (std::is_same_v<T, double>) { os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
os << std::setprecision(std::numeric_limits<double>::digits10 + 1); << x.item<T>() << std::setprecision(old_precision);
} else {
os << std::setprecision(std::numeric_limits<float>::digits10 + 1);
}
os << x.item<T>() << std::setprecision(old_precision);
} }
template <typename T> template <typename T>

View File

@@ -5,9 +5,11 @@
namespace mlx::core { namespace mlx::core {
std::pair<Shape, Strides> shapes_without_reduction_axes( std::pair<Shape, Strides> shapes_without_reduction_axes(
Shape shape, const array& x,
Strides strides,
const std::vector<int>& axes) { const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();
for (int i = axes.size() - 1; i >= 0; i--) { for (int i = axes.size() - 1; i >= 0; i--) {
int a = axes[i]; int a = axes[i];
shape.erase(shape.begin() + a); shape.erase(shape.begin() + a);
@@ -17,15 +19,6 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
return std::make_pair(shape, strides); return std::make_pair(shape, strides);
} }
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();
return shapes_without_reduction_axes(
std::move(shape), std::move(strides), axes);
}
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) { ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// The data is all there and we are reducing over everything // The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() && if (x.size() == x.data_size() && axes.size() == x.ndim() &&

View File

@@ -51,9 +51,5 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
std::pair<Shape, Strides> shapes_without_reduction_axes( std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x, const array& x,
const std::vector<int>& axes); const std::vector<int>& axes);
std::pair<Shape, Strides> shapes_without_reduction_axes(
Shape shape,
Strides strides,
const std::vector<int>& axes);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -199,15 +199,12 @@ Dims get_2d_grid_dims_common(
} }
} }
} }
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) { if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) {
throw std::runtime_error("Unable to safely factor shape."); throw std::runtime_error("Unable to safely factor shape.");
} }
if (grid_y > grid_x) { if (grid_y > grid_x) {
std::swap(grid_x, grid_y); std::swap(grid_x, grid_y);
} }
if (divisor > 1) {
grid_x = ((grid_x + divisor - 1) / divisor) * divisor;
}
return std::make_tuple( return std::make_tuple(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1); static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
} }

View File

@@ -8,7 +8,6 @@ target_sources(
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu ${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
@@ -29,12 +28,12 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
${CMAKE_CURRENT_SOURCE_DIR}/random.cu ${CMAKE_CURRENT_SOURCE_DIR}/random.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu

View File

@@ -3,7 +3,6 @@
#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/worker.h" #include "mlx/backend/cuda/worker.h"
#include "mlx/utils.h"
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <fmt/format.h> #include <fmt/format.h>
@@ -15,11 +14,9 @@ namespace mlx::core {
namespace cu { namespace cu {
constexpr int page_size = 16384;
CudaAllocator::CudaAllocator() CudaAllocator::CudaAllocator()
: buffer_cache_( : buffer_cache_(
page_size, getpagesize(),
[](CudaBuffer* buf) { return buf->size; }, [](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { [this](CudaBuffer* buf) {
cuda_free(buf->data); cuda_free(buf->data);
@@ -34,14 +31,7 @@ CudaAllocator::CudaAllocator()
Buffer CudaAllocator::malloc(size_t size) { Buffer CudaAllocator::malloc(size_t size) {
// Find available buffer from cache. // Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_); std::unique_lock lock(mutex_);
if (size < page_size) {
size = next_power_of_2(size);
} else {
size = page_size * ((size + page_size - 1) / page_size);
}
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) { if (!buf) {
// If we have a lot of memory pressure or are over the maximum cache size, // If we have a lot of memory pressure or are over the maximum cache size,
@@ -116,6 +106,7 @@ void CudaAllocator::cuda_free(void* buf) {
return; return;
} }
} }
cudaFree(buf); cudaFree(buf);
} }

View File

@@ -125,12 +125,13 @@ constexpr bool supports_binary_op() {
template <typename Op> template <typename Op>
void binary_op_gpu_inplace( void binary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, std::vector<array>& outputs,
std::string_view op, std::string_view op,
const Stream& s) { const Stream& s) {
assert(inputs.size() > 1); assert(inputs.size() > 1);
const auto& a = inputs[0]; const auto& a = inputs[0];
const auto& b = inputs[1]; const auto& b = inputs[1];
auto& out = outputs[0];
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }
@@ -145,6 +146,7 @@ void binary_op_gpu_inplace(
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) { if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>; using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>; using OutType = cuda_type_t<CTYPE_OUT>;
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
if (bopt == BinaryOpType::General) { if (bopt == BinaryOpType::General) {
auto [shape, strides] = collapse_contiguous_dims(a, b, out); auto [shape, strides] = collapse_contiguous_dims(a, b, out);
@@ -217,6 +219,20 @@ void binary_op_gpu_inplace(
}); });
} }
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt);
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
template <typename Op> template <typename Op>
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
@@ -227,7 +243,8 @@ void binary_op_gpu(
auto& b = inputs[1]; auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt); set_binary_op_output_data(a, b, out, bopt);
binary_op_gpu_inplace<Op>(inputs, out, op, s); std::vector<array> outputs{out};
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
} }
#define BINARY_GPU(func) \ #define BINARY_GPU(func) \
@@ -237,6 +254,14 @@ void binary_op_gpu(
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \ binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
} }
#define BINARY_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = outputs[0].primitive().stream(); \
binary_op_gpu<cu::func>(inputs, outputs, get_primitive_string(this), s); \
}
BINARY_GPU(Add) BINARY_GPU(Add)
BINARY_GPU(ArcTan2) BINARY_GPU(ArcTan2)
BINARY_GPU(Divide) BINARY_GPU(Divide)

View File

@@ -1,248 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/binary.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/binary_ops.cuh"
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void
binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[0], b[0]);
out_a[0] = out[0];
out_b[0] = out[1];
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void
binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[0], b[index]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void
binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[index], b[0]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void
binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[index], b[index]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
__global__ void binary_g_nd(
const In* a,
const In* b,
Out* out_a,
Out* out_b,
IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index, shape.data(), a_strides.data(), b_strides.data());
auto out = Op{}(a[a_idx], b[b_idx]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_g(
const In* a,
const In* b,
Out* out_a,
Out* out_b,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_4d(
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
auto out = Op{}(a[a_idx], b[b_idx]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out>
constexpr bool supports_binary_op() {
if (std::is_same_v<Op, DivMod>) {
return std::is_same_v<In, Out> &&
(std::is_integral_v<Out> || is_floating_v<Out>);
}
return false;
}
} // namespace cu
template <typename Op>
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
assert(inputs.size() > 1);
const auto& a = inputs[0];
const auto& b = inputs[1];
auto& out_a = outputs[0];
auto& out_b = outputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
if (out_a.size() == 0) {
return;
}
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
encoder.set_output_array(out_b);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
MLX_SWITCH_ALL_TYPES(out_a.dtype(), CTYPE_OUT, {
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
auto bopt = get_binary_op_type(a, b);
if (bopt == BinaryOpType::General) {
auto [shape, strides] = collapse_contiguous_dims(a, b, out_a);
auto& a_strides = strides[0];
auto& b_strides = strides[1];
bool large = a.data_size() > INT32_MAX ||
b.data_size() > INT32_MAX || out_a.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel =
cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.size(),
const_param<NDIM>(shape),
const_param<NDIM>(a_strides),
const_param<NDIM>(b_strides));
});
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
ndim);
}
});
} else {
MLX_SWITCH_BOOL(out_a.data_size() > UINT32_MAX, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorScalar) {
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] = get_launch_args(
kernel,
out_a.data_size(),
out_a.shape(),
out_a.strides(),
LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.data_size());
});
}
} else {
throw std::runtime_error(fmt::format(
"Can not do binary op {} on inputs of {} with result of {}.",
op,
dtype_to_string(a.dtype()),
dtype_to_string(out_a.dtype())));
}
});
});
});
}
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt);
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
void DivMod::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("DivMod::eval_gpu");
auto& s = outputs[0].primitive().stream();
binary_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s);
}
} // namespace mlx::core

View File

@@ -24,6 +24,7 @@ void copy_gpu_inplace(
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) { if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out); copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
return; return;

View File

@@ -63,30 +63,25 @@ void copy_general(
MLX_SWITCH_BOOL(large, LARGE, { MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>; using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size(); int ndim = shape.size();
size_t data_size = 1;
for (auto& s : shape)
data_size *= s;
if (ndim <= 3) { if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, { MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>; auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
get_launch_args(kernel, data_size, shape, out.strides(), large);
kernel<<<num_blocks, block_dims, 0, stream>>>( kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr, in_ptr,
out_ptr, out_ptr,
data_size, out.size(),
const_param<NDIM>(shape), const_param<NDIM>(shape),
const_param<NDIM>(strides_in), const_param<NDIM>(strides_in),
const_param<NDIM>(strides_out)); const_param<NDIM>(strides_out));
}); });
} else { // ndim >= 4 } else { // ndim >= 4
auto kernel = cu::copy_gg<InType, OutType, IdxT>; auto kernel = cu::copy_gg<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
get_launch_args(kernel, data_size, shape, out.strides(), large);
kernel<<<num_blocks, block_dims, 0, stream>>>( kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr, in_ptr,
out_ptr, out_ptr,
data_size, out.size(),
const_param(shape), const_param(shape),
const_param(strides_in), const_param(strides_in),
const_param(strides_out), const_param(strides_out),

View File

@@ -6,7 +6,6 @@
#include <fmt/format.h> #include <fmt/format.h>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <future>
namespace mlx::core { namespace mlx::core {
@@ -108,16 +107,6 @@ void CommandEncoder::commit() {
worker_.commit(stream_.last_cuda_stream()); worker_.commit(stream_.last_cuda_stream());
} }
void CommandEncoder::synchronize() {
stream().synchronize();
auto p = std::make_shared<std::promise<void>>();
std::future<void> f = p->get_future();
add_completed_handler([p = std::move(p)]() { p->set_value(); });
worker_.end_batch();
commit();
f.wait();
}
Device& device(mlx::core::Device device) { Device& device(mlx::core::Device device) {
static std::unordered_map<int, Device> devices; static std::unordered_map<int, Device> devices;
auto it = devices.find(device.index); auto it = devices.find(device.index);

View File

@@ -123,9 +123,6 @@ class CommandEncoder {
return has_gpu_work_; return has_gpu_work_;
} }
// Wait until kernels and completion handlers are finished
void synchronize();
private: private:
Device& device_; Device& device_;
DeviceStream& stream_; DeviceStream& stream_;

View File

@@ -22,7 +22,7 @@ struct FloorDivide {
if constexpr (cuda::std::is_integral_v<T>) { if constexpr (cuda::std::is_integral_v<T>) {
return x / y; return x / y;
} else { } else {
return truncf(x / y); return trunc(x / y);
} }
} }
}; };
@@ -132,7 +132,7 @@ struct LogAddExp {
cuda::std::numeric_limits<float>::quiet_NaN(), cuda::std::numeric_limits<float>::quiet_NaN(),
cuda::std::numeric_limits<float>::quiet_NaN()}; cuda::std::numeric_limits<float>::quiet_NaN()};
} }
float inf = cuda::std::numeric_limits<float>::infinity(); constexpr float inf = cuda::std::numeric_limits<float>::infinity();
auto maxval = x > y ? x : y; auto maxval = x > y ? x : y;
auto minval = x < y ? x : y; auto minval = x < y ? x : y;
if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf) if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf)

View File

@@ -5,7 +5,7 @@
#pragma once #pragma once
// The maximum dimensions of shape/strides passed as kernel parameters. // The maximum dimensions of shape/strides passed as kernel parameters.
#define MAX_NDIM 10 #define MAX_NDIM 8
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in // All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
// warpSize variable exists, using it would prevent compile-time optimizations. // warpSize variable exists, using it would prevent compile-time optimizations.

View File

@@ -155,8 +155,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
#pragma unroll #pragma unroll
for (int i = NDIM - 1; i >= 0; --i) { for (int i = NDIM - 1; i >= 0; --i) {
int dim_idx = elem % shape[i]; int dim_idx = elem % shape[i];
a_loc += dim_idx * IdxT(a_strides[i]); a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * IdxT(b_strides[i]); b_loc += dim_idx * b_strides[i];
elem /= shape[i]; elem /= shape[i];
} }
return cuda::std::make_tuple(a_loc, b_loc); return cuda::std::make_tuple(a_loc, b_loc);
@@ -175,9 +175,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
#pragma unroll #pragma unroll
for (int i = NDIM - 1; i >= 0; --i) { for (int i = NDIM - 1; i >= 0; --i) {
int dim_idx = elem % shape[i]; int dim_idx = elem % shape[i];
a_loc += dim_idx * IdxT(a_strides[i]); a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * IdxT(b_strides[i]); b_loc += dim_idx * b_strides[i];
c_loc += dim_idx * IdxT(c_strides[i]); c_loc += dim_idx * c_strides[i];
elem /= shape[i]; elem /= shape[i];
} }
return cuda::std::make_tuple(a_loc, b_loc, c_loc); return cuda::std::make_tuple(a_loc, b_loc, c_loc);
@@ -206,8 +206,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
IdxT b_loc = 0; IdxT b_loc = 0;
for (int i = ndim - 1; i >= 0; --i) { for (int i = ndim - 1; i >= 0; --i) {
int dim_idx = elem % shape[i]; int dim_idx = elem % shape[i];
a_loc += dim_idx * IdxT(a_strides[i]); a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * IdxT(b_strides[i]); b_loc += dim_idx * b_strides[i];
elem /= shape[i]; elem /= shape[i];
} }
return cuda::std::make_tuple(a_loc, b_loc); return cuda::std::make_tuple(a_loc, b_loc);
@@ -226,9 +226,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
IdxT c_loc = 0; IdxT c_loc = 0;
for (int i = ndim - 1; i >= 0; --i) { for (int i = ndim - 1; i >= 0; --i) {
int dim_idx = elem % shape[i]; int dim_idx = elem % shape[i];
a_loc += dim_idx * IdxT(a_strides[i]); a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * IdxT(b_strides[i]); b_loc += dim_idx * b_strides[i];
c_loc += dim_idx * IdxT(c_strides[i]); c_loc += dim_idx * c_strides[i];
elem /= shape[i]; elem /= shape[i];
} }
return cuda::std::make_tuple(a_loc, b_loc, c_loc); return cuda::std::make_tuple(a_loc, b_loc, c_loc);

View File

@@ -62,7 +62,7 @@ void finalize(Stream s) {
void synchronize(Stream s) { void synchronize(Stream s) {
nvtx3::scoped_range r("gpu::synchronize"); nvtx3::scoped_range r("gpu::synchronize");
cu::get_command_encoder(s).synchronize(); cu::get_stream(s).synchronize();
} }
} // namespace mlx::core::gpu } // namespace mlx::core::gpu

View File

@@ -37,46 +37,36 @@ void check_cu_error(const char* name, CUresult err) {
} }
// Return the location of the CUDA toolkit. // Return the location of the CUDA toolkit.
const std::string& cuda_home() { const char* cuda_home() {
static std::string home = []() -> std::string { const char* home = std::getenv("CUDA_HOME");
const char* home = std::getenv("CUDA_HOME"); if (home) {
if (home) { return home;
return home; }
} home = std::getenv("CUDA_PATH");
home = std::getenv("CUDA_PATH"); if (home) {
if (home) { return home;
return home; }
}
#if defined(__linux__) #if defined(__linux__)
home = "/usr/local/cuda"; home = "/usr/local/cuda";
if (std::filesystem::exists(home)) { if (std::filesystem::exists(home)) {
return home; return home;
} }
#endif #endif
throw std::runtime_error( throw std::runtime_error(
"Environment variable CUDA_HOME or CUDA_PATH is not set."); "Environment variable CUDA_HOME or CUDA_PATH is not set.");
}();
return home;
} }
// Get the cache directory for storing compiled results. // Get the cache directory for storing compiled results.
const std::filesystem::path& ptx_cache_dir() { bool get_ptx_cache_dir(std::filesystem::path* result) {
static std::filesystem::path cache = []() -> std::filesystem::path { auto path = std::filesystem::temp_directory_path() / "mlx" / "ptx";
std::filesystem::path cache; if (!std::filesystem::is_directory(path)) {
if (auto c = std::getenv("MLX_PTX_CACHE"); c) { std::error_code error;
cache = c; if (!std::filesystem::create_directories(path, error)) {
} else { return false;
cache = std::filesystem::temp_directory_path() / "mlx" / "ptx";
} }
if (!std::filesystem::exists(cache)) { }
std::error_code error; *result = path;
if (!std::filesystem::create_directories(cache, error)) { return true;
return std::filesystem::path();
}
}
return cache;
}();
return cache;
} }
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|. // Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
@@ -85,10 +75,6 @@ bool read_cached_ptx(
const std::string& module_name, const std::string& module_name,
std::vector<char>* ptx, std::vector<char>* ptx,
std::vector<std::pair<std::string, std::string>>* ptx_kernels) { std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
if (cache_dir.empty()) {
return false;
}
auto ptx_path = cache_dir / (module_name + ".ptx"); auto ptx_path = cache_dir / (module_name + ".ptx");
std::error_code error; std::error_code error;
auto ptx_size = std::filesystem::file_size(ptx_path, error); auto ptx_size = std::filesystem::file_size(ptx_path, error);
@@ -119,10 +105,6 @@ void write_cached_ptx(
const std::string& module_name, const std::string& module_name,
const std::vector<char>& ptx, const std::vector<char>& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) { const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
if (cache_dir.empty()) {
return;
}
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary); std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
if (!ptx.empty()) { if (!ptx.empty()) {
ptx_file.write(&ptx.front(), ptx.size()); ptx_file.write(&ptx.front(), ptx.size());
@@ -202,9 +184,11 @@ JitModule::JitModule(
const std::string& module_name, const std::string& module_name,
const KernelBuilder& builder) { const KernelBuilder& builder) {
// Check cache. // Check cache.
std::filesystem::path cache_dir;
std::vector<char> ptx; std::vector<char> ptx;
std::vector<std::pair<std::string, std::string>> ptx_kernels; std::vector<std::pair<std::string, std::string>> ptx_kernels;
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) { if (!get_ptx_cache_dir(&cache_dir) ||
!read_cached_ptx(cache_dir, module_name, &ptx, &ptx_kernels)) {
// Create program. // Create program.
auto [source_code, kernel_names] = builder(); auto [source_code, kernel_names] = builder();
nvrtcProgram prog; nvrtcProgram prog;
@@ -262,7 +246,7 @@ JitModule::JitModule(
} else { } else {
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data())); CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
} }
write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels); write_cached_ptx(cache_dir, module_name, ptx, ptx_kernels);
} }
// Load module. // Load module.

View File

@@ -162,15 +162,11 @@ class MatMul {
} }
} }
void* workspace_ptr = nullptr; array workspace(
if (heuristic_.workspaceSize > 0) { allocator::malloc(heuristic_.workspaceSize),
array workspace( {static_cast<int>(heuristic_.workspaceSize)},
allocator::malloc(heuristic_.workspaceSize), int8);
{static_cast<int>(heuristic_.workspaceSize)}, encoder.add_temporary(workspace);
int8);
encoder.add_temporary(workspace);
workspace_ptr = workspace.data<void>();
}
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
CHECK_CUBLAS_ERROR(cublasLtMatmul( CHECK_CUBLAS_ERROR(cublasLtMatmul(
@@ -187,8 +183,8 @@ class MatMul {
out, out,
out_desc_, out_desc_,
&heuristic_.algo, &heuristic_.algo,
workspace_ptr, workspace.data<void>(),
heuristic_.workspaceSize, workspace.nbytes(),
stream)); stream));
}); });
} }
@@ -362,18 +358,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
a_batch_strides.back(), a_batch_strides.back(),
b_batch_strides.back()); b_batch_strides.back());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto nbatch = batch_count / batch_shape.back();
if (nbatch == 1) {
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
return;
}
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
for (size_t i = 0; i < nbatch; ++i) { for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
matmul.run( matmul.run(
encoder, encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N, out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
@@ -457,28 +444,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
b_batch_strides.back(), b_batch_strides.back(),
c_batch_strides.back()); c_batch_strides.back());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto nbatch = batch_count / batch_shape.back();
if (nbatch == 1) {
matmul.run(
encoder,
out.data<int8_t>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
alpha_,
beta_);
return;
}
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
for (size_t i = 0; i < nbatch; ++i) { for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
matmul.run( matmul.run(
encoder, encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N, out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,

View File

@@ -43,17 +43,6 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
}); });
} }
bool fast::ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
return true;
}
#define NO_GPU_MULTI(func) \ #define NO_GPU_MULTI(func) \
void func::eval_gpu( \ void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \ const std::vector<array>& inputs, std::vector<array>& outputs) { \
@@ -71,8 +60,10 @@ bool fast::ScaledDotProductAttention::use_fallback(
throw std::runtime_error(#func " has no CUDA implementation."); \ throw std::runtime_error(#func " has no CUDA implementation."); \
} }
NO_GPU(ArgPartition)
NO_GPU(BlockMaskedMM) NO_GPU(BlockMaskedMM)
NO_GPU(Convolution) NO_GPU(Convolution)
NO_GPU_MULTI(DivMod)
NO_GPU(DynamicSlice) NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate) NO_GPU(DynamicSliceUpdate)
NO_GPU(FFT) NO_GPU(FFT)
@@ -81,6 +72,7 @@ NO_GPU(GatherQMM)
NO_GPU(Hadamard) NO_GPU(Hadamard)
NO_GPU(Load) NO_GPU(Load)
NO_GPU_MULTI(LUF) NO_GPU_MULTI(LUF)
NO_GPU(Partition)
NO_GPU_MULTI(QRF) NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul) NO_GPU(QuantizedMatmul)
NO_GPU(Scan) NO_GPU(Scan)
@@ -91,7 +83,6 @@ NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh) NO_GPU_MULTI(Eigh)
namespace fast { namespace fast {
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize) NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel) NO_GPU_MULTI(CustomKernel)
} // namespace fast } // namespace fast

View File

@@ -21,11 +21,28 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(!axes_.empty()); assert(!axes_.empty());
assert(out.size() != in.size()); assert(out.size() != in.size());
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream(); auto& s = stream();
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
// Fill out with init value.
if (in.size() == 0) { if (in.size() == 0) {
init_reduce(encoder, in, out, reduce_type_); encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, {
using InType = cuda_type_t<CTYPE>;
using OutType = cu::ReduceResult<OP, InType>::type;
thrust::fill_n(
cu::thrust_policy(stream),
thrust::device_pointer_cast(out.data<OutType>()),
out.data_size(),
cu::ReduceInit<OP, InType>::value());
});
});
});
return; return;
} }
@@ -34,19 +51,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// If it is a general reduce then copy the input to a contiguous array and // If it is a general reduce then copy the input to a contiguous array and
// recompute the plan. // recompute the plan.
// if (plan.type == GeneralReduce) {
// TODO: Instead of copying we can use elem-to-loc to deal with broadcasting
// like we do in Metal. When it comes to broadcasted reduction axes
// some can be ignored eg for min/max.
bool broadcasted = false;
for (int i = 0, j = 0; i < in.ndim() && !broadcasted; i++) {
if (j < axes_.size() && axes_[j] == i) {
j++;
} else {
broadcasted = in.strides(i) == 0;
}
}
if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {
array in_copy(in.shape(), in.dtype(), nullptr, {}); array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s); copy_gpu(in, in_copy, CopyType::General, s);
encoder.add_temporary(in_copy); encoder.add_temporary(in_copy);
@@ -54,8 +59,9 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
plan = get_reduction_plan(in, axes_); plan = get_reduction_plan(in, axes_);
} }
if (plan.type == ContiguousAllReduce) { if ((plan.type == ContiguousAllReduce) ||
all_reduce(encoder, in, out, reduce_type_); (plan.type == ContiguousReduce && plan.shape.size() == 1)) {
segmented_reduce(encoder, in, out, reduce_type_, axes_, plan);
return; return;
} }

View File

@@ -1,150 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/reduce/reduce.cuh"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/block/block_load.cuh>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T, typename U, typename ReduceOp, int N = 4>
__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
// TODO: Process multiple "rows" in each thread
constexpr int M = 1;
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
const U init = cu::ReduceInit<ReduceOp, T>::value();
ReduceOp op;
T vals[N];
U accs[M];
accs[0] = init;
size_t start = grid.block_rank() * block_step;
size_t end = start + block_step;
size_t check = min(end, size);
size_t i = start;
for (; i + block.size() * N <= check; i += block.size() * N) {
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
for (int j = 0; j < N; j++) {
accs[0] = op(accs[0], __cast<U, T>(vals[j]));
}
}
if (i < check) {
cub::LoadDirectBlocked(
block.thread_rank(), in + i, vals, check - i, __cast<T, U>(init));
for (int i = 0; i < N; i++) {
accs[0] = op(accs[0], __cast<U, T>(vals[i]));
}
}
__shared__ U shared_accumulators[32];
block_reduce(block, warp, accs, shared_accumulators, op, init);
if (block.thread_rank() == 0) {
out[grid.block_rank()] = accs[0];
}
}
} // namespace cu
void all_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type) {
constexpr int N_READS = 8;
out.set_data(allocator::malloc(out.nbytes()));
auto get_args = [](size_t size, int N) {
int threads = std::min(512UL, (size + N - 1) / N);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
int reductions_per_step = threads * N;
size_t steps_needed =
(size + reductions_per_step - 1) / reductions_per_step;
int blocks;
if (steps_needed < 32) {
blocks = 1;
} else if (steps_needed < 128) {
blocks = 32;
} else if (steps_needed < 512) {
blocks = 128;
} else if (steps_needed < 1024) {
blocks = 512;
} else {
blocks = 1024;
}
size_t steps_per_block = (steps_needed + blocks - 1) / blocks;
size_t block_step = steps_per_block * reductions_per_step;
return std::make_tuple(blocks, threads, block_step);
};
int blocks, threads;
size_t block_step;
size_t insize = in.size();
Dtype dt = in.dtype();
// Cub doesn't like const pointers for load (sigh).
void* indata = const_cast<void*>(in.data<void>());
// Large array so allocate an intermediate and accumulate there
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
encoder.set_input_array(in);
if (blocks > 1) {
array intermediate({blocks}, out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
encoder.add_temporary(intermediate);
encoder.set_output_array(intermediate);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(dt, CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
kernel<<<blocks, threads, 0, stream>>>(
static_cast<T*>(indata),
intermediate.data<U>(),
block_step,
insize);
});
});
});
// Set the input for the next step and recalculate the blocks
indata = intermediate.data<void>();
dt = intermediate.dtype();
insize = intermediate.size();
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
encoder.set_input_array(intermediate);
}
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(dt, CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
kernel<<<blocks, threads, 0, stream>>>(
static_cast<T*>(indata), out.data<U>(), block_step, insize);
});
});
});
}
} // namespace mlx::core

View File

@@ -1,7 +1,5 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include <numeric>
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/device/cast_op.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh"
@@ -38,36 +36,19 @@ struct ColReduceArgs {
const array& in, const array& in,
const ReductionPlan& plan, const ReductionPlan& plan,
const std::vector<int>& axes) { const std::vector<int>& axes) {
using ShapeVector = decltype(plan.shape);
using StridesVector = decltype(plan.strides);
ShapeVector shape_vec;
StridesVector strides_vec;
assert(!plan.shape.empty()); assert(!plan.shape.empty());
reduction_size = plan.shape.back(); reduction_size = plan.shape.back();
reduction_stride = plan.strides.back(); reduction_stride = plan.strides.back();
int64_t stride_back = 1; int64_t stride_back = 1;
std::tie(shape_vec, strides_vec) = shapes_without_reduction_axes(in, axes); auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
while (!shape_vec.empty() && stride_back < reduction_stride) { while (!shape_vec.empty() && stride_back < reduction_stride) {
stride_back *= shape_vec.back(); stride_back *= shape_vec.back();
shape_vec.pop_back(); shape_vec.pop_back();
strides_vec.pop_back(); strides_vec.pop_back();
} }
std::vector<int> indices(shape_vec.size());
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(), [&](int left, int right) {
return strides_vec[left] > strides_vec[right];
});
ShapeVector sorted_shape;
StridesVector sorted_strides;
for (auto idx : indices) {
sorted_shape.push_back(shape_vec[idx]);
sorted_strides.push_back(strides_vec[idx]);
}
std::tie(shape_vec, strides_vec) = std::tie(shape_vec, strides_vec) =
collapse_contiguous_dims(sorted_shape, sorted_strides); collapse_contiguous_dims(shape_vec, strides_vec);
shape = const_param(shape_vec); shape = const_param(shape_vec);
strides = const_param(strides_vec); strides = const_param(strides_vec);
ndim = shape_vec.size(); ndim = shape_vec.size();
@@ -83,6 +64,86 @@ struct ColReduceArgs {
} }
}; };
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
__global__ void col_reduce_small(
const T* in,
U* out,
const __grid_constant__ ColReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
int column =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
if (column * N_READS >= args.reduction_stride) {
return;
}
int out_idx = grid.block_rank() / grid.dim_blocks().x;
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
Op op;
U totals[N_READS];
for (int i = 0; i < N_READS; i++) {
totals[i] = ReduceInit<Op, T>::value();
}
// Read input to local.
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(
block.thread_index().y,
args.reduce_shape.data(),
args.reduce_strides.data());
for (size_t r = block.thread_index().y;
r < args.non_col_reductions * args.reduction_size;
r += block.dim_threads().y) {
U vals[N_READS];
cub::LoadDirectBlocked(
column,
make_cast_iterator<U>(in + loop.location()),
vals,
args.reduction_stride,
ReduceInit<Op, T>::value());
for (int i = 0; i < N_READS; i++) {
totals[i] = op(vals[i], totals[i]);
}
loop.next(
block.dim_threads().y,
args.reduce_shape.data(),
args.reduce_strides.data());
}
// Do block reduce when each column has more than 1 element to reduce.
if (block.dim_threads().y > 1) {
__shared__ U shared_vals[32 * 8 * N_READS];
size_t col =
block.thread_index().y * block.dim_threads().x + block.thread_index().x;
for (int i = 0; i < N_READS; i++) {
shared_vals[col * N_READS + i] = totals[i];
}
block.sync();
if (block.thread_index().y == 0) {
for (int i = 0; i < N_READS; i++) {
totals[i] = shared_vals[block.thread_index().x * N_READS + i];
}
for (int j = 1; j < block.dim_threads().y; j++) {
col = j * block.dim_threads().x + block.thread_index().x;
for (int i = 0; i < N_READS; i++) {
totals[i] = op(shared_vals[col * N_READS + i], totals[i]);
}
}
}
}
// Write result.
if (block.thread_index().y == 0) {
cub::StoreDirectBlocked(
column,
out + out_idx * args.reduction_stride,
totals,
args.reduction_stride);
}
}
template < template <
typename T, typename T,
typename U, typename U,
@@ -91,94 +152,67 @@ template <
int BM, int BM,
int BN, int BN,
int N_READS = 4> int N_READS = 4>
__global__ void __global__ void col_reduce_looped(
col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { const T* in,
U* out,
const __grid_constant__ ColReduceArgs args) {
auto grid = cg::this_grid(); auto grid = cg::this_grid();
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block); auto warp = cg::tiled_partition<WARP_SIZE>(block);
constexpr int threads_per_row = BN / N_READS; constexpr int n_warps = BN / N_READS;
// Compute the indices for the tile int out_idx = grid.block_rank() / grid.dim_blocks().x;
size_t tile_idx = grid.block_rank(); in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
// Compute the indices for the thread within the tile
short thread_x = block.thread_rank() % threads_per_row;
short thread_y = block.thread_rank() / threads_per_row;
// Move the input pointer
in += elem_to_loc(tile_y, args.shape.data(), args.strides.data(), args.ndim) +
tile_x * BN;
// Initialize the running totals
Op op; Op op;
U totals[N_READS]; U totals[N_READS];
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
totals[i] = ReduceInit<Op, T>::value(); totals[i] = ReduceInit<Op, T>::value();
} }
// Read input to local.
int r = block.thread_rank() / n_warps;
int column = block.thread_rank() % n_warps;
int in_offset = grid.block_index().x * BN;
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim); LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data()); loop.next(r, args.reduce_shape.data(), args.reduce_strides.data());
size_t total = args.non_col_reductions * args.reduction_size; for (; r < args.non_col_reductions * args.reduction_size; r += BM) {
if (tile_x * BN + BN <= args.reduction_stride) { U vals[N_READS];
if (args.reduction_stride % N_READS == 0) { cub::LoadDirectBlocked(
for (size_t r = thread_y; r < total; r += BM) { column,
T vals[N_READS]; make_cast_iterator<U>(in + loop.location() + in_offset),
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals); vals,
for (int i = 0; i < N_READS; i++) { args.reduction_stride - in_offset,
totals[i] = op(totals[i], __cast<U, T>(vals[i])); ReduceInit<Op, T>::value());
} for (int i = 0; i < N_READS; i++) {
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); totals[i] = op(vals[i], totals[i]);
}
} else {
for (size_t r = thread_y; r < total; r += BM) {
T vals[N_READS];
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) {
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
}
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
}
}
} else {
for (size_t r = thread_y; r < total; r += BM) {
T vals[N_READS];
cub::LoadDirectBlocked(
thread_x,
in + loop.location(),
vals,
args.reduction_stride - tile_x * BN,
__cast<T, U>(ReduceInit<Op, T>::value()));
for (int i = 0; i < N_READS; i++) {
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
}
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
} }
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
} }
// Do warp reduce for each output. // Do warp reduce for each output.
constexpr int n_outputs = BN / threads_per_row; constexpr int n_outputs = BN / n_warps;
static_assert(BM == 32 && n_outputs == N_READS); static_assert(BM == 32 && n_outputs == N_READS);
__shared__ U shared_vals[BM * BN]; __shared__ U shared_vals[BM * BN];
short s_idx = thread_y * BN + thread_x * N_READS; size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS;
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
shared_vals[s_idx + i] = totals[i]; shared_vals[col + i] = totals[i];
} }
block.sync(); block.sync();
s_idx = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs; col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
for (int i = 0; i < n_outputs; i++) { for (int i = 0; i < n_outputs; i++) {
totals[i] = cg::reduce(warp, shared_vals[s_idx + i], op); totals[i] = cg::reduce(warp, shared_vals[col + i], op);
} }
// Write result. // Write result.
if (warp.thread_rank() == 0) { if (warp.thread_rank() == 0) {
size_t out_offset = grid.block_index().x * BN;
cub::StoreDirectBlocked( cub::StoreDirectBlocked(
warp.meta_group_rank(), warp.meta_group_rank(),
out + tile_y * args.reduction_stride + tile_x * BN, out + out_idx * args.reduction_stride + out_offset,
totals, totals,
args.reduction_stride - tile_x * BN); args.reduction_stride - out_offset);
} }
} }
@@ -186,55 +220,14 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
inline auto output_grid_for_col_reduce( inline auto output_grid_for_col_reduce(
const array& out, const array& out,
const cu::ColReduceArgs& args, const cu::ColReduceArgs& args) {
int bn) { auto out_shape = out.shape();
int gx, gy = 1; auto out_strides = out.strides();
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn); while (!out_shape.empty() && out_strides.back() < args.reduction_stride) {
size_t n_outer_blocks = out.size() / args.reduction_stride; out_shape.pop_back();
size_t n_blocks = n_outer_blocks * n_inner_blocks; out_strides.pop_back();
while (n_blocks / gy > INT32_MAX) {
gy *= 2;
} }
gx = cuda::ceil_div(n_blocks, gy); return get_2d_grid_dims(out_shape, out_strides);
return dim3(gx, gy, 1);
}
void col_reduce_looped(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
cu::ColReduceArgs args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, args, BN);
int blocks = BM * BN / N_READS;
auto kernel = cu::col_reduce_looped<T, U, OP, NDIM, BM, BN, N_READS>;
kernel<<<grid, blocks, 0, stream>>>(indata, out.data<U>(), args);
});
});
});
});
} }
void col_reduce( void col_reduce(
@@ -244,23 +237,42 @@ void col_reduce(
Reduce::ReduceType reduce_type, Reduce::ReduceType reduce_type,
const std::vector<int>& axes, const std::vector<int>& axes,
const ReductionPlan& plan) { const ReductionPlan& plan) {
// Current col reduce options
//
// - col_reduce_looped
//
// It is a general strided reduce. Each threadblock computes the output for
// a subrow of the fast moving axis. For instance 32 elements.
//
// Notes: As in row reduce we opt to read as much in order as possible and
// leave transpositions as they are (contrary to our Metal backend).
//
// Moreover we need different kernels for short rows and tuning
// Make the args struct to help route to the best kernel
cu::ColReduceArgs args(in, plan, axes); cu::ColReduceArgs args(in, plan, axes);
// Fallback col reduce encoder.launch_kernel([&](cudaStream_t stream) {
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args); MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
using InType = cuda_type_t<CTYPE>;
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using OutType = cu::ReduceResult<OP, InType>::type;
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
constexpr int N_READS = 4;
dim3 block_dims;
dim3 num_blocks = output_grid_for_col_reduce(out, args);
num_blocks.z = num_blocks.y;
num_blocks.y = num_blocks.x;
auto kernel =
cu::col_reduce_small<InType, OutType, OP, NDIM, N_READS>;
size_t total = args.non_col_reductions * args.reduction_size;
if (total < 32) {
size_t stride_blocks =
cuda::ceil_div(args.reduction_stride, N_READS);
block_dims.x = std::min(stride_blocks, 32ul);
block_dims.y = std::min(total, 8ul);
num_blocks.x = cuda::ceil_div(stride_blocks, block_dims.x);
} else {
constexpr int BM = 32;
constexpr int BN = 32;
block_dims.x = BM * BN / N_READS;
num_blocks.x = cuda::ceil_div(args.reduction_stride, BN);
kernel = cu::
col_reduce_looped<InType, OutType, OP, NDIM, BM, BN, N_READS>;
}
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>(), out.data<OutType>(), args);
});
});
});
});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,50 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/reduce/reduce.cuh"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T, typename U, typename Op>
__global__ void init_reduce(U* out, size_t size) {
auto index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = ReduceInit<Op, T>::value();
}
}
} // namespace cu
void init_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type) {
// Allocate if needed
if (out.data_shared_ptr() == nullptr) {
out.set_data(allocator::malloc(out.nbytes()));
}
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
auto kernel = cu::init_reduce<T, U, OP>;
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
grid.x = (grid.x + 1023) / 1024;
kernel<<<grid, block, 0, stream>>>(out.data<U>(), out.size());
});
});
});
}
} // namespace mlx::core

View File

@@ -47,11 +47,13 @@ namespace mlx::core {
throw std::invalid_argument("Unknown reduce type."); \ throw std::invalid_argument("Unknown reduce type."); \
} }
void all_reduce( void segmented_reduce(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
const array& in, const array& in,
array& out, array& out,
Reduce::ReduceType reduce_type); Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan);
void row_reduce( void row_reduce(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
@@ -69,10 +71,4 @@ void col_reduce(
const std::vector<int>& axes, const std::vector<int>& axes,
const ReductionPlan& plan); const ReductionPlan& plan);
void init_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -3,89 +3,48 @@
#pragma once #pragma once
#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/reduce/reduce_utils.cuh"
namespace mlx::core::cu { namespace mlx::core::cu {
// Reduce ops. // Reduce ops.
struct And { struct And {
__device__ __forceinline__ bool operator()(bool a, bool b) { __device__ bool operator()(bool a, bool b) {
return a && b; return a && b;
} }
__device__ void atomic_update(bool* x, bool y) {
atomic_reduce<bool, And>(x, y);
}
}; };
struct Or { struct Or {
__device__ __forceinline__ bool operator()(bool a, bool b) { __device__ bool operator()(bool a, bool b) {
return a || b; return a || b;
} }
__device__ void atomic_update(bool* x, bool y) {
atomic_reduce<bool, Or>(x, y);
}
}; };
struct Sum { struct Sum {
template <typename T> template <typename T>
__device__ __forceinline__ T operator()(T a, T b) { __device__ T operator()(T a, T b) {
return a + b; return a + b;
} }
template <typename T>
__device__ void atomic_update(T* x, T y) {
atomic_reduce<T, Sum>(x, y);
}
__device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) {
atomicAdd(x, y);
}
__device__ void atomic_update(int* x, int y) {
atomicAdd(x, y);
}
__device__ void atomic_update(float* x, float y) {
atomicAdd(x, y);
}
}; };
struct Prod { struct Prod {
template <typename T> template <typename T>
__device__ __forceinline__ T operator()(T a, T b) { __device__ T operator()(T a, T b) {
return a * b; return a * b;
} }
template <typename T>
__device__ void atomic_update(T* x, T y) {
atomic_reduce<T, Prod>(x, y);
}
}; };
struct Min { struct Min {
template <typename T> template <typename T>
__device__ __forceinline__ T operator()(T a, T b) { __device__ T operator()(T a, T b) {
return a < b ? a : b; return a < b ? a : b;
} }
template <typename T>
__device__ void atomic_update(T* x, T y) {
atomic_reduce<T, Min>(x, y);
}
}; };
struct Max { struct Max {
template <typename T> template <typename T>
__device__ __forceinline__ T operator()(T a, T b) { __device__ T operator()(T a, T b) {
return a > b ? a : b; return a > b ? a : b;
} }
template <typename T>
__device__ void atomic_update(T* x, T y) {
atomic_reduce<T, Max>(x, y);
}
}; };
// Traits to get the result type of reduce op. // Traits to get the result type of reduce op.
@@ -161,7 +120,7 @@ template <typename T>
struct ReduceInit<Prod, T> { struct ReduceInit<Prod, T> {
static constexpr __host__ __device__ auto value() { static constexpr __host__ __device__ auto value() {
if constexpr (cuda::std::is_same_v<T, cuComplex>) { if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return T{1, 0}; return T{1, 1};
} else { } else {
return typename ReduceResult<Prod, T>::type{1}; return typename ReduceResult<Prod, T>::type{1};
} }

View File

@@ -1,158 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <numeric>
#include "mlx/backend/cuda/device/utils.cuh"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <size_t N>
struct uint_by_size;
template <>
struct uint_by_size<2> {
using type = uint16_t;
};
template <>
struct uint_by_size<4> {
using type = uint32_t;
};
template <>
struct uint_by_size<8> {
using type = unsigned long long int;
};
template <typename T, typename Op>
__device__ void atomic_reduce(T* x, T y) {
if constexpr (sizeof(T) == 1) {
using U = uint16_t;
U* x_int = (U*)((char*)x - ((size_t)x % 2));
int shift = ((char*)x - (char*)x_int) * 8;
int mask = 0xff << shift;
U old_val, new_val;
do {
old_val = *x_int;
T result = Op{}(static_cast<T>((old_val >> shift) & 0xff), y);
new_val = (old_val & ~mask) | (result << shift);
} while (atomicCAS(x_int, old_val, new_val) != old_val);
} else {
using U = typename uint_by_size<sizeof(T)>::type;
U* x_int = (U*)(x);
U old_val, new_val;
do {
old_val = *x_int;
T result = Op{}(*((T*)&old_val), y);
new_val = *((U*)&result);
} while (atomicCAS(x_int, old_val, new_val) != old_val);
}
}
// TODO: Should make a custom complex type
template <typename U, typename T>
inline __device__ U __cast(T x) {
return static_cast<U>(x);
}
template <>
inline __device__ bool __cast<bool, cuComplex>(cuComplex x) {
return x.x != 0 && x.y != 0;
}
template <>
inline __device__ cuComplex __cast<cuComplex, bool>(bool x) {
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
}
template <typename T, int N, typename Block, typename Warp, typename Op>
inline __device__ void
block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) {
// First reduce in the current warp
for (int i = 0; i < N; i++) {
vals[i] = cg::reduce(warp, vals[i], op);
}
// Reduce across warps
if (warp.meta_group_size() > 1) {
if (warp.thread_rank() == 0) {
for (int i = 0; i < N; i++) {
smem[warp.meta_group_rank() * N + i] = vals[i];
}
}
block.sync();
if (warp.thread_rank() < warp.meta_group_size()) {
for (int i = 0; i < N; i++) {
vals[i] = smem[warp.thread_rank() * N + i];
}
} else {
for (int i = 0; i < N; i++) {
vals[i] = init;
}
}
for (int i = 0; i < N; i++) {
vals[i] = cg::reduce(warp, vals[i], op);
}
}
}
} // namespace cu
inline void allocate_same_layout(
array& out,
const array& in,
const std::vector<int>& axes) {
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
return;
}
if (out.ndim() < in.ndim()) {
throw std::runtime_error(
"Reduction without keepdims only supported for row-contiguous inputs");
}
// Calculate the transpositions applied to in in order to apply them to out.
std::vector<int> axis_order(in.ndim());
std::iota(axis_order.begin(), axis_order.end(), 0);
std::sort(axis_order.begin(), axis_order.end(), [&](int left, int right) {
return in.strides(left) > in.strides(right);
});
// Transpose the shape and calculate the strides
Shape out_shape(in.ndim());
Strides out_strides(in.ndim(), 1);
for (int i = 0; i < in.ndim(); i++) {
out_shape[i] = out.shape(axis_order[i]);
}
for (int i = in.ndim() - 2; i >= 0; i--) {
out_strides[i] = out_shape[i + 1] * out_strides[i + 1];
}
// Reverse the axis order to get the final strides
Strides final_strides(in.ndim());
for (int i = 0; i < in.ndim(); i++) {
final_strides[axis_order[i]] = out_strides[i];
}
// Calculate the resulting contiguity and do the memory allocation
auto [data_size, rc, cc] = check_contiguity(out.shape(), final_strides);
auto fl = in.flags();
fl.row_contiguous = rc;
fl.col_contiguous = cc;
fl.contiguous = true;
out.set_data(
allocator::malloc(out.nbytes()),
data_size,
final_strides,
fl,
allocator::free);
}
} // namespace mlx::core

View File

@@ -1,7 +1,5 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include <numeric>
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/device/cast_op.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh"
@@ -57,108 +55,84 @@ struct RowReduceArgs {
non_row_reductions *= reduce_shape[i]; non_row_reductions *= reduce_shape[i];
} }
} }
// Convert shape and strides as if in was contiguous
void sort_access_pattern(const array& in, const std::vector<int>& axes) {
auto shape_vec = in.shape();
auto strides_vec = in.strides();
std::tie(shape_vec, strides_vec) =
shapes_without_reduction_axes(shape_vec, strides_vec, axes);
std::vector<int> indices(shape_vec.size());
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(), [&](int left, int right) {
return strides_vec[left] > strides_vec[right];
});
decltype(shape_vec) sorted_shape;
decltype(strides_vec) sorted_strides;
for (auto idx : indices) {
sorted_shape.push_back(shape_vec[idx]);
sorted_strides.push_back(strides_vec[idx]);
}
std::tie(shape_vec, strides_vec) =
collapse_contiguous_dims(sorted_shape, sorted_strides);
shape = const_param(shape_vec);
strides = const_param(strides_vec);
ndim = shape_vec.size();
}
}; };
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1> template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { __global__ void row_reduce_small(
const T* in,
U* out,
size_t out_size,
const __grid_constant__ RowReduceArgs args) {
size_t out_idx = cg::this_grid().thread_rank();
if (out_idx >= out_size) {
return;
}
Op op;
U total_val = ReduceInit<Op, T>::value();
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
for (size_t n = 0; n < args.non_row_reductions; n++) {
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
U vals[N_READS];
cub::LoadDirectBlocked(
r,
make_cast_iterator<U>(in + loop.location()),
vals,
args.row_size,
ReduceInit<Op, T>::value());
total_val = op(total_val, cub::ThreadReduce(vals, op));
}
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
out[out_idx] = total_val;
}
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
__global__ void row_reduce_small_warp(
const T* in,
U* out,
size_t out_size,
const __grid_constant__ RowReduceArgs args) {
auto grid = cg::this_grid(); auto grid = cg::this_grid();
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block); auto warp = cg::tiled_partition<WARP_SIZE>(block);
const U init = cu::ReduceInit<ReduceOp, T>::value(); size_t out_idx = grid.thread_rank() / WARP_SIZE;
ReduceOp op; if (out_idx >= out_size) {
return;
T vals[M][N];
U accs[M];
for (int i = 0; i < M; i++) {
accs[i] = init;
} }
const size_t start_row = Op op;
min(n_rows - M, static_cast<size_t>(grid.block_rank() * M));
const size_t full_blocks = size / (block.size() * N);
const size_t final_offset = full_blocks * (block.size() * N);
in += start_row * size;
out += start_row;
if (size % N == 0) { U total_val = ReduceInit<Op, T>::value();
for (size_t r = 0; r < full_blocks; r++) { LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
for (int k = 0; k < M; k++) {
cub::LoadDirectBlockedVectorized<T, N>(
block.thread_rank(),
in + k * size + r * (block.size() * N),
vals[k]);
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
}
}
}
} else {
for (size_t r = 0; r < full_blocks; r++) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlocked(
block.thread_rank(),
in + k * size + r * (block.size() * N),
vals[k]);
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
}
}
}
}
if (final_offset < size) { in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
for (int k = 0; k < M; k++) {
for (size_t n = warp.thread_rank(); n < args.non_row_reductions;
n += WARP_SIZE) {
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
U vals[N_READS];
cub::LoadDirectBlocked( cub::LoadDirectBlocked(
block.thread_rank(), r,
in + k * size + final_offset, make_cast_iterator<U>(in + loop.location()),
vals[k], vals,
size, args.row_size,
__cast<T, U>(init)); ReduceInit<Op, T>::value());
for (int j = 0; j < N; j++) { total_val = op(total_val, cub::ThreadReduce(vals, op));
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
}
} }
loop.next(WARP_SIZE, args.reduce_shape.data(), args.reduce_strides.data());
} }
__shared__ U shared_accumulators[32 * M]; total_val = cg::reduce(warp, total_val, op);
block_reduce(block, warp, accs, shared_accumulators, op, init);
if (block.thread_rank() == 0) { if (warp.thread_rank() == 0) {
if (grid.block_rank() * M + M <= n_rows) { out[out_idx] = total_val;
for (int i = 0; i < M; i++) {
out[i] = accs[i];
}
} else {
short offset = grid.block_rank() * M + M - n_rows;
for (int i = offset; i < M; i++) {
out[i] = accs[i];
}
}
} }
} }
@@ -167,165 +141,55 @@ template <
typename U, typename U,
typename Op, typename Op,
int NDIM, int NDIM,
int BLOCK_DIM, int BLOCK_DIM_X,
int N_READS = 4> int N_READS = 4>
__global__ void row_reduce_looped( __global__ void row_reduce_looped(
T* in, const T* in,
U* out, U* out,
size_t out_size, size_t out_size,
const __grid_constant__ RowReduceArgs args) { const __grid_constant__ RowReduceArgs args) {
auto grid = cg::this_grid(); auto grid = cg::this_grid();
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
size_t out_idx = grid.block_rank(); size_t out_idx = grid.thread_rank() / BLOCK_DIM_X;
if (out_idx >= out_size) {
return;
}
Op op; Op op;
U total[1]; U total_val = ReduceInit<Op, T>::value();
U init = ReduceInit<Op, T>::value();
total[0] = init;
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim); LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS);
size_t final_offset = full_blocks * BLOCK_DIM * N_READS;
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
for (size_t n = 0; n < args.non_row_reductions; n++) { for (size_t n = 0; n < args.non_row_reductions; n++) {
for (size_t r = 0; r < full_blocks; r++) { for (size_t r = 0; r < cuda::ceil_div(args.row_size, BLOCK_DIM_X * N_READS);
T vals[N_READS]; r++) {
cub::LoadDirectBlockedVectorized<T, N_READS>( U vals[N_READS];
block.thread_rank(),
in + loop.location() + r * BLOCK_DIM * N_READS,
vals);
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], __cast<U, T>(vals[i]));
}
}
if (final_offset < args.row_size) {
T vals[N_READS];
cub::LoadDirectBlocked( cub::LoadDirectBlocked(
block.thread_rank(), r * BLOCK_DIM_X + block.thread_index().x,
in + loop.location() + final_offset, make_cast_iterator<U>(in + loop.location()),
vals, vals,
args.row_size - final_offset, args.row_size,
__cast<T, U>(init)); ReduceInit<Op, T>::value());
for (int i = 0; i < N_READS; i++) { total_val = op(total_val, cub::ThreadReduce(vals, op));
total[0] = op(total[0], __cast<U, T>(vals[i]));
}
} }
// TODO: Maybe block.sync() here?
loop.next(args.reduce_shape.data(), args.reduce_strides.data()); loop.next(args.reduce_shape.data(), args.reduce_strides.data());
} }
__shared__ U shared_accumulators[32]; typedef cub::BlockReduce<U, BLOCK_DIM_X> BlockReduceT;
block_reduce(block, warp, total, shared_accumulators, op, init); __shared__ typename BlockReduceT::TempStorage temp;
total_val = BlockReduceT(temp).Reduce(total_val, op);
if (block.thread_rank() == 0) { if (block.thread_rank() == 0) {
out[out_idx] = total[0]; out[out_idx] = total_val;
} }
} }
} // namespace cu } // namespace cu
void row_reduce_simple(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan) {
constexpr int N_READS = 8;
// Allocate data for the output using in's layout to avoid elem_to_loc in the
// kernel.
allocate_same_layout(out, in, axes);
// TODO: If out.size() < 1024 which will be a common case then write this in
// 2 passes. Something like 32 * out.size() and then do a warp reduce.
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
// Calculate the grid and block dims
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
int threads = std::min(1024UL, reductions);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
dim3 block(threads, 1, 1);
// Pick the kernel
auto kernel = cu::row_reduce_simple<T, U, OP, N_READS>;
if (grid.x >= 1024) {
grid.x = (grid.x + 1) / 2;
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
}
// Launch
kernel<<<grid, block, 0, stream>>>(
indata, out.data<U>(), out.size(), plan.shape.back());
});
});
});
}
void row_reduce_looped(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
cu::RowReduceArgs args) {
constexpr int N_READS = 8;
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
// Calculate the grid and block dims
args.sort_access_pattern(in, axes);
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
int threads = std::min(1024UL, reductions);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
dim3 block(threads, 1, 1);
// Pick the kernel
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
MLX_SWITCH_BLOCK_DIM(threads, THREADS, {
kernel = cu::row_reduce_looped<T, U, OP, NDIM, THREADS, N_READS>;
block.x = THREADS;
});
});
// Launch
kernel<<<grid, block, 0, stream>>>(
indata, out.data<U>(), out.size(), args);
});
});
});
}
void row_reduce( void row_reduce(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
const array& in, const array& in,
@@ -333,35 +197,54 @@ void row_reduce(
Reduce::ReduceType reduce_type, Reduce::ReduceType reduce_type,
const std::vector<int>& axes, const std::vector<int>& axes,
const ReductionPlan& plan) { const ReductionPlan& plan) {
// Current row reduction options
//
// - row_reduce_simple
//
// That means that we are simply reducing across the fastest moving axis.
// We are reducing 1 or 2 rows per threadblock depending on the size of
// output.
//
// - row_reduce_looped
//
// It is a general row reduction. We are computing 1 output per
// threadblock. We read the fastest moving axis vectorized and loop over
// the rest of the axes.
//
// Notes: We opt to read as much in order as possible and leave
// transpositions as they are (contrary to our Metal backend).
// Simple row reduce means that we have 1 axis that we are reducing over and
// it has stride 1.
if (plan.shape.size() == 1) {
row_reduce_simple(encoder, in, out, reduce_type, axes, plan);
return;
}
// Make the args struct to help route to the best kernel
cu::RowReduceArgs args(in, plan, axes); cu::RowReduceArgs args(in, plan, axes);
// Fallback row reduce encoder.launch_kernel([&](cudaStream_t stream) {
row_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args)); MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
using InType = cuda_type_t<CTYPE>;
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using OutType = cu::ReduceResult<OP, InType>::type;
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
constexpr size_t N_READS = 4;
dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides());
dim3 block_dims, num_blocks;
auto kernel =
cu::row_reduce_small<InType, OutType, OP, NDIM, N_READS>;
if (args.row_size <= 64) {
if ((args.non_row_reductions < 32 && args.row_size <= 8) ||
(args.non_row_reductions <= 8)) {
block_dims.x = std::min(out_dims.x, 1024u);
num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x);
num_blocks.y = out_dims.y;
} else {
block_dims.x = WARP_SIZE;
num_blocks.y = out_dims.x;
num_blocks.z = out_dims.y;
kernel =
cu::row_reduce_small_warp<InType, OutType, OP, NDIM, N_READS>;
}
} else {
size_t num_threads = cuda::ceil_div(args.row_size, N_READS);
num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE;
MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, {
num_blocks.y = out_dims.x;
num_blocks.z = out_dims.y;
block_dims.x = BLOCK_DIM_X;
kernel = cu::row_reduce_looped<
InType,
OutType,
OP,
NDIM,
BLOCK_DIM_X,
N_READS>;
});
}
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>(), out.data<OutType>(), out.size(), args);
});
});
});
});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -0,0 +1,84 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/cast_op.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh"
#include <thrust/device_ptr.h>
#include <cub/device/device_reduce.cuh>
#include <cub/device/device_segmented_reduce.cuh>
namespace mlx::core {
template <typename... Args>
void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) {
// Allocate temporary storage.
size_t size;
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Run op.
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data<void>(), size, args...));
}
template <typename... Args>
void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) {
// Allocate temporary storage.
size_t size;
CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Run op.
CHECK_CUDA_ERROR(
cub::DeviceSegmentedReduce::Reduce(temp.data<void>(), size, args...));
}
struct MultiplyOp {
int factor;
__device__ int operator()(int i) {
return i * factor;
}
};
void segmented_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan) {
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using InType = cuda_type_t<CTYPE>;
using OutType = cu::ReduceResult<OP, InType>::type;
auto in_iter = cu::make_cast_iterator<OutType>(
thrust::device_pointer_cast(in.data<InType>()));
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
auto init = cu::ReduceInit<OP, InType>::value();
if (plan.type == ContiguousAllReduce) {
cub_all_reduce(
encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream);
} else if (plan.type == ContiguousReduce) {
auto offsets = thrust::make_transform_iterator(
thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()});
cub_segmented_reduce(
encoder,
in_iter,
out_ptr,
out.size(),
offsets,
offsets + 1,
OP(),
init,
stream);
} else {
throw std::runtime_error("Unsupported plan in segmented_reduce.");
}
});
});
});
}
} // namespace mlx::core

View File

@@ -0,0 +1,51 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/fast_primitives.h"
namespace mlx::core {
namespace cu {} // namespace cu
namespace fast {
bool ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
if (detail::in_grad_tracing()) {
return true;
}
if (s.device == Device::cpu) {
return true;
}
const int value_head_dim = v.shape(-1);
const int query_head_dim = q.shape(-1);
const int query_sequence_length = q.shape(2);
const int key_sequence_length = k.shape(2);
const bool sdpa_vector_supported_head_dim =
query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
query_head_dim == 256);
const bool supports_sdpa_vector = (query_sequence_length <= 1) &&
(query_sequence_length <= key_sequence_length) &&
sdpa_vector_supported_head_dim;
return !supports_sdpa_vector;
}
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
array& out) {
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
}
} // namespace fast
} // namespace mlx::core

View File

@@ -51,7 +51,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
make_cast_iterator<AccT>(in), make_cast_iterator<AccT>(in),
vals, vals,
axis_size, axis_size,
Limits<AccT>::min()); Limits<AccT>::finite_min());
prevmax = maxval; prevmax = maxval;
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op)); maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
// Online normalizer calculation for softmax: // Online normalizer calculation for softmax:
@@ -79,7 +79,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
block.sync(); block.sync();
maxval = warp.thread_rank() < warp.meta_group_size() maxval = warp.thread_rank() < warp.meta_group_size()
? local_max[warp.thread_rank()] ? local_max[warp.thread_rank()]
: Limits<AccT>::min(); : Limits<AccT>::finite_min();
maxval = cg::reduce(warp, maxval, max_op); maxval = cg::reduce(warp, maxval, max_op);
normalizer = normalizer * softmax_exp(prevmax - maxval); normalizer = normalizer * softmax_exp(prevmax - maxval);
if (warp.thread_rank() == 0) { if (warp.thread_rank() == 0) {

View File

@@ -79,10 +79,14 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
array out = out_; array out = out_;
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
if (axis < 0) { if (axis < 0) {
axis += in.ndim(); axis += in.ndim();
} }
int nsort = in.shape(axis); int nsort = in.shape(axis);
int nsegments = in.data_size() / nsort;
int last_dim = in.ndim() - 1; int last_dim = in.ndim() - 1;
// If we are not sorting the innermost dimension of a contiguous array, // If we are not sorting the innermost dimension of a contiguous array,
@@ -96,15 +100,9 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
encoder.add_temporary(out); encoder.add_temporary(out);
} else { } else {
out.set_data( out.set_data(allocator::malloc(out.nbytes()));
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
} }
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
if constexpr (!std::is_same_v<CTYPE, complex64_t>) { if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
@@ -136,7 +134,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
indices.data<uint32_t>(), indices.data<uint32_t>(),
out.data<uint32_t>(), out.data<uint32_t>(),
in.data_size(), in.data_size(),
in.data_size() / nsort, nsegments,
offsets, offsets,
offsets + 1, offsets + 1,
stream); stream);
@@ -146,7 +144,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.data<Type>(), in.data<Type>(),
out.data<Type>(), out.data<Type>(),
in.data_size(), in.data_size(),
in.data_size() / nsort, nsegments,
offsets, offsets,
offsets + 1, offsets + 1,
stream); stream);
@@ -179,14 +177,4 @@ void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
gpu_sort(stream(), inputs[0], out, axis_, false); gpu_sort(stream(), inputs[0], out, axis_, false);
} }
void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("ArgPartition::eval_gpu");
gpu_sort(stream(), inputs[0], out, axis_, true);
}
void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Partition::eval_gpu");
gpu_sort(stream(), inputs[0], out, axis_, false);
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -80,9 +80,7 @@ void Worker::thread_fn() {
} }
worker_tasks_.erase(worker_tasks_.begin(), end); worker_tasks_.erase(worker_tasks_.begin(), end);
} }
// Make sure tasks are cleared before the next wait for (auto& task : tasks) {
for (int i = 0; i < tasks.size(); ++i) {
auto task = std::move(tasks[i]);
task(); task();
} }
worker_event_.wait(batch + 1); worker_event_.wait(batch + 1);

View File

@@ -245,30 +245,6 @@ void merge(array& dst, array& src, ParentsMap& parents_map) {
} }
} }
// Any parent in the divider will continue to refer to `x` but any parent not
// in the divider will refer to a copy of the operation.
array split_one(
const array& x,
ParentsMap& parents_map,
const std::unordered_set<uintptr_t>& divider) {
array y(x.shape(), x.dtype(), x.primitive_ptr(), x.inputs());
auto& x_parents = parents_map[x.id()];
auto& y_parents = parents_map[y.id()];
for (auto it = x_parents.begin(); it != x_parents.end();) {
if (divider.find(it->first.id()) != divider.end()) {
it->first.inputs()[it->second] = y;
y_parents.emplace_back(std::move(*it));
it = x_parents.erase(it);
} else {
it++;
}
}
return std::move(y);
}
template <typename T, typename... U> template <typename T, typename... U>
std::uintptr_t get_function_address(const std::function<T(U...)>& fun) { std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
using FunType = T (*)(U...); using FunType = T (*)(U...);
@@ -693,16 +669,10 @@ void compile_fuse(
} }
// Arrays with a mix of parents outside the compilable section // Arrays with a mix of parents outside the compilable section
// are not fusable except for broadcast which we can split to avoid // are not fusable
// stopping fusion
if (!all_parents_in) { if (!all_parents_in) {
if (a.has_primitive() && is_broadcast(a.primitive())) { // Possible input
array b = split_one(a, parents_map, cache); input_set.insert(a.id());
recurse(b, depth, s, shape);
} else {
// Possible input
input_set.insert(a.id());
}
return; return;
} }

View File

@@ -546,7 +546,7 @@ class GELU(Module):
See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the
functional equivalents and information regarding error bounds. functional equivalents and information regarding error bounds.
Args: Args:
approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any. approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any.
@@ -554,19 +554,20 @@ class GELU(Module):
def __init__(self, approx="none"): def __init__(self, approx="none"):
super().__init__() super().__init__()
self._approx = approx
allowed = ["none", "precise", "tanh", "fast"] if approx == "none":
if approx not in allowed: self._act = gelu
elif approx == "precise" or approx == "tanh":
self._act = gelu_approx
elif approx == "fast":
self._act = gelu_fast_approx
else:
raise ValueError( raise ValueError(
f"The approximation should be in {allowed} but '{approx}' was given" f"The approximation should be in ['none', 'precise', 'tanh', 'fast'] but '{approx}' was given"
) )
def __call__(self, x): def __call__(self, x):
if self._approx == "none": return self._act(x)
return gelu(x)
elif self._approx in ["precise", "tanh"]:
return gelu_approx(x)
return gelu_fast_approx(x)
@_make_activation_module(tanh) @_make_activation_module(tanh)

View File

@@ -404,7 +404,7 @@ class Module(dict):
dst[k] = new_value dst[k] = new_value
elif isinstance(current_value, (dict, list)): elif isinstance(current_value, (dict, list)):
apply(current_value, new_value) apply(current_value, new_value)
elif strict and new_value != {}: elif strict:
raise ValueError( raise ValueError(
f"Received invalid type: {type(new_value).__name__}." f"Received invalid type: {type(new_value).__name__}."
) )
@@ -413,14 +413,14 @@ class Module(dict):
f'Module does not have sub-module named "{k}".' f'Module does not have sub-module named "{k}".'
) )
elif isinstance(modules, list): elif isinstance(modules, list):
for i in range(len(modules)): for i in range(len(dst)):
current_value = dst[i] current_value = dst[i]
new_value = modules[i] new_value = modules[i]
if self.is_module(current_value) and self.is_module(new_value): if self.is_module(current_value) and self.is_module(new_value):
dst[i] = new_value dst[i] = new_value
elif isinstance(current_value, (dict, list)): elif isinstance(current_value, (dict, list)):
apply(current_value, new_value) apply(current_value, new_value)
elif strict and new_value != {}: elif strict:
raise ValueError( raise ValueError(
f"Received invalid type: {type(new_value).__name__}." f"Received invalid type: {type(new_value).__name__}."
) )

View File

@@ -1,17 +0,0 @@
#!/bin/bash
auditwheel repair dist/* \
--plat manylinux_2_35_x86_64 \
--exclude libcublas* \
--exclude libnvrtc*
cd wheelhouse
repaired_wheel=$(find . -name "*.whl" -print -quit)
unzip -q "${repaired_wheel}"
core_so=$(find mlx -name "core*.so" -print -quit)
rpath=$(patchelf --print-rpath "${core_so}")
rpath=$rpath:\$ORIGIN/../nvidia/cublas/lib:\$ORIGIN/../nvidia/cuda_nvrtc/lib
patchelf --force-rpath --set-rpath "$rpath" "$core_so"
# Re-zip the repaired wheel
zip -r -q "${repaired_wheel}" .

View File

@@ -205,8 +205,6 @@ nb::object to_scalar(mx::array& a) {
return nb::cast(static_cast<float>(a.item<mx::bfloat16_t>())); return nb::cast(static_cast<float>(a.item<mx::bfloat16_t>()));
case mx::complex64: case mx::complex64:
return nb::cast(a.item<std::complex<float>>()); return nb::cast(a.item<std::complex<float>>());
case mx::float64:
return nb::cast(a.item<double>());
default: default:
throw nb::type_error("type cannot be converted to Python scalar."); throw nb::type_error("type cannot be converted to Python scalar.");
} }

View File

@@ -1,19 +1,37 @@
cuda_skip = { cuda_skip = {
"TestArray.test_api", "TestArray.test_api",
"TestAutograd.test_update_state",
"TestBF16.test_arg_reduction_ops", "TestBF16.test_arg_reduction_ops",
"TestBF16.test_reduction_ops",
"TestBlas.test_complex_gemm", "TestBlas.test_complex_gemm",
"TestCompile.test_compile_dynamic_dims",
"TestEinsum.test_ellipses", "TestEinsum.test_ellipses",
"TestEinsum.test_opt_einsum_test_cases", "TestEinsum.test_opt_einsum_test_cases",
"TestLoad.test_load_f8_e4m3", "TestLoad.test_load_f8_e4m3",
"TestMemory.test_memory_info",
"TestLayers.test_group_norm", "TestLayers.test_group_norm",
"TestLayers.test_pooling", "TestLayers.test_pooling",
"TestLayers.test_quantized_embedding", "TestLayers.test_quantized_embedding",
"TestLayers.test_sin_pe", "TestLayers.test_sin_pe",
"TestLayers.test_upsample", "TestLayers.test_upsample",
"TestOps.test_array_equal",
"TestOps.test_complex_ops", "TestOps.test_complex_ops",
"TestOps.test_dynamic_slicing", "TestOps.test_dynamic_slicing",
"TestOps.test_softmax",
"TestOps.test_sort",
"TestOps.test_tile",
"TestReduce.test_axis_permutation_sums",
"TestReduce.test_dtypes", "TestReduce.test_dtypes",
"TestReduce.test_expand_sums",
"TestReduce.test_many_reduction_axes",
"TestUpsample.test_torch_upsample", "TestUpsample.test_torch_upsample",
# DivMod NYI
"TestOps.test_divmod",
"TestEval.test_multi_output_eval_during_transform",
# Partition NYI
"TestAutograd.test_topk_grad",
"TestOps.test_argpartition",
"TestOps.test_partition",
# Block masked matmul NYI # Block masked matmul NYI
"TestBlas.test_block_masked_matmul", "TestBlas.test_block_masked_matmul",
# Gather matmul NYI # Gather matmul NYI

View File

@@ -2,10 +2,8 @@
import gc import gc
import io import io
import math
import unittest import unittest
from functools import partial from functools import partial
from io import StringIO
import mlx.core as mx import mlx.core as mx
import mlx_tests import mlx_tests
@@ -981,39 +979,6 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertEqual(mem_pre, mem_post) self.assertEqual(mem_pre, mem_post)
def test_double_constant(self):
with mx.stream(mx.cpu):
x = mx.array(1.0, dtype=mx.float64)
def fun(x):
return (x + math.pi) * 2.0
y = fun(x).item()
y_compiled = mx.compile(fun)(x).item()
self.assertEqual(y, y_compiled)
def test_shared_broadcast(self):
def fun(x, y, z):
yy = mx.broadcast_to(y, z.shape)
return (x + yy * z), yy.sum()
a = mx.random.normal((10, 10))
b = mx.array(0.1)
c = mx.random.normal((10, 10))
mx.eval(a, b, c)
fc = mx.compile(fun)
d = fc(a, b, c)
s = StringIO()
mx.export_to_dot(s, a=a, b=b, c=c, d1=d[0], d2=d[1])
s.seek(0)
s = s.read()
self.assertTrue("CompiledBroadcastMultiplyAdd" in s)
d_hat = fun(a, b, c)
self.assertTrue(mx.allclose(d[0], d_hat[0]))
self.assertTrue(mx.allclose(d[1], d_hat[1]))
if __name__ == "__main__": if __name__ == "__main__":
mlx_tests.MLXTestRunner() mlx_tests.MLXTestRunner()

View File

@@ -259,21 +259,6 @@ class TestBase(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
m = m.update_modules({"list": ["hi"]}) m = m.update_modules({"list": ["hi"]})
# Allow updating a strict subset
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
m.update_modules({"layers": [{}, nn.Linear(3, 4)]})
self.assertEqual(m.layers[1].weight.shape, (4, 3))
# Using leaf_modules in the update should always work
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.stuff = [nn.Linear(2, 2), 0, nn.Linear(2, 2)]
self.more_stuff = {"hi": nn.Linear(2, 2), "bye": 0}
m = MyModel()
m.update_modules(m.leaf_modules())
class TestLayers(mlx_tests.MLXTestCase): class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self): def test_identity(self):

View File

@@ -174,26 +174,20 @@ if __name__ == "__main__":
) )
package_dir = {"": "python"} package_dir = {"": "python"}
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]} package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]}
install_requires = []
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
if build_cuda:
install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"]
setup( setup(
name="mlx-cuda" if build_cuda else "mlx", name="mlx",
version=get_version(), version=get_version(),
author="MLX Contributors", author="MLX Contributors",
author_email="mlx@group.apple.com", author_email="mlx@group.apple.com",
description="A framework for machine learning on Apple silicon.", description="A framework for machine learning on Apple silicon.",
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
license="MIT",
url="https://github.com/ml-explore/mlx", url="https://github.com/ml-explore/mlx",
packages=packages, packages=packages,
package_dir=package_dir, package_dir=package_dir,
package_data=package_data, package_data=package_data,
include_package_data=True, include_package_data=True,
install_requires=install_requires,
extras_require={ extras_require={
"dev": [ "dev": [
"nanobind==2.4.0", "nanobind==2.4.0",