Compare commits

...

14 Commits

Author SHA1 Message Date
Awni Hannun
58f3860306 patch bump (#2324) 2025-07-01 12:12:16 -07:00
Awni Hannun
dd4f53db63 use fp32 for testing, add more complex ops (#2322) 2025-07-01 07:30:00 -07:00
Angelos Katharopoulos
3d5e17e507 MLX_SWITCH macros to templates (#2320) 2025-07-01 01:33:44 -07:00
Awni Hannun
33bf1a244b Fix module update in strict mode (#2321)
* fix module update in strict mode

* allow GELU to be pickled
2025-06-29 11:12:29 -07:00
Angelos Katharopoulos
772f471ff2 [CUDA] Fix reductions (#2314) 2025-06-27 12:59:20 -07:00
Angelos Katharopoulos
2c11d10f8d Split broadcast so it is always fused in compile (#2318) 2025-06-26 22:08:18 -07:00
Angelos Katharopoulos
656ed7f780 Fix get 2d grid dims (#2316) 2025-06-25 13:03:09 -07:00
Awni Hannun
81bb9a2a9e Compile float64 functions on CPU (#2311) 2025-06-24 10:18:52 -07:00
Angelos Katharopoulos
5adf185f86 Fix update_modules() when providing a subset (#2308) 2025-06-20 17:19:46 -07:00
Awni Hannun
c9a9180584 Cuda perf tuning (#2307)
* perf tuning

* fix adding inputs arrays in matmul / srot

* format

* fix
2025-06-20 14:50:57 -07:00
Awni Hannun
76831ed83d Build CUDA release in Circle (#2306)
* cuda release

* add license
2025-06-19 15:26:36 -07:00
Angelos Katharopoulos
b3d7b85376 Make ptx cache settable by environment variable (#2304) 2025-06-17 23:55:56 -07:00
Awni Hannun
cad5c0241c [CUDA] synch properly waits for all tasks to finish and clear (#2303)
* cuda synch properly waits for all tasks to finish and clear

* fix copy
2025-06-17 12:03:25 -07:00
Awni Hannun
b8022c578a divmod, partition, sort fixes (#2302) 2025-06-16 18:49:32 -07:00
63 changed files with 2200 additions and 1203 deletions

View File

@@ -16,6 +16,9 @@ parameters:
linux_release: linux_release:
type: boolean type: boolean
default: false default: false
cuda_release:
type: boolean
default: false
jobs: jobs:
build_documentation: build_documentation:
@@ -104,7 +107,7 @@ jobs:
command: | command: |
echo "stubs" echo "stubs"
pip install typing_extensions pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |
@@ -162,7 +165,7 @@ jobs:
command: | command: |
source env/bin/activate source env/bin/activate
pip install typing_extensions pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |
@@ -223,7 +226,6 @@ jobs:
command: | command: |
sudo apt-get update sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
python -m venv env python -m venv env
source env/bin/activate source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
@@ -283,7 +285,7 @@ jobs:
command: | command: |
source env/bin/activate source env/bin/activate
pip install typing_extensions pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
- run: - run:
name: Build Python package name: Build Python package
command: | command: |
@@ -342,7 +344,7 @@ jobs:
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
pip install . -v pip install . -v
pip install typing_extensions pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
<< parameters.extra_env >> \ << parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python -m build --wheel python -m build --wheel
@@ -356,6 +358,48 @@ jobs:
- store_artifacts: - store_artifacts:
path: wheelhouse/ path: wheelhouse/
build_cuda_release:
parameters:
python_version:
type: string
default: "3.9"
extra_env:
type: string
default: "DEV_RELEASE=1"
machine:
image: linux-cuda-12:default
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- run:
name: Build wheel
command: |
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
python -m venv env
source env/bin/activate
pip install auditwheel
pip install patchelf
pip install build
pip install twine
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install ".[dev]" -v
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build --wheel
bash python/scripts/repair_cuda.sh
- run:
name: Upload package
command: |
source env/bin/activate
twine upload wheelhouse/*.whl
- store_artifacts:
path: wheelhouse/
workflows: workflows:
build_and_test: build_and_test:
when: when:
@@ -625,3 +669,14 @@ workflows:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"] extra_env: ["PYPI_RELEASE=1"]
cuda_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.cuda_release >>
jobs:
- build_cuda_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]

View File

@@ -5,6 +5,7 @@ import os
import time import time
import torch import torch
import torch.cuda
import torch.mps import torch.mps
@@ -44,8 +45,10 @@ def bench(f, *args):
def sync_if_needed(x): def sync_if_needed(x):
if x.device != torch.device("cpu"): if x.device == torch.device("mps"):
torch.mps.synchronize() torch.mps.synchronize()
elif x.device == torch.device("cuda"):
torch.cuda.synchronize()
@torch.no_grad() @torch.no_grad()
@@ -99,6 +102,14 @@ def reduction(op, axis, x):
sync_if_needed(x) sync_if_needed(x)
@torch.no_grad()
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
sync_if_needed(x)
@torch.no_grad() @torch.no_grad()
def softmax(axis, x): def softmax(axis, x):
ys = [] ys = []
@@ -340,7 +351,11 @@ if __name__ == "__main__":
args.axis.pop(0) args.axis.pop(0)
torch.set_num_threads(1) torch.set_num_threads(1)
device = "cpu" if args.cpu else "mps" device = "mps"
if torch.cuda.is_available():
device = "cuda"
if args.cpu:
device = "cpu"
types = args.dtype types = args.dtype
if not types: if not types:
@@ -460,5 +475,8 @@ if __name__ == "__main__":
elif args.benchmark == "selu": elif args.benchmark == "selu":
print(bench(selu, x)) print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else: else:
raise ValueError(f"Unknown benchmark `{args.benchmark}`.") raise ValueError(f"Unknown benchmark `{args.benchmark}`.")

View File

@@ -30,6 +30,16 @@ MLX is also available on conda-forge. To install MLX with conda do:
conda install conda-forge::mlx conda install conda-forge::mlx
CUDA
^^^^
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
.. code-block:: shell
pip install mlx-cuda
Troubleshooting Troubleshooting
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^
@@ -65,6 +75,8 @@ Build Requirements
Python API Python API
^^^^^^^^^^ ^^^^^^^^^^
.. _python install:
To build and install the MLX python library from source, first, clone MLX from To build and install the MLX python library from source, first, clone MLX from
`its GitHub repo <https://github.com/ml-explore/mlx>`_: `its GitHub repo <https://github.com/ml-explore/mlx>`_:
@@ -107,6 +119,8 @@ IDE:
C++ API C++ API
^^^^^^^ ^^^^^^^
.. _cpp install:
Currently, MLX must be built and installed from source. Currently, MLX must be built and installed from source.
Similarly to the python library, to build and install the MLX C++ library start Similarly to the python library, to build and install the MLX C++ library start
@@ -185,6 +199,7 @@ should point to the path to the built metal library.
xcrun -sdk macosx --show-sdk-version xcrun -sdk macosx --show-sdk-version
Binary Size Minimization Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~
@@ -213,6 +228,50 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists across reboots. Metal kernel cache persists across reboots.
Linux
^^^^^
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
For example on Ubuntu, run the following:
.. code-block:: shell
apt-get update -y
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
From here follow the instructions to install either the :ref:`Python <python
install>` or :ref:`C++ <cpp install>` APIs.
CUDA
^^^^
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
and the CUDA toolkit. For example on Ubuntu, run the following:
.. code-block:: shell
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y
apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
When building either the Python or C++ APIs make sure to pass the cmake flag
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
To build the C++ package run:
.. code-block:: shell
mkdir -p build && cd build
cmake .. -DMLX_BUILD_CUDA=ON && make -j
Troubleshooting Troubleshooting
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^

View File

@@ -14,6 +14,8 @@ void print_constant(std::ostream& os, const array& x) {
return print_float_constant<float16_t>(os, x); return print_float_constant<float16_t>(os, x);
case bfloat16: case bfloat16:
return print_float_constant<bfloat16_t>(os, x); return print_float_constant<bfloat16_t>(os, x);
case float64:
return print_float_constant<double>(os, x);
case complex64: case complex64:
return print_complex_constant<complex64_t>(os, x); return print_complex_constant<complex64_t>(os, x);
case int8: case int8:
@@ -50,6 +52,8 @@ std::string get_type_string(Dtype d) {
return "float16_t"; return "float16_t";
case bfloat16: case bfloat16:
return "bfloat16_t"; return "bfloat16_t";
case float64:
return "double";
case complex64: case complex64:
return "complex64_t"; return "complex64_t";
case bool_: case bool_:

View File

@@ -18,8 +18,12 @@ std::string get_type_string(Dtype d);
template <typename T> template <typename T>
void print_float_constant(std::ostream& os, const array& x) { void print_float_constant(std::ostream& os, const array& x) {
auto old_precision = os.precision(); auto old_precision = os.precision();
os << std::setprecision(std::numeric_limits<float>::digits10 + 1) if constexpr (std::is_same_v<T, double>) {
<< x.item<T>() << std::setprecision(old_precision); os << std::setprecision(std::numeric_limits<double>::digits10 + 1);
} else {
os << std::setprecision(std::numeric_limits<float>::digits10 + 1);
}
os << x.item<T>() << std::setprecision(old_precision);
} }
template <typename T> template <typename T>

View File

@@ -5,11 +5,9 @@
namespace mlx::core { namespace mlx::core {
std::pair<Shape, Strides> shapes_without_reduction_axes( std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x, Shape shape,
Strides strides,
const std::vector<int>& axes) { const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();
for (int i = axes.size() - 1; i >= 0; i--) { for (int i = axes.size() - 1; i >= 0; i--) {
int a = axes[i]; int a = axes[i];
shape.erase(shape.begin() + a); shape.erase(shape.begin() + a);
@@ -19,6 +17,15 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
return std::make_pair(shape, strides); return std::make_pair(shape, strides);
} }
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();
return shapes_without_reduction_axes(
std::move(shape), std::move(strides), axes);
}
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) { ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// The data is all there and we are reducing over everything // The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() && if (x.size() == x.data_size() && axes.size() == x.ndim() &&

View File

@@ -51,5 +51,9 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
std::pair<Shape, Strides> shapes_without_reduction_axes( std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x, const array& x,
const std::vector<int>& axes); const std::vector<int>& axes);
std::pair<Shape, Strides> shapes_without_reduction_axes(
Shape shape,
Strides strides,
const std::vector<int>& axes);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -199,12 +199,15 @@ Dims get_2d_grid_dims_common(
} }
} }
} }
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) { if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
throw std::runtime_error("Unable to safely factor shape."); throw std::runtime_error("Unable to safely factor shape.");
} }
if (grid_y > grid_x) { if (grid_y > grid_x) {
std::swap(grid_x, grid_y); std::swap(grid_x, grid_y);
} }
if (divisor > 1) {
grid_x = ((grid_x + divisor - 1) / divisor) * divisor;
}
return std::make_tuple( return std::make_tuple(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1); static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
} }

View File

@@ -8,6 +8,7 @@ target_sources(
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu ${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
@@ -28,9 +29,10 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
${CMAKE_CURRENT_SOURCE_DIR}/random.cu ${CMAKE_CURRENT_SOURCE_DIR}/random.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp

View File

@@ -3,6 +3,7 @@
#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/worker.h" #include "mlx/backend/cuda/worker.h"
#include "mlx/utils.h"
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <fmt/format.h> #include <fmt/format.h>
@@ -14,9 +15,11 @@ namespace mlx::core {
namespace cu { namespace cu {
constexpr int page_size = 16384;
CudaAllocator::CudaAllocator() CudaAllocator::CudaAllocator()
: buffer_cache_( : buffer_cache_(
getpagesize(), page_size,
[](CudaBuffer* buf) { return buf->size; }, [](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { [this](CudaBuffer* buf) {
cuda_free(buf->data); cuda_free(buf->data);
@@ -31,7 +34,14 @@ CudaAllocator::CudaAllocator()
Buffer CudaAllocator::malloc(size_t size) { Buffer CudaAllocator::malloc(size_t size) {
// Find available buffer from cache. // Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_); std::unique_lock lock(mutex_);
if (size < page_size) {
size = next_power_of_2(size);
} else {
size = page_size * ((size + page_size - 1) / page_size);
}
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) { if (!buf) {
// If we have a lot of memory pressure or are over the maximum cache size, // If we have a lot of memory pressure or are over the maximum cache size,
@@ -106,7 +116,6 @@ void CudaAllocator::cuda_free(void* buf) {
return; return;
} }
} }
cudaFree(buf); cudaFree(buf);
} }

View File

@@ -152,35 +152,29 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_REAL_TYPES_CHECKED(in.dtype(), "ArgReduce", CTYPE, { dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
using InType = cuda_type_t<CTYPE>; using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr uint32_t N_READS = 4; constexpr uint32_t N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { dispatch_block_dim(
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
dim3 block_dims{BLOCK_DIM, 1, 1}; dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
auto kernel = &cu::arg_reduce_general< auto kernel =
InType, cu::arg_reduce_general<T, cu::ArgMax<T>, block_dim(), N_READS>;
cu::ArgMax<InType>, if (reduce_type_ == ArgReduce::ArgMin) {
BLOCK_DIM, kernel = cu::
N_READS>; arg_reduce_general<T, cu::ArgMin<T>, block_dim(), N_READS>;
if (reduce_type_ == ArgReduce::ArgMin) { }
kernel = &cu::arg_reduce_general< kernel<<<num_blocks, block_dim(), 0, stream>>>(
InType, in.data<T>(),
cu::ArgMin<InType>, out.data<uint32_t>(),
BLOCK_DIM, out.size(),
N_READS>; const_param(shape),
} const_param(in_strides),
kernel<<<num_blocks, block_dims, 0, stream>>>( const_param(out_strides),
in.data<InType>(), ndim,
out.data<uint32_t>(), axis_stride,
out.size(), axis_size);
const_param(shape), });
const_param(in_strides),
const_param(out_strides),
ndim,
axis_stride,
axis_size);
});
}); });
}); });
} }

View File

@@ -125,13 +125,12 @@ constexpr bool supports_binary_op() {
template <typename Op> template <typename Op>
void binary_op_gpu_inplace( void binary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, array& out,
std::string_view op, std::string_view op,
const Stream& s) { const Stream& s) {
assert(inputs.size() > 1); assert(inputs.size() > 1);
const auto& a = inputs[0]; const auto& a = inputs[0];
const auto& b = inputs[1]; const auto& b = inputs[1];
auto& out = outputs[0];
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }
@@ -141,55 +140,64 @@ void binary_op_gpu_inplace(
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, { dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) { if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>; using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>; using OutType = cuda_type_t<CTYPE_OUT>;
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
if (bopt == BinaryOpType::General) { if (bopt == BinaryOpType::General) {
auto [shape, strides] = collapse_contiguous_dims(a, b, out); dispatch_bool(
auto& a_strides = strides[0]; a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
auto& b_strides = strides[1]; out.data_size() > INT32_MAX,
bool large = a.data_size() > INT32_MAX || [&](auto large) {
b.data_size() > INT32_MAX || out.data_size() > INT32_MAX; using IdxT = std::conditional_t<large(), int64_t, int32_t>;
MLX_SWITCH_BOOL(large, LARGE, { Shape shape;
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>; std::vector<Strides> strides;
int ndim = shape.size(); std::tie(shape, strides) =
if (ndim <= 3) { collapse_contiguous_dims(a, b, out);
MLX_SWITCH_1_2_3(ndim, NDIM, { auto& a_strides = strides[0];
auto kernel = auto& b_strides = strides[1];
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>; int ndim = shape.size();
auto [num_blocks, block_dims] = if (ndim <= 3) {
get_launch_args(kernel, out, large); dispatch_1_2_3(ndim, [&](auto dims_constant) {
kernel<<<num_blocks, block_dims, 0, stream>>>( auto kernel = cu::binary_g_nd<
a.data<InType>(), Op,
b.data<InType>(), InType,
out.data<OutType>(), OutType,
out.size(), IdxT,
const_param<NDIM>(shape), dims_constant()>;
const_param<NDIM>(a_strides), auto [num_blocks, block_dims] =
const_param<NDIM>(b_strides)); get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides));
});
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
ndim);
}
}); });
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
ndim);
}
});
} else { } else {
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>; using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>; auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
if (bopt == BinaryOpType::ScalarVector) { if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_sv<Op, InType, OutType, IdxT>; kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
@@ -199,7 +207,7 @@ void binary_op_gpu_inplace(
kernel = cu::binary_vv<Op, InType, OutType, IdxT>; kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
} }
auto [num_blocks, block_dims] = get_launch_args( auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), LARGE); kernel, out.data_size(), out.shape(), out.strides(), large());
kernel<<<num_blocks, block_dims, 0, stream>>>( kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),
@@ -219,20 +227,6 @@ void binary_op_gpu_inplace(
}); });
} }
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt);
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
template <typename Op> template <typename Op>
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
@@ -243,8 +237,7 @@ void binary_op_gpu(
auto& b = inputs[1]; auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt); set_binary_op_output_data(a, b, out, bopt);
std::vector<array> outputs{out}; binary_op_gpu_inplace<Op>(inputs, out, op, s);
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
} }
#define BINARY_GPU(func) \ #define BINARY_GPU(func) \
@@ -254,14 +247,6 @@ void binary_op_gpu(
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \ binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
} }
#define BINARY_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = outputs[0].primitive().stream(); \
binary_op_gpu<cu::func>(inputs, outputs, get_primitive_string(this), s); \
}
BINARY_GPU(Add) BINARY_GPU(Add)
BINARY_GPU(ArcTan2) BINARY_GPU(ArcTan2)
BINARY_GPU(Divide) BINARY_GPU(Divide)

View File

@@ -0,0 +1,258 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/binary.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/binary_ops.cuh"
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void
binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[0], b[0]);
out_a[0] = out[0];
out_b[0] = out[1];
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void
binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[0], b[index]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void
binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[index], b[0]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void
binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[index], b[index]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
__global__ void binary_g_nd(
const In* a,
const In* b,
Out* out_a,
Out* out_b,
IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index, shape.data(), a_strides.data(), b_strides.data());
auto out = Op{}(a[a_idx], b[b_idx]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_g(
const In* a,
const In* b,
Out* out_a,
Out* out_b,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_4d(
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
auto out = Op{}(a[a_idx], b[b_idx]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out>
constexpr bool supports_binary_op() {
if (std::is_same_v<Op, DivMod>) {
return std::is_same_v<In, Out> &&
(std::is_integral_v<Out> || is_floating_v<Out>);
}
return false;
}
} // namespace cu
template <typename Op>
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
assert(inputs.size() > 1);
const auto& a = inputs[0];
const auto& b = inputs[1];
auto& out_a = outputs[0];
auto& out_b = outputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
if (out_a.size() == 0) {
return;
}
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
encoder.set_output_array(out_b);
encoder.launch_kernel([&](cudaStream_t stream) {
dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) {
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
auto bopt = get_binary_op_type(a, b);
if (bopt == BinaryOpType::General) {
dispatch_bool(
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
out_a.data_size() > INT32_MAX,
[&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
Shape shape;
std::vector<Strides> strides;
std::tie(shape, strides) =
collapse_contiguous_dims(a, b, out_a);
auto& a_strides = strides[0];
auto& b_strides = strides[1];
int ndim = shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu::binary_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides));
});
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
ndim);
}
});
} else {
dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorScalar) {
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] = get_launch_args(
kernel,
out_a.data_size(),
out_a.shape(),
out_a.strides(),
large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.data_size());
});
}
} else {
throw std::runtime_error(fmt::format(
"Can not do binary op {} on inputs of {} with result of {}.",
op,
dtype_to_string(a.dtype()),
dtype_to_string(out_a.dtype())));
}
});
});
});
}
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt);
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
void DivMod::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("DivMod::eval_gpu");
auto& s = outputs[0].primitive().stream();
binary_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s);
}
} // namespace mlx::core

View File

@@ -24,7 +24,6 @@ void copy_gpu_inplace(
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) { if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out); copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
return; return;

View File

@@ -10,15 +10,6 @@
namespace mlx::core { namespace mlx::core {
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
using InType = cuda_type_t<CTYPE_IN>; \
using OutType = cuda_type_t<CTYPE_OUT>; \
__VA_ARGS__; \
}); \
})
void copy_contiguous( void copy_contiguous(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
CopyType ctype, CopyType ctype,

View File

@@ -36,19 +36,23 @@ void copy_contiguous(
int64_t in_offset, int64_t in_offset,
int64_t out_offset) { int64_t out_offset) {
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>; dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
auto kernel = cu::copy_s<InType, OutType, IdxT>; using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
if (ctype == CopyType::Vector) { using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
kernel = cu::copy_v<InType, OutType, IdxT>; using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
} auto kernel = cu::copy_s<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args( if (ctype == CopyType::Vector) {
kernel, out.data_size(), out.shape(), out.strides(), LARGE); kernel = cu::copy_v<InType, OutType, IdxT>;
kernel<<<num_blocks, block_dims, 0, stream>>>( }
in.data<InType>() + in_offset, auto [num_blocks, block_dims] = get_launch_args(
out.data<OutType>() + out_offset, kernel, out.data_size(), out.shape(), out.strides(), large());
out.data_size()); kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>() + in_offset,
out.data<OutType>() + out_offset,
out.data_size());
});
}); });
}); });
}); });

View File

@@ -56,37 +56,48 @@ void copy_general(
const Strides& strides_in, const Strides& strides_in,
const Strides& strides_out) { const Strides& strides_out) {
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
const InType* in_ptr = in.data<InType>() + offset_in; dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
OutType* out_ptr = out.data<OutType>() + offset_out; dispatch_bool(
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
MLX_SWITCH_BOOL(large, LARGE, { [&](auto large) {
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>; using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
int ndim = shape.size(); using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
if (ndim <= 3) { using IdxT = std::conditional_t<large(), int64_t, int32_t>;
MLX_SWITCH_1_2_3(ndim, NDIM, { const InType* in_ptr = in.data<InType>() + offset_in;
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>; OutType* out_ptr = out.data<OutType>() + offset_out;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); int ndim = shape.size();
kernel<<<num_blocks, block_dims, 0, stream>>>( size_t data_size = 1;
in_ptr, for (auto& s : shape)
out_ptr, data_size *= s;
out.size(), if (ndim <= 3) {
const_param<NDIM>(shape), dispatch_1_2_3(ndim, [&](auto ndim_constant) {
const_param<NDIM>(strides_in), auto kernel =
const_param<NDIM>(strides_out)); cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>;
}); auto [num_blocks, block_dims] = get_launch_args(
} else { // ndim >= 4 kernel, data_size, shape, out.strides(), large());
auto kernel = cu::copy_gg<InType, OutType, IdxT>; kernel<<<num_blocks, block_dims, 0, stream>>>(
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); in_ptr,
kernel<<<num_blocks, block_dims, 0, stream>>>( out_ptr,
in_ptr, data_size,
out_ptr, const_param<ndim_constant()>(shape),
out.size(), const_param<ndim_constant()>(strides_in),
const_param(shape), const_param<ndim_constant()>(strides_out));
const_param(strides_in), });
const_param(strides_out), } else { // ndim >= 4
ndim); auto kernel = cu::copy_gg<InType, OutType, IdxT>;
} auto [num_blocks, block_dims] = get_launch_args(
kernel, data_size, shape, out.strides(), large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
data_size,
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim);
}
});
}); });
}); });
}); });

View File

@@ -62,41 +62,52 @@ void copy_general_dynamic(
const array& dynamic_offset_in, const array& dynamic_offset_in,
const array& dynamic_offset_out) { const array& dynamic_offset_out) {
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
const InType* in_ptr = in.data<InType>() + offset_in; dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
OutType* out_ptr = out.data<OutType>() + offset_out; dispatch_bool(
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
MLX_SWITCH_BOOL(large, LARGE, { [&](auto large) {
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>; using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
int ndim = shape.size(); using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
if (ndim <= 3) { using IdxT = std::conditional_t<large(), int64_t, int32_t>;
MLX_SWITCH_1_2_3(ndim, NDIM, { const InType* in_ptr = in.data<InType>() + offset_in;
auto kernel = cu::copy_gg_dynamic_nd<InType, OutType, IdxT, NDIM>; OutType* out_ptr = out.data<OutType>() + offset_out;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); int ndim = shape.size();
kernel<<<num_blocks, block_dims, 0, stream>>>( if (ndim <= 3) {
in_ptr, dispatch_1_2_3(ndim, [&](auto dims_constant) {
out_ptr, auto kernel = cu::copy_gg_dynamic_nd<
out.size(), InType,
const_param<NDIM>(shape), OutType,
const_param<NDIM>(strides_in), IdxT,
const_param<NDIM>(strides_out), dims_constant()>;
dynamic_offset_in.data<int64_t>(), auto [num_blocks, block_dims] =
dynamic_offset_out.data<int64_t>()); get_launch_args(kernel, out, large());
}); kernel<<<num_blocks, block_dims, 0, stream>>>(
} else { // ndim >= 4 in_ptr,
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>; out_ptr,
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); out.size(),
kernel<<<num_blocks, block_dims, 0, stream>>>( const_param<dims_constant()>(shape),
in_ptr, const_param<dims_constant()>(strides_in),
out_ptr, const_param<dims_constant()>(strides_out),
out.size(), dynamic_offset_in.data<int64_t>(),
const_param(shape), dynamic_offset_out.data<int64_t>());
const_param(strides_in), });
const_param(strides_out), } else { // ndim >= 4
ndim, auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
dynamic_offset_in.data<int64_t>(), auto [num_blocks, block_dims] =
dynamic_offset_out.data<int64_t>()); get_launch_args(kernel, out, large());
} kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim,
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
}
});
}); });
}); });
}); });

View File

@@ -51,35 +51,43 @@ void copy_general_input(
const Shape& shape, const Shape& shape,
const Strides& strides_in) { const Strides& strides_in) {
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
const InType* in_ptr = in.data<InType>() + offset_in; dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
OutType* out_ptr = out.data<OutType>() + offset_out; dispatch_bool(
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
MLX_SWITCH_BOOL(large, LARGE, { [&](auto large) {
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>; using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
int ndim = shape.size(); using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
if (ndim <= 3) { using IdxT = std::conditional_t<large(), int64_t, int32_t>;
MLX_SWITCH_1_2_3(ndim, NDIM, { const InType* in_ptr = in.data<InType>() + offset_in;
auto kernel = cu::copy_g_nd<InType, OutType, IdxT, NDIM>; OutType* out_ptr = out.data<OutType>() + offset_out;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); int ndim = shape.size();
kernel<<<num_blocks, block_dims, 0, stream>>>( if (ndim <= 3) {
in_ptr, dispatch_1_2_3(ndim, [&](auto dims_constant) {
out_ptr, auto kernel =
out.size(), cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>;
const_param<NDIM>(shape), auto [num_blocks, block_dims] =
const_param<NDIM>(strides_in)); get_launch_args(kernel, out, large());
}); kernel<<<num_blocks, block_dims, 0, stream>>>(
} else { // ndim >= 4 in_ptr,
auto kernel = cu::copy_g<InType, OutType, IdxT>; out_ptr,
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); out.size(),
kernel<<<num_blocks, block_dims, 0, stream>>>( const_param<dims_constant()>(shape),
in_ptr, const_param<dims_constant()>(strides_in));
out_ptr, });
out.size(), } else { // ndim >= 4
const_param(shape), auto kernel = cu::copy_g<InType, OutType, IdxT>;
const_param(strides_in), auto [num_blocks, block_dims] =
ndim); get_launch_args(kernel, out, large());
} kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
const_param(shape),
const_param(strides_in),
ndim);
}
});
}); });
}); });
}); });

View File

@@ -6,6 +6,7 @@
#include <fmt/format.h> #include <fmt/format.h>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <future>
namespace mlx::core { namespace mlx::core {
@@ -107,6 +108,16 @@ void CommandEncoder::commit() {
worker_.commit(stream_.last_cuda_stream()); worker_.commit(stream_.last_cuda_stream());
} }
void CommandEncoder::synchronize() {
stream().synchronize();
auto p = std::make_shared<std::promise<void>>();
std::future<void> f = p->get_future();
add_completed_handler([p = std::move(p)]() { p->set_value(); });
worker_.end_batch();
commit();
f.wait();
}
Device& device(mlx::core::Device device) { Device& device(mlx::core::Device device) {
static std::unordered_map<int, Device> devices; static std::unordered_map<int, Device> devices;
auto it = devices.find(device.index); auto it = devices.find(device.index);

View File

@@ -123,6 +123,9 @@ class CommandEncoder {
return has_gpu_work_; return has_gpu_work_;
} }
// Wait until kernels and completion handlers are finished
void synchronize();
private: private:
Device& device_; Device& device_;
DeviceStream& stream_; DeviceStream& stream_;

View File

@@ -22,7 +22,7 @@ struct FloorDivide {
if constexpr (cuda::std::is_integral_v<T>) { if constexpr (cuda::std::is_integral_v<T>) {
return x / y; return x / y;
} else { } else {
return trunc(x / y); return truncf(x / y);
} }
} }
}; };
@@ -132,7 +132,7 @@ struct LogAddExp {
cuda::std::numeric_limits<float>::quiet_NaN(), cuda::std::numeric_limits<float>::quiet_NaN(),
cuda::std::numeric_limits<float>::quiet_NaN()}; cuda::std::numeric_limits<float>::quiet_NaN()};
} }
constexpr float inf = cuda::std::numeric_limits<float>::infinity(); float inf = cuda::std::numeric_limits<float>::infinity();
auto maxval = x > y ? x : y; auto maxval = x > y ? x : y;
auto minval = x < y ? x : y; auto minval = x < y ? x : y;
if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf) if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf)

View File

@@ -5,7 +5,7 @@
#pragma once #pragma once
// The maximum dimensions of shape/strides passed as kernel parameters. // The maximum dimensions of shape/strides passed as kernel parameters.
#define MAX_NDIM 8 #define MAX_NDIM 10
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in // All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
// warpSize variable exists, using it would prevent compile-time optimizations. // warpSize variable exists, using it would prevent compile-time optimizations.

View File

@@ -27,6 +27,8 @@ struct ArcCos {
__device__ T operator()(T x) { __device__ T operator()(T x) {
return acos(x); return acos(x);
} }
__device__ cuComplex operator()(cuComplex x);
}; };
struct ArcCosh { struct ArcCosh {
@@ -41,6 +43,8 @@ struct ArcSin {
__device__ T operator()(T x) { __device__ T operator()(T x) {
return asin(x); return asin(x);
} }
__device__ cuComplex operator()(cuComplex x);
}; };
struct ArcSinh { struct ArcSinh {
@@ -55,6 +59,8 @@ struct ArcTan {
__device__ T operator()(T x) { __device__ T operator()(T x) {
return atan(x); return atan(x);
} }
__device__ cuComplex operator()(cuComplex x);
}; };
struct ArcTanh { struct ArcTanh {
@@ -261,13 +267,6 @@ struct Round {
} }
}; };
struct Rsqrt {
template <typename T>
__device__ T operator()(T x) {
return rsqrt(x);
}
};
struct Sigmoid { struct Sigmoid {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
@@ -333,6 +332,29 @@ struct Sqrt {
__device__ T operator()(T x) { __device__ T operator()(T x) {
return sqrt(x); return sqrt(x);
} }
__device__ cuComplex operator()(cuComplex x) {
auto xr = cuCrealf(x);
auto xi = cuCimagf(x);
if (xr == 0.0f && xi == 0.0f) {
return {0.0f, 0.0f};
}
auto r = cuCrealf(Abs{}(x));
auto a = sqrt((r + xr) / 2.0f);
auto b_abs = sqrt((r - xr) / 2.0f);
auto b = copysign(b_abs, xi);
return {a, b};
}
};
struct Rsqrt {
template <typename T>
__device__ T operator()(T x) {
return rsqrt(x);
}
__device__ cuComplex operator()(cuComplex x) {
return 1.0f / Sqrt{}(x);
}
}; };
struct Tan { struct Tan {
@@ -365,4 +387,22 @@ struct Tanh {
} }
}; };
__device__ cuComplex ArcCos::operator()(cuComplex x) {
auto i = cuComplex{0.0, 1.0};
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
return {cuCimagf(y), -cuCrealf(y)};
};
__device__ cuComplex ArcSin::operator()(cuComplex x) {
auto i = cuComplex{0.0f, 1.0f};
auto y = Log{}(i * x + Sqrt{}(1.0f - x * x));
return {cuCimagf(y), -cuCrealf(y)};
};
__device__ cuComplex ArcTan::operator()(cuComplex x) {
auto i = cuComplex{0.0f, 1.0f};
auto ix = i * x;
return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix));
};
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@@ -155,8 +155,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
#pragma unroll #pragma unroll
for (int i = NDIM - 1; i >= 0; --i) { for (int i = NDIM - 1; i >= 0; --i) {
int dim_idx = elem % shape[i]; int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i]; a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * b_strides[i]; b_loc += dim_idx * IdxT(b_strides[i]);
elem /= shape[i]; elem /= shape[i];
} }
return cuda::std::make_tuple(a_loc, b_loc); return cuda::std::make_tuple(a_loc, b_loc);
@@ -175,9 +175,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
#pragma unroll #pragma unroll
for (int i = NDIM - 1; i >= 0; --i) { for (int i = NDIM - 1; i >= 0; --i) {
int dim_idx = elem % shape[i]; int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i]; a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * b_strides[i]; b_loc += dim_idx * IdxT(b_strides[i]);
c_loc += dim_idx * c_strides[i]; c_loc += dim_idx * IdxT(c_strides[i]);
elem /= shape[i]; elem /= shape[i];
} }
return cuda::std::make_tuple(a_loc, b_loc, c_loc); return cuda::std::make_tuple(a_loc, b_loc, c_loc);
@@ -206,8 +206,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
IdxT b_loc = 0; IdxT b_loc = 0;
for (int i = ndim - 1; i >= 0; --i) { for (int i = ndim - 1; i >= 0; --i) {
int dim_idx = elem % shape[i]; int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i]; a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * b_strides[i]; b_loc += dim_idx * IdxT(b_strides[i]);
elem /= shape[i]; elem /= shape[i];
} }
return cuda::std::make_tuple(a_loc, b_loc); return cuda::std::make_tuple(a_loc, b_loc);
@@ -226,9 +226,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
IdxT c_loc = 0; IdxT c_loc = 0;
for (int i = ndim - 1; i >= 0; --i) { for (int i = ndim - 1; i >= 0; --i) {
int dim_idx = elem % shape[i]; int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i]; a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * b_strides[i]; b_loc += dim_idx * IdxT(b_strides[i]);
c_loc += dim_idx * c_strides[i]; c_loc += dim_idx * IdxT(c_strides[i]);
elem /= shape[i]; elem /= shape[i];
} }
return cuda::std::make_tuple(a_loc, b_loc, c_loc); return cuda::std::make_tuple(a_loc, b_loc, c_loc);

View File

@@ -62,7 +62,7 @@ void finalize(Stream s) {
void synchronize(Stream s) { void synchronize(Stream s) {
nvtx3::scoped_range r("gpu::synchronize"); nvtx3::scoped_range r("gpu::synchronize");
cu::get_stream(s).synchronize(); cu::get_command_encoder(s).synchronize();
} }
} // namespace mlx::core::gpu } // namespace mlx::core::gpu

View File

@@ -37,36 +37,46 @@ void check_cu_error(const char* name, CUresult err) {
} }
// Return the location of the CUDA toolkit. // Return the location of the CUDA toolkit.
const char* cuda_home() { const std::string& cuda_home() {
const char* home = std::getenv("CUDA_HOME"); static std::string home = []() -> std::string {
if (home) { const char* home = std::getenv("CUDA_HOME");
return home; if (home) {
} return home;
home = std::getenv("CUDA_PATH"); }
if (home) { home = std::getenv("CUDA_PATH");
return home; if (home) {
} return home;
}
#if defined(__linux__) #if defined(__linux__)
home = "/usr/local/cuda"; home = "/usr/local/cuda";
if (std::filesystem::exists(home)) { if (std::filesystem::exists(home)) {
return home; return home;
} }
#endif #endif
throw std::runtime_error( throw std::runtime_error(
"Environment variable CUDA_HOME or CUDA_PATH is not set."); "Environment variable CUDA_HOME or CUDA_PATH is not set.");
}();
return home;
} }
// Get the cache directory for storing compiled results. // Get the cache directory for storing compiled results.
bool get_ptx_cache_dir(std::filesystem::path* result) { const std::filesystem::path& ptx_cache_dir() {
auto path = std::filesystem::temp_directory_path() / "mlx" / "ptx"; static std::filesystem::path cache = []() -> std::filesystem::path {
if (!std::filesystem::is_directory(path)) { std::filesystem::path cache;
std::error_code error; if (auto c = std::getenv("MLX_PTX_CACHE"); c) {
if (!std::filesystem::create_directories(path, error)) { cache = c;
return false; } else {
cache = std::filesystem::temp_directory_path() / "mlx" / "ptx";
} }
} if (!std::filesystem::exists(cache)) {
*result = path; std::error_code error;
return true; if (!std::filesystem::create_directories(cache, error)) {
return std::filesystem::path();
}
}
return cache;
}();
return cache;
} }
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|. // Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
@@ -75,6 +85,10 @@ bool read_cached_ptx(
const std::string& module_name, const std::string& module_name,
std::vector<char>* ptx, std::vector<char>* ptx,
std::vector<std::pair<std::string, std::string>>* ptx_kernels) { std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
if (cache_dir.empty()) {
return false;
}
auto ptx_path = cache_dir / (module_name + ".ptx"); auto ptx_path = cache_dir / (module_name + ".ptx");
std::error_code error; std::error_code error;
auto ptx_size = std::filesystem::file_size(ptx_path, error); auto ptx_size = std::filesystem::file_size(ptx_path, error);
@@ -105,6 +119,10 @@ void write_cached_ptx(
const std::string& module_name, const std::string& module_name,
const std::vector<char>& ptx, const std::vector<char>& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) { const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
if (cache_dir.empty()) {
return;
}
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary); std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
if (!ptx.empty()) { if (!ptx.empty()) {
ptx_file.write(&ptx.front(), ptx.size()); ptx_file.write(&ptx.front(), ptx.size());
@@ -184,11 +202,9 @@ JitModule::JitModule(
const std::string& module_name, const std::string& module_name,
const KernelBuilder& builder) { const KernelBuilder& builder) {
// Check cache. // Check cache.
std::filesystem::path cache_dir;
std::vector<char> ptx; std::vector<char> ptx;
std::vector<std::pair<std::string, std::string>> ptx_kernels; std::vector<std::pair<std::string, std::string>> ptx_kernels;
if (!get_ptx_cache_dir(&cache_dir) || if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) {
!read_cached_ptx(cache_dir, module_name, &ptx, &ptx_kernels)) {
// Create program. // Create program.
auto [source_code, kernel_names] = builder(); auto [source_code, kernel_names] = builder();
nvrtcProgram prog; nvrtcProgram prog;
@@ -246,7 +262,7 @@ JitModule::JitModule(
} else { } else {
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data())); CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
} }
write_cached_ptx(cache_dir, module_name, ptx, ptx_kernels); write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels);
} }
// Load module. // Load module.

View File

@@ -6,6 +6,8 @@
#pragma once #pragma once
#include <type_traits>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/device/utils.cuh"
@@ -17,60 +19,46 @@
namespace mlx::core { namespace mlx::core {
// Convert a number between 1~3 to constexpr. template <typename F>
#define MLX_SWITCH_1_2_3(N, NDIM, ...) \ void dispatch_1_2_3(int n, F&& f) {
switch (N) { \ switch (n) {
case 1: { \ case 1:
constexpr int NDIM = 1; \ f(std::integral_constant<int, 1>{});
__VA_ARGS__; \ break;
break; \ case 2:
} \ f(std::integral_constant<int, 2>{});
case 2: { \ break;
constexpr int NDIM = 2; \ case 3:
__VA_ARGS__; \ f(std::integral_constant<int, 3>{});
break; \ break;
} \
case 3: { \
constexpr int NDIM = 3; \
__VA_ARGS__; \
break; \
} \
} }
}
// Like MLX_SWITCH_ALL_TYPES but for booleans. template <typename F>
#define MLX_SWITCH_BOOL(BOOL, BOOL_ALIAS, ...) \ void dispatch_bool(bool v, F&& f) {
if (BOOL) { \ if (v) {
constexpr bool BOOL_ALIAS = true; \ f(std::true_type{});
__VA_ARGS__; \ } else {
} else { \ f(std::false_type{});
constexpr bool BOOL_ALIAS = false; \
__VA_ARGS__; \
} }
}
// Convert a block_dim to constexpr between WARP_SIZE and WARP_SIZE ^ 2. template <typename F>
#define MLX_SWITCH_BLOCK_DIM(NUM_THREADS, BLOCK_DIM, ...) \ void dispatch_block_dim(int threads, F&& f) {
{ \ if (threads <= WARP_SIZE) {
uint32_t _num_threads = NUM_THREADS; \ f(std::integral_constant<int, WARP_SIZE>{});
if (_num_threads <= WARP_SIZE) { \ } else if (threads <= WARP_SIZE * 2) {
constexpr uint32_t BLOCK_DIM = WARP_SIZE; \ f(std::integral_constant<int, WARP_SIZE * 2>{});
__VA_ARGS__; \ } else if (threads <= WARP_SIZE * 4) {
} else if (_num_threads <= WARP_SIZE * 2) { \ f(std::integral_constant<int, WARP_SIZE * 4>{});
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 2; \ } else if (threads <= WARP_SIZE * 8) {
__VA_ARGS__; \ f(std::integral_constant<int, WARP_SIZE * 8>{});
} else if (_num_threads <= WARP_SIZE * 4) { \ } else if (threads <= WARP_SIZE * 16) {
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 4; \ f(std::integral_constant<int, WARP_SIZE * 16>{});
__VA_ARGS__; \ } else {
} else if (_num_threads <= WARP_SIZE * 8) { \ f(std::integral_constant<int, WARP_SIZE * 32>{});
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 8; \
__VA_ARGS__; \
} else if (_num_threads <= WARP_SIZE * 16) { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 16; \
__VA_ARGS__; \
} else { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * WARP_SIZE; \
__VA_ARGS__; \
} \
} }
}
// Maps CPU types to CUDA types. // Maps CPU types to CUDA types.
template <typename T> template <typename T>

View File

@@ -259,21 +259,22 @@ void LayerNorm::eval_gpu(
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "layernorm", CTYPE, { dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) {
using DataType = cuda_type_t<CTYPE>;
constexpr uint32_t N_READS = 4; constexpr uint32_t N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { dispatch_block_dim(
auto kernel = cu::layer_norm<DataType, BLOCK_DIM, N_READS>; cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>( using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
x.data<DataType>(), auto kernel = cu::layer_norm<DataType, block_dim(), N_READS>;
w.data<DataType>(), kernel<<<n_rows, block_dim(), 0, stream>>>(
b.data<DataType>(), x.data<DataType>(),
out.data<DataType>(), w.data<DataType>(),
eps_, b.data<DataType>(),
axis_size, out.data<DataType>(),
w_stride, eps_,
b_stride); axis_size,
}); w_stride,
b_stride);
});
}); });
}); });
} }
@@ -341,8 +342,6 @@ void LayerNormVJP::eval_gpu(
encoder.add_temporary(gw_temp); encoder.add_temporary(gw_temp);
} }
} }
gw.set_data(allocator::malloc(gw.nbytes()));
gb.set_data(allocator::malloc(gb.nbytes()));
// Finish with the gradient for b in case we had a b. // Finish with the gradient for b in case we had a b.
if (gb.ndim() == 1 && gb.size() == axis_size) { if (gb.ndim() == 1 && gb.size() == axis_size) {
@@ -357,22 +356,27 @@ void LayerNormVJP::eval_gpu(
encoder.set_output_array(gx); encoder.set_output_array(gx);
encoder.set_output_array(gw_temp); encoder.set_output_array(gw_temp);
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "layernorm_vjp", CTYPE, { dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) {
using DataType = cuda_type_t<CTYPE>; dispatch_bool(has_w, [&](auto has_w_constant) {
constexpr int N_READS = 4; constexpr int N_READS = 4;
MLX_SWITCH_BOOL(has_w, HAS_W, { dispatch_block_dim(
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
auto kernel = cu::layer_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>; using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>( auto kernel = cu::layer_norm_vjp<
x.data<DataType>(), DataType,
w.data<DataType>(), has_w_constant(),
g.data<DataType>(), block_dim(),
gx.data<DataType>(), N_READS>;
gw_temp.data<DataType>(), kernel<<<n_rows, block_dim(), 0, stream>>>(
eps_, x.data<DataType>(),
axis_size, w.data<DataType>(),
w_stride); g.data<DataType>(),
}); gx.data<DataType>(),
gw_temp.data<DataType>(),
eps_,
axis_size,
w_stride);
});
}); });
}); });
}); });

View File

@@ -144,14 +144,15 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "logsumexp", CTYPE, { dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) {
using DataType = cuda_type_t<CTYPE>;
constexpr int N_READS = 4; constexpr int N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { dispatch_block_dim(
auto kernel = cu::logsumexp<DataType, float, BLOCK_DIM, N_READS>; cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>( using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
in.data<DataType>(), out.data<DataType>(), axis_size); auto kernel = cu::logsumexp<DataType, float, block_dim(), N_READS>;
}); kernel<<<n_rows, block_dim(), 0, stream>>>(
in.data<DataType>(), out.data<DataType>(), axis_size);
});
}); });
}); });
} }

View File

@@ -162,11 +162,15 @@ class MatMul {
} }
} }
array workspace( void* workspace_ptr = nullptr;
allocator::malloc(heuristic_.workspaceSize), if (heuristic_.workspaceSize > 0) {
{static_cast<int>(heuristic_.workspaceSize)}, array workspace(
int8); allocator::malloc(heuristic_.workspaceSize),
encoder.add_temporary(workspace); {static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);
workspace_ptr = workspace.data<void>();
}
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
CHECK_CUBLAS_ERROR(cublasLtMatmul( CHECK_CUBLAS_ERROR(cublasLtMatmul(
@@ -183,8 +187,8 @@ class MatMul {
out, out,
out_desc_, out_desc_,
&heuristic_.algo, &heuristic_.algo,
workspace.data<void>(), workspace_ptr,
workspace.nbytes(), heuristic_.workspaceSize,
stream)); stream));
}); });
} }
@@ -358,9 +362,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
a_batch_strides.back(), a_batch_strides.back(),
b_batch_strides.back()); b_batch_strides.back());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto nbatch = batch_count / batch_shape.back();
if (nbatch == 1) {
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
return;
}
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) { for (size_t i = 0; i < nbatch; ++i) {
matmul.run( matmul.run(
encoder, encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N, out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
@@ -444,10 +457,28 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
b_batch_strides.back(), b_batch_strides.back(),
c_batch_strides.back()); c_batch_strides.back());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto nbatch = batch_count / batch_shape.back();
if (nbatch == 1) {
matmul.run(
encoder,
out.data<int8_t>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
alpha_,
beta_);
return;
}
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) { for (size_t i = 0; i < nbatch; ++i) {
matmul.run( matmul.run(
encoder, encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N, out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,

View File

@@ -28,7 +28,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&, this](cudaStream_t stream) { encoder.launch_kernel([&, this](cudaStream_t stream) {
MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(out.dtype(), "Arange", CTYPE, { dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag);
using OutType = cuda_type_t<CTYPE>; using OutType = cuda_type_t<CTYPE>;
CTYPE step = CTYPE step =
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_); static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
@@ -71,10 +72,8 @@ bool fast::ScaledDotProductAttention::use_fallback(
throw std::runtime_error(#func " has no CUDA implementation."); \ throw std::runtime_error(#func " has no CUDA implementation."); \
} }
NO_GPU(ArgPartition)
NO_GPU(BlockMaskedMM) NO_GPU(BlockMaskedMM)
NO_GPU(Convolution) NO_GPU(Convolution)
NO_GPU_MULTI(DivMod)
NO_GPU(DynamicSlice) NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate) NO_GPU(DynamicSliceUpdate)
NO_GPU(FFT) NO_GPU(FFT)
@@ -83,7 +82,6 @@ NO_GPU(GatherQMM)
NO_GPU(Hadamard) NO_GPU(Hadamard)
NO_GPU(Load) NO_GPU(Load)
NO_GPU_MULTI(LUF) NO_GPU_MULTI(LUF)
NO_GPU(Partition)
NO_GPU_MULTI(QRF) NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul) NO_GPU(QuantizedMatmul)
NO_GPU(Scan) NO_GPU(Scan)

View File

@@ -21,28 +21,11 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(!axes_.empty()); assert(!axes_.empty());
assert(out.size() != in.size()); assert(out.size() != in.size());
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream(); auto& s = stream();
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
// Fill out with init value.
if (in.size() == 0) { if (in.size() == 0) {
encoder.launch_kernel([&](cudaStream_t stream) { init_reduce(encoder, in, out, reduce_type_);
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, {
using InType = cuda_type_t<CTYPE>;
using OutType = cu::ReduceResult<OP, InType>::type;
thrust::fill_n(
cu::thrust_policy(stream),
thrust::device_pointer_cast(out.data<OutType>()),
out.data_size(),
cu::ReduceInit<OP, InType>::value());
});
});
});
return; return;
} }
@@ -51,7 +34,19 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// If it is a general reduce then copy the input to a contiguous array and // If it is a general reduce then copy the input to a contiguous array and
// recompute the plan. // recompute the plan.
if (plan.type == GeneralReduce) { //
// TODO: Instead of copying we can use elem-to-loc to deal with broadcasting
// like we do in Metal. When it comes to broadcasted reduction axes
// some can be ignored eg for min/max.
bool broadcasted = false;
for (int i = 0, j = 0; i < in.ndim() && !broadcasted; i++) {
if (j < axes_.size() && axes_[j] == i) {
j++;
} else {
broadcasted = in.strides(i) == 0;
}
}
if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {
array in_copy(in.shape(), in.dtype(), nullptr, {}); array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s); copy_gpu(in, in_copy, CopyType::General, s);
encoder.add_temporary(in_copy); encoder.add_temporary(in_copy);
@@ -59,9 +54,8 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
plan = get_reduction_plan(in, axes_); plan = get_reduction_plan(in, axes_);
} }
if ((plan.type == ContiguousAllReduce) || if (plan.type == ContiguousAllReduce) {
(plan.type == ContiguousReduce && plan.shape.size() == 1)) { all_reduce(encoder, in, out, reduce_type_);
segmented_reduce(encoder, in, out, reduce_type_, axes_, plan);
return; return;
} }

View File

@@ -0,0 +1,152 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/reduce/reduce.cuh"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/block/block_load.cuh>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T, typename U, typename ReduceOp, int N = 4>
__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
// TODO: Process multiple "rows" in each thread
constexpr int M = 1;
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
const U init = cu::ReduceInit<ReduceOp, T>::value();
ReduceOp op;
T vals[N];
U accs[M];
accs[0] = init;
size_t start = grid.block_rank() * block_step;
size_t end = start + block_step;
size_t check = min(end, size);
size_t i = start;
for (; i + block.size() * N <= check; i += block.size() * N) {
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
for (int j = 0; j < N; j++) {
accs[0] = op(accs[0], __cast<U, T>(vals[j]));
}
}
if (i < check) {
cub::LoadDirectBlocked(
block.thread_rank(), in + i, vals, check - i, __cast<T, U>(init));
for (int i = 0; i < N; i++) {
accs[0] = op(accs[0], __cast<U, T>(vals[i]));
}
}
__shared__ U shared_accumulators[32];
block_reduce(block, warp, accs, shared_accumulators, op, init);
if (block.thread_rank() == 0) {
out[grid.block_rank()] = accs[0];
}
}
} // namespace cu
void all_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type) {
constexpr int N_READS = 8;
out.set_data(allocator::malloc(out.nbytes()));
auto get_args = [](size_t size, int N) {
int threads = std::min(512UL, (size + N - 1) / N);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
int reductions_per_step = threads * N;
size_t steps_needed =
(size + reductions_per_step - 1) / reductions_per_step;
int blocks;
if (steps_needed < 32) {
blocks = 1;
} else if (steps_needed < 128) {
blocks = 32;
} else if (steps_needed < 512) {
blocks = 128;
} else if (steps_needed < 1024) {
blocks = 512;
} else {
blocks = 1024;
}
size_t steps_per_block = (steps_needed + blocks - 1) / blocks;
size_t block_step = steps_per_block * reductions_per_step;
return std::make_tuple(blocks, threads, block_step);
};
int blocks, threads;
size_t block_step;
size_t insize = in.size();
Dtype dt = in.dtype();
// Cub doesn't like const pointers for load (sigh).
void* indata = const_cast<void*>(in.data<void>());
// Large array so allocate an intermediate and accumulate there
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
encoder.set_input_array(in);
if (blocks > 1) {
array intermediate({blocks}, out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
encoder.add_temporary(intermediate);
encoder.set_output_array(intermediate);
encoder.launch_kernel([&](cudaStream_t stream) {
dispatch_all_types(dt, [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
kernel<<<blocks, threads, 0, stream>>>(
static_cast<T*>(indata),
intermediate.data<U>(),
block_step,
insize);
});
});
});
// Set the input for the next step and recalculate the blocks
indata = intermediate.data<void>();
dt = intermediate.dtype();
insize = intermediate.size();
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
encoder.set_input_array(intermediate);
}
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
dispatch_all_types(dt, [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
kernel<<<blocks, threads, 0, stream>>>(
static_cast<T*>(indata), out.data<U>(), block_step, insize);
});
});
});
}
} // namespace mlx::core

View File

@@ -1,5 +1,7 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include <numeric>
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/device/cast_op.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh"
@@ -36,19 +38,36 @@ struct ColReduceArgs {
const array& in, const array& in,
const ReductionPlan& plan, const ReductionPlan& plan,
const std::vector<int>& axes) { const std::vector<int>& axes) {
using ShapeVector = decltype(plan.shape);
using StridesVector = decltype(plan.strides);
ShapeVector shape_vec;
StridesVector strides_vec;
assert(!plan.shape.empty()); assert(!plan.shape.empty());
reduction_size = plan.shape.back(); reduction_size = plan.shape.back();
reduction_stride = plan.strides.back(); reduction_stride = plan.strides.back();
int64_t stride_back = 1; int64_t stride_back = 1;
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes); std::tie(shape_vec, strides_vec) = shapes_without_reduction_axes(in, axes);
while (!shape_vec.empty() && stride_back < reduction_stride) { while (!shape_vec.empty() && stride_back < reduction_stride) {
stride_back *= shape_vec.back(); stride_back *= shape_vec.back();
shape_vec.pop_back(); shape_vec.pop_back();
strides_vec.pop_back(); strides_vec.pop_back();
} }
std::vector<int> indices(shape_vec.size());
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(), [&](int left, int right) {
return strides_vec[left] > strides_vec[right];
});
ShapeVector sorted_shape;
StridesVector sorted_strides;
for (auto idx : indices) {
sorted_shape.push_back(shape_vec[idx]);
sorted_strides.push_back(strides_vec[idx]);
}
std::tie(shape_vec, strides_vec) = std::tie(shape_vec, strides_vec) =
collapse_contiguous_dims(shape_vec, strides_vec); collapse_contiguous_dims(sorted_shape, sorted_strides);
shape = const_param(shape_vec); shape = const_param(shape_vec);
strides = const_param(strides_vec); strides = const_param(strides_vec);
ndim = shape_vec.size(); ndim = shape_vec.size();
@@ -64,86 +83,6 @@ struct ColReduceArgs {
} }
}; };
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
__global__ void col_reduce_small(
const T* in,
U* out,
const __grid_constant__ ColReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
int column =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
if (column * N_READS >= args.reduction_stride) {
return;
}
int out_idx = grid.block_rank() / grid.dim_blocks().x;
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
Op op;
U totals[N_READS];
for (int i = 0; i < N_READS; i++) {
totals[i] = ReduceInit<Op, T>::value();
}
// Read input to local.
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(
block.thread_index().y,
args.reduce_shape.data(),
args.reduce_strides.data());
for (size_t r = block.thread_index().y;
r < args.non_col_reductions * args.reduction_size;
r += block.dim_threads().y) {
U vals[N_READS];
cub::LoadDirectBlocked(
column,
make_cast_iterator<U>(in + loop.location()),
vals,
args.reduction_stride,
ReduceInit<Op, T>::value());
for (int i = 0; i < N_READS; i++) {
totals[i] = op(vals[i], totals[i]);
}
loop.next(
block.dim_threads().y,
args.reduce_shape.data(),
args.reduce_strides.data());
}
// Do block reduce when each column has more than 1 element to reduce.
if (block.dim_threads().y > 1) {
__shared__ U shared_vals[32 * 8 * N_READS];
size_t col =
block.thread_index().y * block.dim_threads().x + block.thread_index().x;
for (int i = 0; i < N_READS; i++) {
shared_vals[col * N_READS + i] = totals[i];
}
block.sync();
if (block.thread_index().y == 0) {
for (int i = 0; i < N_READS; i++) {
totals[i] = shared_vals[block.thread_index().x * N_READS + i];
}
for (int j = 1; j < block.dim_threads().y; j++) {
col = j * block.dim_threads().x + block.thread_index().x;
for (int i = 0; i < N_READS; i++) {
totals[i] = op(shared_vals[col * N_READS + i], totals[i]);
}
}
}
}
// Write result.
if (block.thread_index().y == 0) {
cub::StoreDirectBlocked(
column,
out + out_idx * args.reduction_stride,
totals,
args.reduction_stride);
}
}
template < template <
typename T, typename T,
typename U, typename U,
@@ -152,67 +91,94 @@ template <
int BM, int BM,
int BN, int BN,
int N_READS = 4> int N_READS = 4>
__global__ void col_reduce_looped( __global__ void
const T* in, col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
U* out,
const __grid_constant__ ColReduceArgs args) {
auto grid = cg::this_grid(); auto grid = cg::this_grid();
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block); auto warp = cg::tiled_partition<WARP_SIZE>(block);
constexpr int n_warps = BN / N_READS; constexpr int threads_per_row = BN / N_READS;
int out_idx = grid.block_rank() / grid.dim_blocks().x; // Compute the indices for the tile
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); size_t tile_idx = grid.block_rank();
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
// Compute the indices for the thread within the tile
short thread_x = block.thread_rank() % threads_per_row;
short thread_y = block.thread_rank() / threads_per_row;
// Move the input pointer
in += elem_to_loc(tile_y, args.shape.data(), args.strides.data(), args.ndim) +
tile_x * BN;
// Initialize the running totals
Op op; Op op;
U totals[N_READS]; U totals[N_READS];
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
totals[i] = ReduceInit<Op, T>::value(); totals[i] = ReduceInit<Op, T>::value();
} }
// Read input to local.
int r = block.thread_rank() / n_warps;
int column = block.thread_rank() % n_warps;
int in_offset = grid.block_index().x * BN;
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim); LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(r, args.reduce_shape.data(), args.reduce_strides.data()); loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
for (; r < args.non_col_reductions * args.reduction_size; r += BM) { size_t total = args.non_col_reductions * args.reduction_size;
U vals[N_READS]; if (tile_x * BN + BN <= args.reduction_stride) {
cub::LoadDirectBlocked( if (args.reduction_stride % N_READS == 0) {
column, for (size_t r = thread_y; r < total; r += BM) {
make_cast_iterator<U>(in + loop.location() + in_offset), T vals[N_READS];
vals, cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
args.reduction_stride - in_offset, for (int i = 0; i < N_READS; i++) {
ReduceInit<Op, T>::value()); totals[i] = op(totals[i], __cast<U, T>(vals[i]));
for (int i = 0; i < N_READS; i++) { }
totals[i] = op(vals[i], totals[i]); loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
}
} else {
for (size_t r = thread_y; r < total; r += BM) {
T vals[N_READS];
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) {
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
}
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
}
}
} else {
for (size_t r = thread_y; r < total; r += BM) {
T vals[N_READS];
cub::LoadDirectBlocked(
thread_x,
in + loop.location(),
vals,
args.reduction_stride - tile_x * BN,
__cast<T, U>(ReduceInit<Op, T>::value()));
for (int i = 0; i < N_READS; i++) {
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
}
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
} }
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
} }
// Do warp reduce for each output. // Do warp reduce for each output.
constexpr int n_outputs = BN / n_warps; constexpr int n_outputs = BN / threads_per_row;
static_assert(BM == 32 && n_outputs == N_READS); static_assert(BM == 32 && n_outputs == N_READS);
__shared__ U shared_vals[BM * BN]; __shared__ U shared_vals[BM * BN];
size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS; short s_idx = thread_y * BN + thread_x * N_READS;
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
shared_vals[col + i] = totals[i]; shared_vals[s_idx + i] = totals[i];
} }
block.sync(); block.sync();
col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs; s_idx = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
for (int i = 0; i < n_outputs; i++) { for (int i = 0; i < n_outputs; i++) {
totals[i] = cg::reduce(warp, shared_vals[col + i], op); totals[i] = cg::reduce(warp, shared_vals[s_idx + i], op);
} }
// Write result. // Write result.
if (warp.thread_rank() == 0) { if (warp.thread_rank() == 0) {
size_t out_offset = grid.block_index().x * BN;
cub::StoreDirectBlocked( cub::StoreDirectBlocked(
warp.meta_group_rank(), warp.meta_group_rank(),
out + out_idx * args.reduction_stride + out_offset, out + tile_y * args.reduction_stride + tile_x * BN,
totals, totals,
args.reduction_stride - out_offset); args.reduction_stride - tile_x * BN);
} }
} }
@@ -220,14 +186,57 @@ __global__ void col_reduce_looped(
inline auto output_grid_for_col_reduce( inline auto output_grid_for_col_reduce(
const array& out, const array& out,
const cu::ColReduceArgs& args) { const cu::ColReduceArgs& args,
auto out_shape = out.shape(); int bn) {
auto out_strides = out.strides(); int gx, gy = 1;
while (!out_shape.empty() && out_strides.back() < args.reduction_stride) { size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
out_shape.pop_back(); size_t n_outer_blocks = out.size() / args.reduction_stride;
out_strides.pop_back(); size_t n_blocks = n_outer_blocks * n_inner_blocks;
while (n_blocks / gy > INT32_MAX) {
gy *= 2;
} }
return get_2d_grid_dims(out_shape, out_strides); gx = cuda::ceil_div(n_blocks, gy);
return dim3(gx, gy, 1);
}
void col_reduce_looped(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
cu::ColReduceArgs args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, args, BN);
int blocks = BM * BN / N_READS;
auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
kernel<<<grid, blocks, 0, stream>>>(indata, out.data<U>(), args);
});
});
});
});
} }
void col_reduce( void col_reduce(
@@ -237,42 +246,23 @@ void col_reduce(
Reduce::ReduceType reduce_type, Reduce::ReduceType reduce_type,
const std::vector<int>& axes, const std::vector<int>& axes,
const ReductionPlan& plan) { const ReductionPlan& plan) {
// Current col reduce options
//
// - col_reduce_looped
//
// It is a general strided reduce. Each threadblock computes the output for
// a subrow of the fast moving axis. For instance 32 elements.
//
// Notes: As in row reduce we opt to read as much in order as possible and
// leave transpositions as they are (contrary to our Metal backend).
//
// Moreover we need different kernels for short rows and tuning
// Make the args struct to help route to the best kernel
cu::ColReduceArgs args(in, plan, axes); cu::ColReduceArgs args(in, plan, axes);
encoder.launch_kernel([&](cudaStream_t stream) { // Fallback col reduce
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
using InType = cuda_type_t<CTYPE>;
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using OutType = cu::ReduceResult<OP, InType>::type;
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
constexpr int N_READS = 4;
dim3 block_dims;
dim3 num_blocks = output_grid_for_col_reduce(out, args);
num_blocks.z = num_blocks.y;
num_blocks.y = num_blocks.x;
auto kernel =
cu::col_reduce_small<InType, OutType, OP, NDIM, N_READS>;
size_t total = args.non_col_reductions * args.reduction_size;
if (total < 32) {
size_t stride_blocks =
cuda::ceil_div(args.reduction_stride, N_READS);
block_dims.x = std::min(stride_blocks, 32ul);
block_dims.y = std::min(total, 8ul);
num_blocks.x = cuda::ceil_div(stride_blocks, block_dims.x);
} else {
constexpr int BM = 32;
constexpr int BN = 32;
block_dims.x = BM * BN / N_READS;
num_blocks.x = cuda::ceil_div(args.reduction_stride, BN);
kernel = cu::
col_reduce_looped<InType, OutType, OP, NDIM, BM, BN, N_READS>;
}
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>(), out.data<OutType>(), args);
});
});
});
});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -0,0 +1,51 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/reduce/reduce.cuh"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T, typename U, typename Op>
__global__ void init_reduce(U* out, size_t size) {
auto index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = ReduceInit<Op, T>::value();
}
}
} // namespace cu
void init_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type) {
// Allocate if needed
if (out.data_shared_ptr() == nullptr) {
out.set_data(allocator::malloc(out.nbytes()));
}
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
auto kernel = cu::init_reduce<T, U, OP>;
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
grid.x = (grid.x + 1023) / 1024;
kernel<<<grid, block, 0, stream>>>(out.data<U>(), out.size());
});
});
});
}
} // namespace mlx::core

View File

@@ -1,5 +1,7 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include <type_traits>
#include "mlx/backend/common/reduce.h" #include "mlx/backend/common/reduce.h"
#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
@@ -9,51 +11,41 @@
namespace mlx::core { namespace mlx::core {
// Dispatch dynamic ndim to constexpr. template <typename F>
// The behavior follows get_kernel_reduce_ndim in metal/reduce.cpp file. void dispatch_reduce_ndim(int ndim, F&& f) {
#define MLX_SWITCH_REDUCE_NDIM(ndim, NDIM, ...) \ if (ndim == 1) {
if (ndim == 1) { \ f(std::integral_constant<int, 1>{});
constexpr uint32_t NDIM = 1; \ } else if (ndim == 2) {
__VA_ARGS__; \ f(std::integral_constant<int, 2>{});
} else if (ndim == 2) { \ } else {
constexpr uint32_t NDIM = 2; \ f(std::integral_constant<int, 5>{});
__VA_ARGS__; \
} else { \
constexpr uint32_t NDIM = 5; \
__VA_ARGS__; \
} }
}
// Dispatch reduce ops to constexpr. template <typename F>
#define MLX_SWITCH_REDUCE_OPS(REDUCE, OP, ...) \ void dispatch_reduce_ops(Reduce::ReduceType reduce_type, F&& f) {
if (REDUCE == Reduce::ReduceType::And) { \ if (reduce_type == Reduce::ReduceType::And) {
using OP = cu::And; \ f(type_identity<cu::And>{});
__VA_ARGS__; \ } else if (reduce_type == Reduce::ReduceType::Or) {
} else if (REDUCE == Reduce::ReduceType::Or) { \ f(type_identity<cu::Or>{});
using OP = cu::Or; \ } else if (reduce_type == Reduce::ReduceType::Sum) {
__VA_ARGS__; \ f(type_identity<cu::Sum>{});
} else if (REDUCE == Reduce::ReduceType::Sum) { \ } else if (reduce_type == Reduce::ReduceType::Prod) {
using OP = cu::Sum; \ f(type_identity<cu::Prod>{});
__VA_ARGS__; \ } else if (reduce_type == Reduce::ReduceType::Max) {
} else if (REDUCE == Reduce::ReduceType::Prod) { \ f(type_identity<cu::Max>{});
using OP = cu::Prod; \ } else if (reduce_type == Reduce::ReduceType::Min) {
__VA_ARGS__; \ f(type_identity<cu::Min>{});
} else if (REDUCE == Reduce::ReduceType::Max) { \ } else {
using OP = cu::Max; \ throw std::invalid_argument("Unknown reduce type.");
__VA_ARGS__; \
} else if (REDUCE == Reduce::ReduceType::Min) { \
using OP = cu::Min; \
__VA_ARGS__; \
} else { \
throw std::invalid_argument("Unknown reduce type."); \
} }
}
void segmented_reduce( void all_reduce(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
const array& in, const array& in,
array& out, array& out,
Reduce::ReduceType reduce_type, Reduce::ReduceType reduce_type);
const std::vector<int>& axes,
const ReductionPlan& plan);
void row_reduce( void row_reduce(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
@@ -71,4 +63,10 @@ void col_reduce(
const std::vector<int>& axes, const std::vector<int>& axes,
const ReductionPlan& plan); const ReductionPlan& plan);
void init_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -3,48 +3,89 @@
#pragma once #pragma once
#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/reduce/reduce_utils.cuh"
namespace mlx::core::cu { namespace mlx::core::cu {
// Reduce ops. // Reduce ops.
struct And { struct And {
__device__ bool operator()(bool a, bool b) { __device__ __forceinline__ bool operator()(bool a, bool b) {
return a && b; return a && b;
} }
__device__ void atomic_update(bool* x, bool y) {
atomic_reduce<bool, And>(x, y);
}
}; };
struct Or { struct Or {
__device__ bool operator()(bool a, bool b) { __device__ __forceinline__ bool operator()(bool a, bool b) {
return a || b; return a || b;
} }
__device__ void atomic_update(bool* x, bool y) {
atomic_reduce<bool, Or>(x, y);
}
}; };
struct Sum { struct Sum {
template <typename T> template <typename T>
__device__ T operator()(T a, T b) { __device__ __forceinline__ T operator()(T a, T b) {
return a + b; return a + b;
} }
template <typename T>
__device__ void atomic_update(T* x, T y) {
atomic_reduce<T, Sum>(x, y);
}
__device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) {
atomicAdd(x, y);
}
__device__ void atomic_update(int* x, int y) {
atomicAdd(x, y);
}
__device__ void atomic_update(float* x, float y) {
atomicAdd(x, y);
}
}; };
struct Prod { struct Prod {
template <typename T> template <typename T>
__device__ T operator()(T a, T b) { __device__ __forceinline__ T operator()(T a, T b) {
return a * b; return a * b;
} }
template <typename T>
__device__ void atomic_update(T* x, T y) {
atomic_reduce<T, Prod>(x, y);
}
}; };
struct Min { struct Min {
template <typename T> template <typename T>
__device__ T operator()(T a, T b) { __device__ __forceinline__ T operator()(T a, T b) {
return a < b ? a : b; return a < b ? a : b;
} }
template <typename T>
__device__ void atomic_update(T* x, T y) {
atomic_reduce<T, Min>(x, y);
}
}; };
struct Max { struct Max {
template <typename T> template <typename T>
__device__ T operator()(T a, T b) { __device__ __forceinline__ T operator()(T a, T b) {
return a > b ? a : b; return a > b ? a : b;
} }
template <typename T>
__device__ void atomic_update(T* x, T y) {
atomic_reduce<T, Max>(x, y);
}
}; };
// Traits to get the result type of reduce op. // Traits to get the result type of reduce op.
@@ -120,7 +161,7 @@ template <typename T>
struct ReduceInit<Prod, T> { struct ReduceInit<Prod, T> {
static constexpr __host__ __device__ auto value() { static constexpr __host__ __device__ auto value() {
if constexpr (cuda::std::is_same_v<T, cuComplex>) { if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return T{1, 1}; return T{1, 0};
} else { } else {
return typename ReduceResult<Prod, T>::type{1}; return typename ReduceResult<Prod, T>::type{1};
} }

View File

@@ -0,0 +1,158 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <numeric>
#include "mlx/backend/cuda/device/utils.cuh"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <size_t N>
struct uint_by_size;
template <>
struct uint_by_size<2> {
using type = uint16_t;
};
template <>
struct uint_by_size<4> {
using type = uint32_t;
};
template <>
struct uint_by_size<8> {
using type = unsigned long long int;
};
template <typename T, typename Op>
__device__ void atomic_reduce(T* x, T y) {
if constexpr (sizeof(T) == 1) {
using U = uint16_t;
U* x_int = (U*)((char*)x - ((size_t)x % 2));
int shift = ((char*)x - (char*)x_int) * 8;
int mask = 0xff << shift;
U old_val, new_val;
do {
old_val = *x_int;
T result = Op{}(static_cast<T>((old_val >> shift) & 0xff), y);
new_val = (old_val & ~mask) | (result << shift);
} while (atomicCAS(x_int, old_val, new_val) != old_val);
} else {
using U = typename uint_by_size<sizeof(T)>::type;
U* x_int = (U*)(x);
U old_val, new_val;
do {
old_val = *x_int;
T result = Op{}(*((T*)&old_val), y);
new_val = *((U*)&result);
} while (atomicCAS(x_int, old_val, new_val) != old_val);
}
}
// TODO: Should make a custom complex type
template <typename U, typename T>
inline __device__ U __cast(T x) {
return static_cast<U>(x);
}
template <>
inline __device__ bool __cast<bool, cuComplex>(cuComplex x) {
return x.x != 0 && x.y != 0;
}
template <>
inline __device__ cuComplex __cast<cuComplex, bool>(bool x) {
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
}
template <typename T, int N, typename Block, typename Warp, typename Op>
inline __device__ void
block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) {
// First reduce in the current warp
for (int i = 0; i < N; i++) {
vals[i] = cg::reduce(warp, vals[i], op);
}
// Reduce across warps
if (warp.meta_group_size() > 1) {
if (warp.thread_rank() == 0) {
for (int i = 0; i < N; i++) {
smem[warp.meta_group_rank() * N + i] = vals[i];
}
}
block.sync();
if (warp.thread_rank() < warp.meta_group_size()) {
for (int i = 0; i < N; i++) {
vals[i] = smem[warp.thread_rank() * N + i];
}
} else {
for (int i = 0; i < N; i++) {
vals[i] = init;
}
}
for (int i = 0; i < N; i++) {
vals[i] = cg::reduce(warp, vals[i], op);
}
}
}
} // namespace cu
inline void allocate_same_layout(
array& out,
const array& in,
const std::vector<int>& axes) {
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
return;
}
if (out.ndim() < in.ndim()) {
throw std::runtime_error(
"Reduction without keepdims only supported for row-contiguous inputs");
}
// Calculate the transpositions applied to in in order to apply them to out.
std::vector<int> axis_order(in.ndim());
std::iota(axis_order.begin(), axis_order.end(), 0);
std::sort(axis_order.begin(), axis_order.end(), [&](int left, int right) {
return in.strides(left) > in.strides(right);
});
// Transpose the shape and calculate the strides
Shape out_shape(in.ndim());
Strides out_strides(in.ndim(), 1);
for (int i = 0; i < in.ndim(); i++) {
out_shape[i] = out.shape(axis_order[i]);
}
for (int i = in.ndim() - 2; i >= 0; i--) {
out_strides[i] = out_shape[i + 1] * out_strides[i + 1];
}
// Reverse the axis order to get the final strides
Strides final_strides(in.ndim());
for (int i = 0; i < in.ndim(); i++) {
final_strides[axis_order[i]] = out_strides[i];
}
// Calculate the resulting contiguity and do the memory allocation
auto [data_size, rc, cc] = check_contiguity(out.shape(), final_strides);
auto fl = in.flags();
fl.row_contiguous = rc;
fl.col_contiguous = cc;
fl.contiguous = true;
out.set_data(
allocator::malloc(out.nbytes()),
data_size,
final_strides,
fl,
allocator::free);
}
} // namespace mlx::core

View File

@@ -1,5 +1,7 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include <numeric>
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/device/cast_op.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh"
@@ -55,84 +57,108 @@ struct RowReduceArgs {
non_row_reductions *= reduce_shape[i]; non_row_reductions *= reduce_shape[i];
} }
} }
// Convert shape and strides as if in was contiguous
void sort_access_pattern(const array& in, const std::vector<int>& axes) {
auto shape_vec = in.shape();
auto strides_vec = in.strides();
std::tie(shape_vec, strides_vec) =
shapes_without_reduction_axes(shape_vec, strides_vec, axes);
std::vector<int> indices(shape_vec.size());
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(), [&](int left, int right) {
return strides_vec[left] > strides_vec[right];
});
decltype(shape_vec) sorted_shape;
decltype(strides_vec) sorted_strides;
for (auto idx : indices) {
sorted_shape.push_back(shape_vec[idx]);
sorted_strides.push_back(strides_vec[idx]);
}
std::tie(shape_vec, strides_vec) =
collapse_contiguous_dims(sorted_shape, sorted_strides);
shape = const_param(shape_vec);
strides = const_param(strides_vec);
ndim = shape_vec.size();
}
}; };
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4> template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
__global__ void row_reduce_small( __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
const T* in,
U* out,
size_t out_size,
const __grid_constant__ RowReduceArgs args) {
size_t out_idx = cg::this_grid().thread_rank();
if (out_idx >= out_size) {
return;
}
Op op;
U total_val = ReduceInit<Op, T>::value();
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
for (size_t n = 0; n < args.non_row_reductions; n++) {
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
U vals[N_READS];
cub::LoadDirectBlocked(
r,
make_cast_iterator<U>(in + loop.location()),
vals,
args.row_size,
ReduceInit<Op, T>::value());
total_val = op(total_val, cub::ThreadReduce(vals, op));
}
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
out[out_idx] = total_val;
}
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
__global__ void row_reduce_small_warp(
const T* in,
U* out,
size_t out_size,
const __grid_constant__ RowReduceArgs args) {
auto grid = cg::this_grid(); auto grid = cg::this_grid();
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block); auto warp = cg::tiled_partition<WARP_SIZE>(block);
size_t out_idx = grid.thread_rank() / WARP_SIZE; const U init = cu::ReduceInit<ReduceOp, T>::value();
if (out_idx >= out_size) { ReduceOp op;
return;
T vals[M][N];
U accs[M];
for (int i = 0; i < M; i++) {
accs[i] = init;
} }
Op op; const size_t start_row =
min(n_rows - M, static_cast<size_t>(grid.block_rank() * M));
const size_t full_blocks = size / (block.size() * N);
const size_t final_offset = full_blocks * (block.size() * N);
in += start_row * size;
out += start_row;
U total_val = ReduceInit<Op, T>::value(); if (size % N == 0) {
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim); for (size_t r = 0; r < full_blocks; r++) {
for (int k = 0; k < M; k++) {
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); cub::LoadDirectBlockedVectorized<T, N>(
block.thread_rank(),
for (size_t n = warp.thread_rank(); n < args.non_row_reductions; in + k * size + r * (block.size() * N),
n += WARP_SIZE) { vals[k]);
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) { for (int j = 0; j < N; j++) {
U vals[N_READS]; accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
cub::LoadDirectBlocked( }
r, }
make_cast_iterator<U>(in + loop.location()), }
vals, } else {
args.row_size, for (size_t r = 0; r < full_blocks; r++) {
ReduceInit<Op, T>::value()); for (int k = 0; k < M; k++) {
total_val = op(total_val, cub::ThreadReduce(vals, op)); cub::LoadDirectBlocked(
block.thread_rank(),
in + k * size + r * (block.size() * N),
vals[k]);
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
}
}
} }
loop.next(WARP_SIZE, args.reduce_shape.data(), args.reduce_strides.data());
} }
total_val = cg::reduce(warp, total_val, op); if (final_offset < size) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlocked(
block.thread_rank(),
in + k * size + final_offset,
vals[k],
size,
__cast<T, U>(init));
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
}
}
}
if (warp.thread_rank() == 0) { __shared__ U shared_accumulators[32 * M];
out[out_idx] = total_val; block_reduce(block, warp, accs, shared_accumulators, op, init);
if (block.thread_rank() == 0) {
if (grid.block_rank() * M + M <= n_rows) {
for (int i = 0; i < M; i++) {
out[i] = accs[i];
}
} else {
short offset = grid.block_rank() * M + M - n_rows;
for (int i = offset; i < M; i++) {
out[i] = accs[i];
}
}
} }
} }
@@ -141,55 +167,173 @@ template <
typename U, typename U,
typename Op, typename Op,
int NDIM, int NDIM,
int BLOCK_DIM_X, int BLOCK_DIM,
int N_READS = 4> int N_READS = 4>
__global__ void row_reduce_looped( __global__ void row_reduce_looped(
const T* in, T* in,
U* out, U* out,
size_t out_size, size_t out_size,
const __grid_constant__ RowReduceArgs args) { const __grid_constant__ RowReduceArgs args) {
auto grid = cg::this_grid(); auto grid = cg::this_grid();
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
size_t out_idx = grid.thread_rank() / BLOCK_DIM_X; size_t out_idx = grid.block_rank();
if (out_idx >= out_size) {
return;
}
Op op; Op op;
U total_val = ReduceInit<Op, T>::value(); U total[1];
U init = ReduceInit<Op, T>::value();
total[0] = init;
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim); LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS);
size_t final_offset = full_blocks * BLOCK_DIM * N_READS;
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
for (size_t n = 0; n < args.non_row_reductions; n++) { for (size_t n = 0; n < args.non_row_reductions; n++) {
for (size_t r = 0; r < cuda::ceil_div(args.row_size, BLOCK_DIM_X * N_READS); for (size_t r = 0; r < full_blocks; r++) {
r++) { T vals[N_READS];
U vals[N_READS]; cub::LoadDirectBlockedVectorized<T, N_READS>(
cub::LoadDirectBlocked( block.thread_rank(),
r * BLOCK_DIM_X + block.thread_index().x, in + loop.location() + r * BLOCK_DIM * N_READS,
make_cast_iterator<U>(in + loop.location()), vals);
vals, for (int i = 0; i < N_READS; i++) {
args.row_size, total[0] = op(total[0], __cast<U, T>(vals[i]));
ReduceInit<Op, T>::value()); }
total_val = op(total_val, cub::ThreadReduce(vals, op));
} }
if (final_offset < args.row_size) {
T vals[N_READS];
cub::LoadDirectBlocked(
block.thread_rank(),
in + loop.location() + final_offset,
vals,
args.row_size - final_offset,
__cast<T, U>(init));
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], __cast<U, T>(vals[i]));
}
}
// TODO: Maybe block.sync() here?
loop.next(args.reduce_shape.data(), args.reduce_strides.data()); loop.next(args.reduce_shape.data(), args.reduce_strides.data());
} }
typedef cub::BlockReduce<U, BLOCK_DIM_X> BlockReduceT; __shared__ U shared_accumulators[32];
__shared__ typename BlockReduceT::TempStorage temp; block_reduce(block, warp, total, shared_accumulators, op, init);
total_val = BlockReduceT(temp).Reduce(total_val, op);
if (block.thread_rank() == 0) { if (block.thread_rank() == 0) {
out[out_idx] = total_val; out[out_idx] = total[0];
} }
} }
} // namespace cu } // namespace cu
void row_reduce_simple(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan) {
constexpr int N_READS = 8;
// Allocate data for the output using in's layout to avoid elem_to_loc in the
// kernel.
allocate_same_layout(out, in, axes);
// TODO: If out.size() < 1024 which will be a common case then write this in
// 2 passes. Something like 32 * out.size() and then do a warp reduce.
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
// Calculate the grid and block dims
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
int threads = std::min(1024UL, reductions);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
dim3 block(threads, 1, 1);
// Pick the kernel
auto kernel = cu::row_reduce_simple<T, U, OP, N_READS>;
if (grid.x >= 1024) {
grid.x = (grid.x + 1) / 2;
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
}
// Launch
kernel<<<grid, block, 0, stream>>>(
indata, out.data<U>(), out.size(), plan.shape.back());
});
});
});
}
void row_reduce_looped(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
cu::RowReduceArgs args) {
constexpr int N_READS = 8;
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
// Calculate the grid and block dims
args.sort_access_pattern(in, axes);
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
int threads = std::min(1024UL, reductions);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
dim3 block(threads, 1, 1);
// Pick the kernel
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
dispatch_block_dim(threads, [&](auto threads_constant) {
kernel = cu::row_reduce_looped<
T,
U,
OP,
reduce_ndim(),
threads_constant(),
N_READS>;
block.x = threads_constant();
});
});
// Launch
kernel<<<grid, block, 0, stream>>>(
indata, out.data<U>(), out.size(), args);
});
});
});
}
void row_reduce( void row_reduce(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
const array& in, const array& in,
@@ -197,54 +341,35 @@ void row_reduce(
Reduce::ReduceType reduce_type, Reduce::ReduceType reduce_type,
const std::vector<int>& axes, const std::vector<int>& axes,
const ReductionPlan& plan) { const ReductionPlan& plan) {
// Current row reduction options
//
// - row_reduce_simple
//
// That means that we are simply reducing across the fastest moving axis.
// We are reducing 1 or 2 rows per threadblock depending on the size of
// output.
//
// - row_reduce_looped
//
// It is a general row reduction. We are computing 1 output per
// threadblock. We read the fastest moving axis vectorized and loop over
// the rest of the axes.
//
// Notes: We opt to read as much in order as possible and leave
// transpositions as they are (contrary to our Metal backend).
// Simple row reduce means that we have 1 axis that we are reducing over and
// it has stride 1.
if (plan.shape.size() == 1) {
row_reduce_simple(encoder, in, out, reduce_type, axes, plan);
return;
}
// Make the args struct to help route to the best kernel
cu::RowReduceArgs args(in, plan, axes); cu::RowReduceArgs args(in, plan, axes);
encoder.launch_kernel([&](cudaStream_t stream) { // Fallback row reduce
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { row_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args));
using InType = cuda_type_t<CTYPE>;
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using OutType = cu::ReduceResult<OP, InType>::type;
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
constexpr size_t N_READS = 4;
dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides());
dim3 block_dims, num_blocks;
auto kernel =
cu::row_reduce_small<InType, OutType, OP, NDIM, N_READS>;
if (args.row_size <= 64) {
if ((args.non_row_reductions < 32 && args.row_size <= 8) ||
(args.non_row_reductions <= 8)) {
block_dims.x = std::min(out_dims.x, 1024u);
num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x);
num_blocks.y = out_dims.y;
} else {
block_dims.x = WARP_SIZE;
num_blocks.y = out_dims.x;
num_blocks.z = out_dims.y;
kernel =
cu::row_reduce_small_warp<InType, OutType, OP, NDIM, N_READS>;
}
} else {
size_t num_threads = cuda::ceil_div(args.row_size, N_READS);
num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE;
MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, {
num_blocks.y = out_dims.x;
num_blocks.z = out_dims.y;
block_dims.x = BLOCK_DIM_X;
kernel = cu::row_reduce_looped<
InType,
OutType,
OP,
NDIM,
BLOCK_DIM_X,
N_READS>;
});
}
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>(), out.data<OutType>(), out.size(), args);
});
});
});
});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,84 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/cast_op.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh"
#include <thrust/device_ptr.h>
#include <cub/device/device_reduce.cuh>
#include <cub/device/device_segmented_reduce.cuh>
namespace mlx::core {
template <typename... Args>
void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) {
// Allocate temporary storage.
size_t size;
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Run op.
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data<void>(), size, args...));
}
template <typename... Args>
void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) {
// Allocate temporary storage.
size_t size;
CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Run op.
CHECK_CUDA_ERROR(
cub::DeviceSegmentedReduce::Reduce(temp.data<void>(), size, args...));
}
struct MultiplyOp {
int factor;
__device__ int operator()(int i) {
return i * factor;
}
};
void segmented_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan) {
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using InType = cuda_type_t<CTYPE>;
using OutType = cu::ReduceResult<OP, InType>::type;
auto in_iter = cu::make_cast_iterator<OutType>(
thrust::device_pointer_cast(in.data<InType>()));
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
auto init = cu::ReduceInit<OP, InType>::value();
if (plan.type == ContiguousAllReduce) {
cub_all_reduce(
encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream);
} else if (plan.type == ContiguousReduce) {
auto offsets = thrust::make_transform_iterator(
thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()});
cub_segmented_reduce(
encoder,
in_iter,
out_ptr,
out.size(),
offsets,
offsets + 1,
OP(),
init,
stream);
} else {
throw std::runtime_error("Unsupported plan in segmented_reduce.");
}
});
});
});
}
} // namespace mlx::core

View File

@@ -225,19 +225,20 @@ void RMSNorm::eval_gpu(
encoder.set_input_array(w); encoder.set_input_array(w);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "rms_norm", CTYPE, { dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
using DataType = cuda_type_t<CTYPE>;
constexpr uint32_t N_READS = 4; constexpr uint32_t N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { dispatch_block_dim(
auto kernel = cu::rms_norm<DataType, BLOCK_DIM, N_READS>; cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>( using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
x.data<DataType>(), auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>;
w.data<DataType>(), kernel<<<n_rows, block_dim(), 0, stream>>>(
out.data<DataType>(), x.data<DataType>(),
eps_, w.data<DataType>(),
axis_size, out.data<DataType>(),
w_stride); eps_,
}); axis_size,
w_stride);
});
}); });
}); });
} }
@@ -303,7 +304,6 @@ void RMSNormVJP::eval_gpu(
encoder.add_temporary(gw_temp); encoder.add_temporary(gw_temp);
} }
} }
gw.set_data(allocator::malloc(gw.nbytes()));
encoder.set_input_array(x); encoder.set_input_array(x);
encoder.set_input_array(w); encoder.set_input_array(w);
@@ -311,22 +311,28 @@ void RMSNormVJP::eval_gpu(
encoder.set_output_array(gx); encoder.set_output_array(gx);
encoder.set_output_array(gw_temp); encoder.set_output_array(gw_temp);
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "rms_norm_vjp", CTYPE, { dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) {
using DataType = cuda_type_t<CTYPE>; dispatch_bool(has_w, [&](auto has_w_constant) {
constexpr int N_READS = 4; constexpr int N_READS = 4;
MLX_SWITCH_BOOL(has_w, HAS_W, { dispatch_block_dim(
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
auto kernel = cu::rms_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>; using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>( constexpr int N_READS = 4;
x.data<DataType>(), auto kernel = cu::rms_norm_vjp<
w.data<DataType>(), DataType,
g.data<DataType>(), has_w_constant(),
gx.data<DataType>(), block_dim(),
gw_temp.data<DataType>(), N_READS>;
eps_, kernel<<<n_rows, block_dim(), 0, stream>>>(
axis_size, x.data<DataType>(),
w_stride); w.data<DataType>(),
}); g.data<DataType>(),
gx.data<DataType>(),
gw_temp.data<DataType>(),
eps_,
axis_size,
w_stride);
});
}); });
}); });
}); });

View File

@@ -310,12 +310,12 @@ void RoPE::eval_gpu(
encoder.set_input_array(offset); encoder.set_input_array(offset);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(in.dtype(), "rope", CTYPE, { dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) {
using DataType = cuda_type_t<CTYPE>; dispatch_bool(traditional_, [&](auto traditional) {
MLX_SWITCH_BOOL(traditional_, TRADITIONAL, { dispatch_bool(forward_, [&](auto forward) {
MLX_SWITCH_BOOL(forward_, FORWARD, { using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if (single && !with_freqs) { if (single && !with_freqs) {
auto kernel = cu::rope_single<DataType, TRADITIONAL, FORWARD>; auto kernel = cu::rope_single<DataType, traditional(), forward()>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
kernel<<<grid, block, 0, stream>>>( kernel<<<grid, block, 0, stream>>>(
@@ -327,7 +327,8 @@ void RoPE::eval_gpu(
mat_size, mat_size,
dims); dims);
} else if (single) { } else if (single) {
auto kernel = cu::rope_single_freqs<DataType, TRADITIONAL, FORWARD>; auto kernel =
cu::rope_single_freqs<DataType, traditional(), forward()>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
kernel<<<grid, block, 0, stream>>>( kernel<<<grid, block, 0, stream>>>(
@@ -340,7 +341,7 @@ void RoPE::eval_gpu(
dims, dims,
inputs[2].strides(0)); inputs[2].strides(0));
} else if (with_freqs) { } else if (with_freqs) {
auto kernel = cu::rope_freqs<DataType, TRADITIONAL, FORWARD>; auto kernel = cu::rope_freqs<DataType, traditional(), forward()>;
uint3 dims = uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4; dims.z = (dims.z + 3) / 4;
@@ -358,7 +359,7 @@ void RoPE::eval_gpu(
dims, dims,
inputs[2].strides(0)); inputs[2].strides(0));
} else { } else {
auto kernel = cu::rope<DataType, TRADITIONAL, FORWARD>; auto kernel = cu::rope<DataType, traditional(), forward()>;
uint3 dims = uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4; dims.z = (dims.z + 3) / 4;

View File

@@ -51,7 +51,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
make_cast_iterator<AccT>(in), make_cast_iterator<AccT>(in),
vals, vals,
axis_size, axis_size,
Limits<AccT>::finite_min()); Limits<AccT>::min());
prevmax = maxval; prevmax = maxval;
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op)); maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
// Online normalizer calculation for softmax: // Online normalizer calculation for softmax:
@@ -79,7 +79,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
block.sync(); block.sync();
maxval = warp.thread_rank() < warp.meta_group_size() maxval = warp.thread_rank() < warp.meta_group_size()
? local_max[warp.thread_rank()] ? local_max[warp.thread_rank()]
: Limits<AccT>::finite_min(); : Limits<AccT>::min();
maxval = cg::reduce(warp, maxval, max_op); maxval = cg::reduce(warp, maxval, max_op);
normalizer = normalizer * softmax_exp(prevmax - maxval); normalizer = normalizer * softmax_exp(prevmax - maxval);
if (warp.thread_rank() == 0) { if (warp.thread_rank() == 0) {
@@ -142,17 +142,18 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "softmax", CTYPE, { dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) {
using DataType = cuda_type_t<CTYPE>;
constexpr int N_READS = 4; constexpr int N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { dispatch_block_dim(
auto kernel = cu::softmax<DataType, DataType, BLOCK_DIM, N_READS>; cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
if (precise) { using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
kernel = cu::softmax<DataType, float, BLOCK_DIM, N_READS>; auto kernel = cu::softmax<DataType, DataType, block_dim(), N_READS>;
} if (precise) {
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>( kernel = cu::softmax<DataType, float, block_dim(), N_READS>;
in.data<DataType>(), out.data<DataType>(), axis_size); }
}); kernel<<<n_rows, block_dim(), 0, stream>>>(
in.data<DataType>(), out.data<DataType>(), axis_size);
});
}); });
}); });
} }

View File

@@ -76,17 +76,21 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
temp.data<void>(), size, args...)); temp.data<void>(), size, args...));
} }
struct OffsetTransform {
int nsort;
int __device__ operator()(int i) {
return i * nsort;
}
};
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
array out = out_; array out = out_;
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
if (axis < 0) { if (axis < 0) {
axis += in.ndim(); axis += in.ndim();
} }
int nsort = in.shape(axis); int nsort = in.shape(axis);
int nsegments = in.data_size() / nsort;
int last_dim = in.ndim() - 1; int last_dim = in.ndim() - 1;
// If we are not sorting the innermost dimension of a contiguous array, // If we are not sorting the innermost dimension of a contiguous array,
@@ -100,16 +104,22 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
encoder.add_temporary(out); encoder.add_temporary(out);
} else { } else {
out.set_data(allocator::malloc(out.nbytes())); out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
} }
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { dispatch_all_types(in.dtype(), [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag);
if constexpr (!std::is_same_v<CTYPE, complex64_t>) { if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
using Type = cuda_type_t<CTYPE>; using Type = cuda_type_t<CTYPE>;
auto offsets = thrust::make_transform_iterator( auto offsets = thrust::make_transform_iterator(
thrust::make_counting_iterator(0), thrust::make_counting_iterator(0), OffsetTransform{nsort});
[nsort] __device__(int i) { return i * nsort; });
if (argsort) { if (argsort) {
// Indices in the sorted dimension. // Indices in the sorted dimension.
array indices( array indices(
@@ -134,7 +144,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
indices.data<uint32_t>(), indices.data<uint32_t>(),
out.data<uint32_t>(), out.data<uint32_t>(),
in.data_size(), in.data_size(),
nsegments, in.data_size() / nsort,
offsets, offsets,
offsets + 1, offsets + 1,
stream); stream);
@@ -144,7 +154,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.data<Type>(), in.data<Type>(),
out.data<Type>(), out.data<Type>(),
in.data_size(), in.data_size(),
nsegments, in.data_size() / nsort,
offsets, offsets,
offsets + 1, offsets + 1,
stream); stream);
@@ -177,4 +187,14 @@ void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
gpu_sort(stream(), inputs[0], out, axis_, false); gpu_sort(stream(), inputs[0], out, axis_, false);
} }
void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("ArgPartition::eval_gpu");
gpu_sort(stream(), inputs[0], out, axis_, true);
}
void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Partition::eval_gpu");
gpu_sort(stream(), inputs[0], out, axis_, false);
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -92,58 +92,63 @@ void ternary_op_gpu_inplace(
encoder.set_input_array(c); encoder.set_input_array(c);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE, { dispatch_all_types(out.dtype(), [&](auto type_tag) {
using DType = cuda_type_t<CTYPE>; using DType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto topt = get_ternary_op_type(a, b, c); auto topt = get_ternary_op_type(a, b, c);
if (topt == TernaryOpType::General) { if (topt == TernaryOpType::General) {
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out); dispatch_bool(
auto& a_strides = strides[0]; a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
auto& b_strides = strides[1]; c.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
auto& c_strides = strides[2]; [&](auto large) {
bool large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || using IdxT = std::conditional_t<large(), int64_t, int32_t>;
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX; Shape shape;
MLX_SWITCH_BOOL(large, LARGE, { std::vector<Strides> strides;
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>; std::tie(shape, strides) = collapse_contiguous_dims(a, b, c, out);
int ndim = shape.size(); auto& a_strides = strides[0];
if (ndim <= 3) { auto& b_strides = strides[1];
MLX_SWITCH_1_2_3(ndim, NDIM, { auto& c_strides = strides[2];
auto kernel = cu::ternary_g_nd<Op, DType, IdxT, NDIM>; int ndim = shape.size();
auto [num_blocks, block_dims] = if (ndim <= 3) {
get_launch_args(kernel, out, large); dispatch_1_2_3(ndim, [&](auto dims_constant) {
kernel<<<num_blocks, block_dims, 0, stream>>>( auto kernel =
a.data<bool>(), cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>;
b.data<DType>(), auto [num_blocks, block_dims] =
c.data<DType>(), get_launch_args(kernel, out, large());
out.data<DType>(), kernel<<<num_blocks, block_dims, 0, stream>>>(
out.size(), a.data<bool>(),
const_param<NDIM>(shape), b.data<DType>(),
const_param<NDIM>(a_strides), c.data<DType>(),
const_param<NDIM>(b_strides), out.data<DType>(),
const_param<NDIM>(c_strides)); out.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides),
const_param<dims_constant()>(c_strides));
});
} else {
auto kernel = cu::ternary_g<Op, DType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
out.data_size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
const_param(c_strides),
ndim);
}
}); });
} else {
auto kernel = cu::ternary_g<Op, DType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
out.data_size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
const_param(c_strides),
ndim);
}
});
} else { } else {
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>; using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
auto kernel = cu::ternary_v<Op, DType, IdxT>; auto kernel = cu::ternary_v<Op, DType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args( auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), LARGE); kernel, out.data_size(), out.shape(), out.strides(), large());
kernel<<<num_blocks, block_dims, 0, stream>>>( kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<bool>(), a.data<bool>(),
b.data<DType>(), b.data<DType>(),

View File

@@ -20,38 +20,35 @@ namespace cu {
template <typename Op, typename In, typename Out> template <typename Op, typename In, typename Out>
constexpr bool supports_unary_op() { constexpr bool supports_unary_op() {
if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> || if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||
std::is_same_v<Op, Sign>) { std::is_same_v<Op, Sign> || std::is_same_v<Op, Square>) {
return std::is_same_v<In, Out>; return std::is_same_v<In, Out>;
} }
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcCosh> || if (std::is_same_v<Op, ArcCosh> || std::is_same_v<Op, ArcSinh> ||
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> || std::is_same_v<Op, ArcTanh> || std::is_same_v<Op, Erf> ||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> || std::is_same_v<Op, ErfInv> || std::is_same_v<Op, Expm1> ||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> || std::is_same_v<Op, Sigmoid>) {
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Sigmoid> ||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
return std::is_same_v<In, Out> && is_floating_v<In>; return std::is_same_v<In, Out> && is_floating_v<In>;
} }
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p>) {
return std::is_same_v<In, Out> && is_inexact_v<In>;
}
if (std::is_same_v<Op, BitwiseInvert>) { if (std::is_same_v<Op, BitwiseInvert>) {
return std::is_same_v<In, Out> && std::is_integral_v<In> && return std::is_same_v<In, Out> && std::is_integral_v<In> &&
!std::is_same_v<In, bool>; !std::is_same_v<In, bool>;
} }
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor> || if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor>) {
std::is_same_v<Op, Square>) {
return std::is_same_v<In, Out> && !std::is_same_v<In, complex64_t>; return std::is_same_v<In, Out> && !std::is_same_v<In, complex64_t>;
} }
if (std::is_same_v<Op, Conjugate>) { if (std::is_same_v<Op, Conjugate>) {
return std::is_same_v<In, Out> && std::is_same_v<In, complex64_t>; return std::is_same_v<In, Out> && std::is_same_v<In, complex64_t>;
} }
if (std::is_same_v<Op, Cos> || std::is_same_v<Op, Cosh> || if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcSin> ||
std::is_same_v<Op, Exp> || std::is_same_v<Op, Round> || std::is_same_v<Op, ArcTan> || std::is_same_v<Op, Cos> ||
std::is_same_v<Op, Sin> || std::is_same_v<Op, Sinh> || std::is_same_v<Op, Cosh> || std::is_same_v<Op, Exp> ||
std::is_same_v<Op, Tan> || std::is_same_v<Op, Tanh>) { std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
return std::is_same_v<In, Out> && std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p> ||
(is_floating_v<In> || std::is_same_v<In, complex64_t>); std::is_same_v<Op, Round> || std::is_same_v<Op, Rsqrt> ||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Sin> ||
std::is_same_v<Op, Sinh> || std::is_same_v<Op, Tan> ||
std::is_same_v<Op, Tanh>) {
return std::is_same_v<In, Out> && is_inexact_v<In>;
} }
if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) { if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
return std::is_same_v<In, complex64_t> && std::is_same_v<Out, float>; return std::is_same_v<In, complex64_t> && std::is_same_v<Out, float>;
@@ -79,8 +76,10 @@ void unary_op_gpu_inplace(
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) { if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>; using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>; using OutType = cuda_type_t<CTYPE_OUT>;

View File

@@ -25,22 +25,38 @@ void check_cuda_error(const char* name, cudaError_t err) {
} }
const char* dtype_to_cuda_type(const Dtype& dtype) { const char* dtype_to_cuda_type(const Dtype& dtype) {
if (dtype == float16) { switch (dtype) {
return "__half"; case bool_:
return "bool";
case int8:
return "int8_t";
case int16:
return "int16_t";
case int32:
return "int32_t";
case int64:
return "int64_t";
case uint8:
return "uint8_t";
case uint16:
return "uint16_t";
case uint32:
return "uint32_t";
case uint64:
return "uint64_t";
case float16:
return "__half";
case bfloat16:
return "__nv_bfloat16";
case float32:
return "float";
case float64:
return "double";
case complex64:
return "cuComplex";
default:
return "unknown";
} }
if (dtype == bfloat16) {
return "__nv_bfloat16";
}
if (dtype == complex64) {
return "cuComplex";
}
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
if (dtype == DTYPE) { \
return #CPP_TYPE; \
}
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString)
#undef SPECIALIZE_DtypeToString
return nullptr;
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -80,7 +80,9 @@ void Worker::thread_fn() {
} }
worker_tasks_.erase(worker_tasks_.begin(), end); worker_tasks_.erase(worker_tasks_.begin(), end);
} }
for (auto& task : tasks) { // Make sure tasks are cleared before the next wait
for (int i = 0; i < tasks.size(); ++i) {
auto task = std::move(tasks[i]);
task(); task();
} }
worker_event_.wait(batch + 1); worker_event_.wait(batch + 1);

View File

@@ -245,6 +245,30 @@ void merge(array& dst, array& src, ParentsMap& parents_map) {
} }
} }
// Any parent in the divider will continue to refer to `x` but any parent not
// in the divider will refer to a copy of the operation.
array split_one(
const array& x,
ParentsMap& parents_map,
const std::unordered_set<uintptr_t>& divider) {
array y(x.shape(), x.dtype(), x.primitive_ptr(), x.inputs());
auto& x_parents = parents_map[x.id()];
auto& y_parents = parents_map[y.id()];
for (auto it = x_parents.begin(); it != x_parents.end();) {
if (divider.find(it->first.id()) != divider.end()) {
it->first.inputs()[it->second] = y;
y_parents.emplace_back(std::move(*it));
it = x_parents.erase(it);
} else {
it++;
}
}
return std::move(y);
}
template <typename T, typename... U> template <typename T, typename... U>
std::uintptr_t get_function_address(const std::function<T(U...)>& fun) { std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
using FunType = T (*)(U...); using FunType = T (*)(U...);
@@ -669,10 +693,16 @@ void compile_fuse(
} }
// Arrays with a mix of parents outside the compilable section // Arrays with a mix of parents outside the compilable section
// are not fusable // are not fusable except for broadcast which we can split to avoid
// stopping fusion
if (!all_parents_in) { if (!all_parents_in) {
// Possible input if (a.has_primitive() && is_broadcast(a.primitive())) {
input_set.insert(a.id()); array b = split_one(a, parents_map, cache);
recurse(b, depth, s, shape);
} else {
// Possible input
input_set.insert(a.id());
}
return; return;
} }

View File

@@ -5,16 +5,38 @@
namespace mlx::core { namespace mlx::core {
const char* dtype_to_string(Dtype arg) { const char* dtype_to_string(Dtype arg) {
if (arg == bool_) { switch (arg) {
return "bool"; case bool_:
return "bool";
case int8:
return "int8";
case int16:
return "int16";
case int32:
return "int32";
case int64:
return "int64";
case uint8:
return "uint8";
case uint16:
return "uint16";
case uint32:
return "uint32";
case uint64:
return "uint64";
case float16:
return "float16";
case bfloat16:
return "bfloat16";
case float32:
return "float32";
case float64:
return "float64";
case complex64:
return "complex64";
default:
return "unknown";
} }
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
if (DTYPE == arg) { \
return #DTYPE; \
}
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString)
#undef SPECIALIZE_DtypeToString
return "(unknown)";
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,207 +1,106 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
// Copyright © Meta Platforms, Inc. and affiliates.
//
// This source code is licensed under the BSD-style license found in
// https://github.com/pytorch/executorch/blob/main/LICENSE
//
// Forked from
// https://github.com/pytorch/executorch/blob/main/runtime/core/exec_aten/util/scalar_type_util.h
#pragma once #pragma once
#include "mlx/dtype.h" #include <sstream>
#include <fmt/format.h> #include "mlx/dtype.h"
#include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
// Return string representation of dtype. // Return string representation of dtype.
const char* dtype_to_string(Dtype arg); const char* dtype_to_string(Dtype arg);
// Macros that iterate across different subsets of Dtypes. #define MLX_INTERNAL_DTYPE_SWITCH_CASE(DTYPE, TYPE) \
// case DTYPE: \
// For all of these macros, the final `_` parameter is the name of another macro f(type_identity<TYPE>{}); \
// that takes two parameters: the name of a C type, and the name of the break
// corresponding Dtype enumerator.
//
// Note that these macros should use fully-qualified namespaces (starting with
// `::`) to ensure that they can be called safely in any arbitrary namespace.
#define MLX_FORALL_INT_TYPES(_) \
_(uint8_t, uint8) \
_(uint16_t, uint16) \
_(uint32_t, uint32) \
_(uint64_t, uint64) \
_(int8_t, int8) \
_(int16_t, int16) \
_(int32_t, int32) \
_(int64_t, int64)
#define MLX_FORALL_FLOAT_TYPES(_) \ #define MLX_INTERNAL_DTYPE_SWITCH_INTS() \
_(float16_t, float16) \ MLX_INTERNAL_DTYPE_SWITCH_CASE(int8, int8_t); \
_(float, float32) \ MLX_INTERNAL_DTYPE_SWITCH_CASE(int16, int16_t); \
_(double, float64) \ MLX_INTERNAL_DTYPE_SWITCH_CASE(int32, int32_t); \
_(bfloat16_t, bfloat16) MLX_INTERNAL_DTYPE_SWITCH_CASE(int64, int64_t); \
MLX_INTERNAL_DTYPE_SWITCH_CASE(uint8, uint8_t); \
MLX_INTERNAL_DTYPE_SWITCH_CASE(uint16, uint16_t); \
MLX_INTERNAL_DTYPE_SWITCH_CASE(uint32, uint32_t); \
MLX_INTERNAL_DTYPE_SWITCH_CASE(uint64, uint64_t)
// Calls the provided macro on every Dtype, providing the C type and the #define MLX_INTERNAL_DTYPE_SWITCH_FLOATS() \
// Dtype name to each call. MLX_INTERNAL_DTYPE_SWITCH_CASE(float16, float16_t); \
// MLX_INTERNAL_DTYPE_SWITCH_CASE(bfloat16, bfloat16_t); \
// @param _ A macro that takes two parameters: the name of a C type, and the MLX_INTERNAL_DTYPE_SWITCH_CASE(float32, float); \
// name of the corresponding Dtype enumerator. MLX_INTERNAL_DTYPE_SWITCH_CASE(float64, double)
#define MLX_FORALL_DTYPES(_) \
MLX_FORALL_INT_TYPES(_) \
MLX_FORALL_FLOAT_TYPES(_) \
_(bool, bool_) \
_(complex64_t, complex64)
// Maps Dtypes to C++ types. // This already exists in C++20 but in C++20 we can also just use templated
template <Dtype::Val N> // lambdas which will make this so much nicer.
struct DtypeToCppType;
#define SPECIALIZE_DtypeToCppType(CPP_TYPE, DTYPE) \
template <> \
struct DtypeToCppType<Dtype::Val::DTYPE> { \
using type = CPP_TYPE; \
};
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToCppType)
#undef SPECIALIZE_DtypeToCppType
// Maps C++ types to Dtypes.
template <typename T> template <typename T>
struct CppTypeToDtype; struct type_identity {
using type = T;
};
#define SPECIALIZE_CppTypeToDtype(CPP_TYPE, DTYPE) \ #define MLX_GET_TYPE(x) typename decltype(x)::type
template <> \ #define MLX_GET_VALUE(x) decltype(x)::value
struct CppTypeToDtype<CPP_TYPE> \
: std::integral_constant<Dtype::Val, Dtype::Val::DTYPE> {};
MLX_FORALL_DTYPES(SPECIALIZE_CppTypeToDtype) template <typename F>
void dispatch_all_types(Dtype dt, F&& f) {
#undef SPECIALIZE_CppTypeToDtype switch (dt) {
MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool);
// Helper macros for switch case macros (see below) MLX_INTERNAL_DTYPE_SWITCH_INTS();
// MLX_INTERNAL_DTYPE_SWITCH_FLOATS();
// These macros are not meant to be used directly. They provide an easy way to MLX_INTERNAL_DTYPE_SWITCH_CASE(complex64, complex64_t);
// generate a switch statement that can handle subsets of Dtypes supported.
#define MLX_INTERNAL_SWITCH_CASE(enum_type, CTYPE_ALIAS, ...) \
case enum_type: { \
using CTYPE_ALIAS = ::mlx::core::DtypeToCppType<enum_type>::type; \
__VA_ARGS__; \
break; \
} }
}
#define MLX_INTERNAL_SWITCH_CHECKED(TYPE, NAME, ...) \ template <typename F>
switch (TYPE) { \ void dispatch_int_types(Dtype dt, std::string_view tag, F&& f) {
__VA_ARGS__ \ switch (dt) {
default: \ MLX_INTERNAL_DTYPE_SWITCH_INTS();
throw std::invalid_argument(fmt::format( \ default:
"Unhandled dtype %s for %s", dtype_to_string(TYPE), NAME)); \ std::ostringstream msg;
msg << tag << " Only integer types supported but " << dt
<< " was provided";
throw std::invalid_argument(msg.str());
} }
}
#define MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, ...) \ template <typename F>
MLX_INTERNAL_SWITCH_CASE( \ void dispatch_float_types(Dtype dt, std::string_view tag, F&& f) {
::mlx::core::Dtype::Val::uint8, CTYPE_ALIAS, __VA_ARGS__) \ switch (dt) {
MLX_INTERNAL_SWITCH_CASE( \ MLX_INTERNAL_DTYPE_SWITCH_FLOATS();
::mlx::core::Dtype::Val::uint16, CTYPE_ALIAS, __VA_ARGS__) \ default:
MLX_INTERNAL_SWITCH_CASE( \ std::ostringstream msg;
::mlx::core::Dtype::Val::uint32, CTYPE_ALIAS, __VA_ARGS__) \ msg << tag << " Only float types supported but " << dt << " was provided";
MLX_INTERNAL_SWITCH_CASE( \ throw std::invalid_argument(msg.str());
::mlx::core::Dtype::Val::uint64, CTYPE_ALIAS, __VA_ARGS__) \ }
MLX_INTERNAL_SWITCH_CASE( \ }
::mlx::core::Dtype::Val::int8, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::int16, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::int32, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::int64, CTYPE_ALIAS, __VA_ARGS__)
#define MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, ...) \ template <typename F>
MLX_INTERNAL_SWITCH_CASE( \ void dispatch_int_float_types(Dtype dt, std::string_view tag, F&& f) {
::mlx::core::Dtype::Val::float16, CTYPE_ALIAS, __VA_ARGS__) \ switch (dt) {
MLX_INTERNAL_SWITCH_CASE( \ MLX_INTERNAL_DTYPE_SWITCH_INTS();
::mlx::core::Dtype::Val::float32, CTYPE_ALIAS, __VA_ARGS__) \ MLX_INTERNAL_DTYPE_SWITCH_FLOATS();
MLX_INTERNAL_SWITCH_CASE( \ default:
::mlx::core::Dtype::Val::float64, CTYPE_ALIAS, __VA_ARGS__) \ std::ostringstream msg;
MLX_INTERNAL_SWITCH_CASE( \ msg << tag << " Only integer and float types supported but " << dt
::mlx::core::Dtype::Val::bfloat16, CTYPE_ALIAS, __VA_ARGS__) << " was provided";
throw std::invalid_argument(msg.str());
}
}
#define MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, ...) \ template <typename F>
MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ void dispatch_real_types(Dtype dt, std::string_view tag, F&& f) {
MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) switch (dt) {
MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool);
#define MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, ...) \ MLX_INTERNAL_DTYPE_SWITCH_INTS();
MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ MLX_INTERNAL_DTYPE_SWITCH_FLOATS();
MLX_INTERNAL_SWITCH_CASE( \ default:
::mlx::core::Dtype::Val::bool_, CTYPE_ALIAS, __VA_ARGS__) std::ostringstream msg;
msg << tag << " Only real numbers supported but " << dt
#define MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, ...) \ << " was provided";
MLX_INTERNAL_SWITCH_CASE( \ throw std::invalid_argument(msg.str());
::mlx::core::Dtype::Val::complex64, CTYPE_ALIAS, __VA_ARGS__) }
}
#define MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, __VA_ARGS__)
// Switch case macros
//
// These macros provide an easy way to generate switch statements that apply a
// common lambda function to subsets of Dtypes supported by MLX.
// The lambda function can type specialize to the ctype associated with the
// Dtype being handled through an alias passed as the CTYPE_ALIAS argument.
//
// Arguments:
// - ADDITIONAL: Additional Dtype case to add
// - TYPE: The Dtype to handle through the switch statement
// - NAME: A name for this operation which will be used in error messages
// - CTYPE_ALIAS: A typedef for the ctype associated with the Dtype.
// - ...: A statement to be applied to each Dtype case
//
// An example usage is:
//
// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE, {
// output.data<CTYPE>[0] = input.data<CTYPE>[0];
// });
//
// Note that these can be nested as well:
//
// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE_IN, {
// MLX_SWITCH_ALL_TYPES(output.dtype(), CTYPE_OUT, {
// output.data<CTYPE_OUT>[0] = input.data<CTYPE_IN>[0];
// });
// });
//
// These macros are adapted from Dispatch.h in the ATen library. The primary
// difference is that the CTYPE_ALIAS argument is exposed to users, which is
// used to alias the ctype associated with the Dtype that is being handled.
#define MLX_SWITCH_ALL_TYPES(TYPE, CTYPE_ALIAS, ...) \
switch (TYPE) { MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, __VA_ARGS__) }
#define MLX_SWITCH_INT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CHECKED( \
TYPE, \
NAME, \
MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
#define MLX_SWITCH_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CHECKED( \
TYPE, \
NAME, \
MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
#define MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CHECKED( \
TYPE, \
NAME, \
MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
#define MLX_SWITCH_REAL_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CHECKED( \
TYPE, \
NAME, \
MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__))
} // namespace mlx::core } // namespace mlx::core

View File

@@ -253,7 +253,9 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {
std::ostream& operator<<(std::ostream& os, array a) { std::ostream& operator<<(std::ostream& os, array a) {
a.eval(); a.eval();
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE, print_array<CTYPE>(os, a)); dispatch_all_types(a.dtype(), [&](auto type_tag) {
print_array<MLX_GET_TYPE(type_tag)>(os, a);
});
return os; return os;
} }
@@ -321,8 +323,9 @@ void set_iinfo_limits(int64_t& min, uint64_t& max) {
} }
iinfo::iinfo(Dtype dtype) : dtype(dtype) { iinfo::iinfo(Dtype dtype) : dtype(dtype) {
MLX_SWITCH_INT_TYPES_CHECKED( dispatch_int_types(dtype, "[iinfo]", [&](auto type_tag) {
dtype, "[iinfo]", CTYPE, set_iinfo_limits<CTYPE>(min, max)); set_iinfo_limits<MLX_GET_TYPE(type_tag)>(min, max);
});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -4,7 +4,7 @@
#define MLX_VERSION_MAJOR 0 #define MLX_VERSION_MAJOR 0
#define MLX_VERSION_MINOR 26 #define MLX_VERSION_MINOR 26
#define MLX_VERSION_PATCH 1 #define MLX_VERSION_PATCH 2
#define MLX_VERSION_NUMERIC \ #define MLX_VERSION_NUMERIC \
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)

View File

@@ -546,7 +546,7 @@ class GELU(Module):
See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the
functional equivalents and information regarding error bounds. functional equivalents and information regarding error bounds.
Args: Args:
approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any. approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any.
@@ -554,20 +554,19 @@ class GELU(Module):
def __init__(self, approx="none"): def __init__(self, approx="none"):
super().__init__() super().__init__()
self._approx = approx
if approx == "none": allowed = ["none", "precise", "tanh", "fast"]
self._act = gelu if approx not in allowed:
elif approx == "precise" or approx == "tanh":
self._act = gelu_approx
elif approx == "fast":
self._act = gelu_fast_approx
else:
raise ValueError( raise ValueError(
f"The approximation should be in ['none', 'precise', 'tanh', 'fast'] but '{approx}' was given" f"The approximation should be in {allowed} but '{approx}' was given"
) )
def __call__(self, x): def __call__(self, x):
return self._act(x) if self._approx == "none":
return gelu(x)
elif self._approx in ["precise", "tanh"]:
return gelu_approx(x)
return gelu_fast_approx(x)
@_make_activation_module(tanh) @_make_activation_module(tanh)

View File

@@ -404,7 +404,7 @@ class Module(dict):
dst[k] = new_value dst[k] = new_value
elif isinstance(current_value, (dict, list)): elif isinstance(current_value, (dict, list)):
apply(current_value, new_value) apply(current_value, new_value)
elif strict: elif strict and new_value != {}:
raise ValueError( raise ValueError(
f"Received invalid type: {type(new_value).__name__}." f"Received invalid type: {type(new_value).__name__}."
) )
@@ -413,14 +413,14 @@ class Module(dict):
f'Module does not have sub-module named "{k}".' f'Module does not have sub-module named "{k}".'
) )
elif isinstance(modules, list): elif isinstance(modules, list):
for i in range(len(dst)): for i in range(len(modules)):
current_value = dst[i] current_value = dst[i]
new_value = modules[i] new_value = modules[i]
if self.is_module(current_value) and self.is_module(new_value): if self.is_module(current_value) and self.is_module(new_value):
dst[i] = new_value dst[i] = new_value
elif isinstance(current_value, (dict, list)): elif isinstance(current_value, (dict, list)):
apply(current_value, new_value) apply(current_value, new_value)
elif strict: elif strict and new_value != {}:
raise ValueError( raise ValueError(
f"Received invalid type: {type(new_value).__name__}." f"Received invalid type: {type(new_value).__name__}."
) )

View File

@@ -0,0 +1,17 @@
#!/bin/bash
auditwheel repair dist/* \
--plat manylinux_2_35_x86_64 \
--exclude libcublas* \
--exclude libnvrtc*
cd wheelhouse
repaired_wheel=$(find . -name "*.whl" -print -quit)
unzip -q "${repaired_wheel}"
core_so=$(find mlx -name "core*.so" -print -quit)
rpath=$(patchelf --print-rpath "${core_so}")
rpath=$rpath:\$ORIGIN/../nvidia/cublas/lib:\$ORIGIN/../nvidia/cuda_nvrtc/lib
patchelf --force-rpath --set-rpath "$rpath" "$core_so"
# Re-zip the repaired wheel
zip -r -q "${repaired_wheel}" .

View File

@@ -205,6 +205,8 @@ nb::object to_scalar(mx::array& a) {
return nb::cast(static_cast<float>(a.item<mx::bfloat16_t>())); return nb::cast(static_cast<float>(a.item<mx::bfloat16_t>()));
case mx::complex64: case mx::complex64:
return nb::cast(a.item<std::complex<float>>()); return nb::cast(a.item<std::complex<float>>());
case mx::float64:
return nb::cast(a.item<double>());
default: default:
throw nb::type_error("type cannot be converted to Python scalar."); throw nb::type_error("type cannot be converted to Python scalar.");
} }

View File

@@ -1,43 +1,15 @@
cuda_skip = { cuda_skip = {
"TestArray.test_api",
"TestAutograd.test_update_state",
"TestBF16.test_arg_reduction_ops",
"TestBF16.test_reduction_ops",
"TestBlas.test_complex_gemm",
"TestCompile.test_compile_dynamic_dims",
"TestEinsum.test_ellipses",
"TestEinsum.test_opt_einsum_test_cases",
"TestLoad.test_load_f8_e4m3", "TestLoad.test_load_f8_e4m3",
"TestMemory.test_memory_info",
"TestLayers.test_group_norm",
"TestLayers.test_pooling",
"TestLayers.test_quantized_embedding", "TestLayers.test_quantized_embedding",
"TestLayers.test_sin_pe",
"TestLayers.test_upsample",
"TestOps.test_array_equal",
"TestOps.test_complex_ops",
"TestOps.test_dynamic_slicing", "TestOps.test_dynamic_slicing",
"TestOps.test_softmax",
"TestOps.test_sort",
"TestOps.test_tile",
"TestReduce.test_axis_permutation_sums",
"TestReduce.test_dtypes", "TestReduce.test_dtypes",
"TestReduce.test_expand_sums",
"TestReduce.test_many_reduction_axes",
"TestUpsample.test_torch_upsample",
# DivMod NYI
"TestOps.test_divmod",
"TestEval.test_multi_output_eval_during_transform",
# Partition NYI
"TestAutograd.test_topk_grad",
"TestOps.test_argpartition",
"TestOps.test_partition",
# Block masked matmul NYI # Block masked matmul NYI
"TestBlas.test_block_masked_matmul", "TestBlas.test_block_masked_matmul",
# Gather matmul NYI # Gather matmul NYI
"TestBlas.test_gather_matmul", "TestBlas.test_gather_matmul",
"TestBlas.test_gather_matmul_grad", "TestBlas.test_gather_matmul_grad",
# Scan NYI # Scan NYI
"TestArray.test_api",
"TestAutograd.test_cumprod_grad", "TestAutograd.test_cumprod_grad",
"TestOps.test_scans", "TestOps.test_scans",
"TestOps.test_logcumsumexp", "TestOps.test_logcumsumexp",

View File

@@ -1,6 +1,10 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import os import os
# Use regular fp32 precision for tests
os.environ["MLX_ENABLE_TF32"] = "0"
import platform import platform
import unittest import unittest
from typing import Any, Callable, List, Tuple, Union from typing import Any, Callable, List, Tuple, Union

View File

@@ -2,8 +2,10 @@
import gc import gc
import io import io
import math
import unittest import unittest
from functools import partial from functools import partial
from io import StringIO
import mlx.core as mx import mlx.core as mx
import mlx_tests import mlx_tests
@@ -979,6 +981,39 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertEqual(mem_pre, mem_post) self.assertEqual(mem_pre, mem_post)
def test_double_constant(self):
with mx.stream(mx.cpu):
x = mx.array(1.0, dtype=mx.float64)
def fun(x):
return (x + math.pi) * 2.0
y = fun(x).item()
y_compiled = mx.compile(fun)(x).item()
self.assertEqual(y, y_compiled)
def test_shared_broadcast(self):
def fun(x, y, z):
yy = mx.broadcast_to(y, z.shape)
return (x + yy * z), yy.sum()
a = mx.random.normal((10, 10))
b = mx.array(0.1)
c = mx.random.normal((10, 10))
mx.eval(a, b, c)
fc = mx.compile(fun)
d = fc(a, b, c)
s = StringIO()
mx.export_to_dot(s, a=a, b=b, c=c, d1=d[0], d2=d[1])
s.seek(0)
s = s.read()
self.assertTrue("CompiledBroadcastMultiplyAdd" in s)
d_hat = fun(a, b, c)
self.assertTrue(mx.allclose(d[0], d_hat[0]))
self.assertTrue(mx.allclose(d[1], d_hat[1]))
if __name__ == "__main__": if __name__ == "__main__":
mlx_tests.MLXTestRunner() mlx_tests.MLXTestRunner()

View File

@@ -259,6 +259,21 @@ class TestBase(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
m = m.update_modules({"list": ["hi"]}) m = m.update_modules({"list": ["hi"]})
# Allow updating a strict subset
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
m.update_modules({"layers": [{}, nn.Linear(3, 4)]})
self.assertEqual(m.layers[1].weight.shape, (4, 3))
# Using leaf_modules in the update should always work
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.stuff = [nn.Linear(2, 2), 0, nn.Linear(2, 2)]
self.more_stuff = {"hi": nn.Linear(2, 2), "bye": 0}
m = MyModel()
m.update_modules(m.leaf_modules())
class TestLayers(mlx_tests.MLXTestCase): class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self): def test_identity(self):

View File

@@ -174,20 +174,26 @@ if __name__ == "__main__":
) )
package_dir = {"": "python"} package_dir = {"": "python"}
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]} package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]}
install_requires = []
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
if build_cuda:
install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"]
setup( setup(
name="mlx", name="mlx-cuda" if build_cuda else "mlx",
version=get_version(), version=get_version(),
author="MLX Contributors", author="MLX Contributors",
author_email="mlx@group.apple.com", author_email="mlx@group.apple.com",
description="A framework for machine learning on Apple silicon.", description="A framework for machine learning on Apple silicon.",
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
license="MIT",
url="https://github.com/ml-explore/mlx", url="https://github.com/ml-explore/mlx",
packages=packages, packages=packages,
package_dir=package_dir, package_dir=package_dir,
package_data=package_data, package_data=package_data,
include_package_data=True, include_package_data=True,
install_requires=install_requires,
extras_require={ extras_require={
"dev": [ "dev": [
"nanobind==2.4.0", "nanobind==2.4.0",