Compare commits

..

10 Commits

Author SHA1 Message Date
Awni Hannun
fb4e8b896b patch bump (#2343) 2025-07-08 14:26:07 -07:00
Cheng
2ca533b279 Fix compilation with CUDA 11 (#2331) 2025-07-07 20:00:43 -07:00
Angelos Katharopoulos
4a9b29a875 MoE backward improvements (#2335) 2025-07-07 17:59:53 -07:00
Awni Hannun
a4fcc893cd auto build linux release (#2341) 2025-07-07 09:29:23 -07:00
Cheng
9d10239af7 [CUDA] Do vectorized store/load in binary ops (#2330) 2025-07-07 08:44:14 -07:00
Cheng
19facd4b20 Build with all cpu cores by default (#2336) 2025-07-07 06:06:45 -07:00
Angelos Katharopoulos
f5299f72cd Fix layernorm race condition (#2340) 2025-07-07 06:06:01 -07:00
Cheng
0e0d9ac522 [CUDA] Add MLX_CUDA_GRAPH_CACHE_SIZE env for setting graph cache size (#2329) 2025-07-05 08:33:29 -07:00
Awni Hannun
8917022deb fix graphs for older cuda (#2328) 2025-07-02 19:37:58 -07:00
Awni Hannun
ec0d5db67b [CUDA] Switch to CUDA graphs (#2317)
* cuda graph prototype

fix signal bug + start to add dependencies

capture more

capture more ops

remaining ops

fix reduce and rope deps

add concurrent context

try update, but not working

cosistent topology order

use node api

use node api directly to reduce overhead

fix bug

use kernels in unary

cache graph

format

fix synchronization

format

* comment
2025-07-02 15:59:13 -07:00
47 changed files with 1736 additions and 1316 deletions

View File

@@ -41,7 +41,7 @@ jobs:
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install -r docs/requirements.txt pip install -r docs/requirements.txt
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v pip install . -v
- when: - when:
condition: condition:
not: << parameters.upload-docs >> not: << parameters.upload-docs >>
@@ -97,10 +97,8 @@ jobs:
name: Install Python package name: Install Python package
command: | command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \ CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py build_ext --inplace python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \ CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop python3 setup.py develop
- run: - run:
name: Generate package stubs name: Generate package stubs
@@ -157,8 +155,7 @@ jobs:
name: Install Python package name: Install Python package
command: | command: |
source env/bin/activate source env/bin/activate
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
pip install -e . -v pip install -e . -v
- run: - run:
name: Generate package stubs name: Generate package stubs
@@ -208,8 +205,7 @@ jobs:
name: Run Python tests with JIT name: Run Python tests with JIT
command: | command: |
source env/bin/activate source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \ LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \ METAL_DEBUG_ERROR_MODE=0 \
@@ -228,8 +224,7 @@ jobs:
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
python -m venv env python -m venv env
source env/bin/activate source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install -e ".[dev]" pip install -e ".[dev]"
- run: - run:
name: Run Python tests name: Run Python tests
@@ -278,7 +273,6 @@ jobs:
command: | command: |
source env/bin/activate source env/bin/activate
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \ env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
pip install . -v pip install . -v
- run: - run:
name: Generate package stubs name: Generate package stubs
@@ -290,9 +284,7 @@ jobs:
name: Build Python package name: Build Python package
command: | command: |
source env/bin/activate source env/bin/activate
<< parameters.build_env >> \ << parameters.build_env >> python -m build -w
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
python -m build -w
- when: - when:
condition: << parameters.build_env >> condition: << parameters.build_env >>
steps: steps:
@@ -340,14 +332,10 @@ jobs:
pip install patchelf pip install patchelf
pip install build pip install build
pip install twine pip install twine
<< parameters.extra_env >> \ << parameters.extra_env >> pip install . -v
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
pip install . -v
pip install typing_extensions pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
<< parameters.extra_env >> \ << parameters.extra_env >> python -m build --wheel
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python -m build --wheel
auditwheel show dist/* auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64 auditwheel repair dist/* --plat manylinux_2_31_x86_64
- run: - run:
@@ -383,12 +371,10 @@ jobs:
pip install build pip install build
pip install twine pip install twine
<< parameters.extra_env >> \ << parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install ".[dev]" -v pip install ".[dev]" -v
python setup.py generate_stubs python setup.py generate_stubs
<< parameters.extra_env >> \ << parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build --wheel python -m build --wheel
bash python/scripts/repair_cuda.sh bash python/scripts/repair_cuda.sh
@@ -506,6 +492,16 @@ workflows:
branches: branches:
ignore: /.*/ ignore: /.*/
upload-docs: true upload-docs: true
- build_linux_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]
prb: prb:
when: when:

View File

@@ -88,20 +88,20 @@ Then simply build and install MLX using pip:
.. code-block:: shell .. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install . pip install .
For developing, install the package with development dependencies, and use an For developing, install the package with development dependencies, and use an
editable install: editable install:
.. code-block:: shell .. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]" pip install -e ".[dev]"
Once the development dependencies are installed, you can build faster with: Once the development dependencies are installed, you can build faster with:
.. code-block:: shell .. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace python setup.py build_ext --inplace
Run the tests with: Run the tests with:
@@ -262,7 +262,7 @@ When building either the Python or C++ APIs make sure to pass the cmake flag
.. code-block:: shell .. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]" CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
To build the C++ package run: To build the C++ package run:

View File

@@ -12,16 +12,11 @@ namespace mlx::core {
inline std::tuple<Shape, Strides, Strides> collapse_batches( inline std::tuple<Shape, Strides, Strides> collapse_batches(
const array& a, const array& a,
const array& b) { const array& b) {
// Get and check the shape for the batched dims if (a.ndim() == 2) {
Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; return {{1}, {0}, {0}};
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
if (A_bshape != B_bshape) {
std::ostringstream msg;
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
<< a.shape() << ", B " << b.shape() << ".";
throw std::runtime_error(msg.str());
} }
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
@@ -42,17 +37,11 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
inline std::tuple<Shape, Strides, Strides, Strides> inline std::tuple<Shape, Strides, Strides, Strides>
collapse_batches(const array& a, const array& b, const array& c) { collapse_batches(const array& a, const array& b, const array& c) {
// Get and check the shape for the batched dims if (a.ndim() == 2) {
Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; return {{1}, {0}, {0}, {0}};
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
Shape C_bshape{c.shape().begin(), c.shape().end() - 2};
if (A_bshape != B_bshape || A_bshape != C_bshape) {
std::ostringstream msg;
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
<< a.shape() << ", B " << b.shape() << ", B " << c.shape() << ".";
throw std::runtime_error(msg.str());
} }
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
Strides C_bstride{c.strides().begin(), c.strides().end() - 2}; Strides C_bstride{c.strides().begin(), c.strides().end() - 2};

View File

@@ -1,6 +1,7 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh" #include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
@@ -151,30 +152,29 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) { using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; constexpr uint32_t N_READS = 4;
constexpr uint32_t N_READS = 4; dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
dispatch_block_dim( dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { auto kernel =
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); cu::arg_reduce_general<T, cu::ArgMax<T>, block_dim(), N_READS>;
auto kernel = if (reduce_type_ == ArgReduce::ArgMin) {
cu::arg_reduce_general<T, cu::ArgMax<T>, block_dim(), N_READS>; kernel = cu::arg_reduce_general<T, cu::ArgMin<T>, block_dim(), N_READS>;
if (reduce_type_ == ArgReduce::ArgMin) { }
kernel = cu:: encoder.add_kernel_node(
arg_reduce_general<T, cu::ArgMin<T>, block_dim(), N_READS>; kernel,
} num_blocks,
kernel<<<num_blocks, block_dim(), 0, stream>>>( block_dim(),
in.data<T>(), in.data<T>(),
out.data<uint32_t>(), out.data<uint32_t>(),
out.size(), out.size(),
const_param(shape), const_param(shape),
const_param(in_strides), const_param(in_strides),
const_param(out_strides), const_param(out_strides),
ndim, ndim,
axis_stride, axis_stride,
axis_size); axis_size);
});
}); });
}); });
} }

View File

@@ -17,35 +17,106 @@ namespace cu {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT> template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) { __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) { int remaining = size - index * N_READS;
out[index] = Op{}(a[0], b[0]); if (remaining <= 0) {
return;
}
if (remaining < N_READS) {
for (int i = 0; i < remaining; ++i) {
IdxT offset = index * N_READS + i;
out[offset] = Op{}(a[0], b[0]);
}
} else {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a[0], b[0]);
}
store_vector<N_READS>(out, index, out_vec);
} }
} }
template <typename Op, typename In, typename Out, typename IdxT> template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) { __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) { int remaining = size - index * N_READS;
out[index] = Op{}(a[0], b[index]); if (remaining <= 0) {
return;
}
if (remaining < N_READS) {
for (int i = 0; i < remaining; ++i) {
IdxT offset = index * N_READS + i;
out[offset] = Op{}(a[0], b[offset]);
}
} else {
auto b_vec = load_vector<N_READS>(b, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a[0], b_vec.val[i]);
}
store_vector<N_READS>(out, index, out_vec);
} }
} }
template <typename Op, typename In, typename Out, typename IdxT> template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) { __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) { int remaining = size - index * N_READS;
out[index] = Op{}(a[index], b[0]); if (remaining <= 0) {
return;
}
if (remaining < N_READS) {
for (int i = 0; i < remaining; ++i) {
IdxT offset = index * N_READS + i;
out[offset] = Op{}(a[offset], b[0]);
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a_vec.val[i], b[0]);
}
store_vector<N_READS>(out, index, out_vec);
} }
} }
template <typename Op, typename In, typename Out, typename IdxT> template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) { __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) { int remaining = size - index * N_READS;
out[index] = Op{}(a[index], b[index]); if (remaining <= 0) {
return;
}
if (remaining < N_READS) {
for (int i = 0; i < remaining; ++i) {
IdxT offset = index * N_READS + i;
out[offset] = Op{}(a[offset], b[offset]);
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
auto b_vec = load_vector<N_READS>(b, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i]);
}
store_vector<N_READS>(out, index, out_vec);
} }
} }
@@ -139,90 +210,99 @@ void binary_op_gpu_inplace(
encoder.set_input_array(a); encoder.set_input_array(a);
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
dispatch_all_types(a.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) { using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_IN = MLX_GET_TYPE(in_type_tag); using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) { using InType = cuda_type_t<CTYPE_IN>;
using InType = cuda_type_t<CTYPE_IN>; using OutType = cuda_type_t<CTYPE_OUT>;
using OutType = cuda_type_t<CTYPE_OUT>; auto bopt = get_binary_op_type(a, b);
auto bopt = get_binary_op_type(a, b); if (bopt == BinaryOpType::General) {
if (bopt == BinaryOpType::General) { dispatch_bool(
dispatch_bool( a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
out.data_size() > INT32_MAX, [&](auto large) {
[&](auto large) { using IdxT = std::conditional_t<large(), int64_t, int32_t>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>; Shape shape;
Shape shape; std::vector<Strides> strides;
std::vector<Strides> strides; std::tie(shape, strides) = collapse_contiguous_dims(a, b, out);
std::tie(shape, strides) = auto& a_strides = strides[0];
collapse_contiguous_dims(a, b, out); auto& b_strides = strides[1];
auto& a_strides = strides[0]; int ndim = shape.size();
auto& b_strides = strides[1]; if (ndim <= 3) {
int ndim = shape.size(); dispatch_1_2_3(ndim, [&](auto dims_constant) {
if (ndim <= 3) { auto kernel = cu::
dispatch_1_2_3(ndim, [&](auto dims_constant) { binary_g_nd<Op, InType, OutType, IdxT, dims_constant()>;
auto kernel = cu::binary_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides));
});
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large()); get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>( encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),
out.data<OutType>(), out.data<OutType>(),
out.size(), out.size(),
const_param(shape), const_param<dims_constant()>(shape),
const_param(a_strides), const_param<dims_constant()>(a_strides),
const_param(b_strides), const_param<dims_constant()>(b_strides));
ndim); });
} } else {
}); auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
} else { auto [num_blocks, block_dims] =
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { get_launch_args(kernel, out, large());
using IdxT = std::conditional_t<large(), int64_t, uint32_t>; encoder.add_kernel_node(
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>; kernel,
if (bopt == BinaryOpType::ScalarVector) { num_blocks,
kernel = cu::binary_sv<Op, InType, OutType, IdxT>; block_dims,
} else if (bopt == BinaryOpType::VectorScalar) { a.data<InType>(),
kernel = cu::binary_vs<Op, InType, OutType, IdxT>; b.data<InType>(),
} else if (bopt == BinaryOpType::VectorVector) { out.data<OutType>(),
kernel = cu::binary_vv<Op, InType, OutType, IdxT>; out.size(),
} const_param(shape),
auto [num_blocks, block_dims] = get_launch_args( const_param(a_strides),
kernel, out.data_size(), out.shape(), out.strides(), large()); const_param(b_strides),
kernel<<<num_blocks, block_dims, 0, stream>>>( ndim);
a.data<InType>(), }
b.data<InType>(), });
out.data<OutType>(),
out.data_size());
});
}
} else { } else {
throw std::runtime_error(fmt::format( dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
"Can not do binary op {} on inputs of {} with result of {}.", using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
op, // TODO: Choose optimized value based on type size.
dtype_to_string(a.dtype()), constexpr int N_READS = 4;
dtype_to_string(out.dtype()))); auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_sv<Op, InType, OutType, IdxT, N_READS>;
} else if (bopt == BinaryOpType::VectorScalar) {
kernel = cu::binary_vs<Op, InType, OutType, IdxT, N_READS>;
} else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_vv<Op, InType, OutType, IdxT, N_READS>;
}
auto [num_blocks, block_dims] = get_launch_args(
kernel,
out.data_size(),
out.shape(),
out.strides(),
large(),
N_READS);
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size());
});
} }
}); } else {
throw std::runtime_error(fmt::format(
"Can not do binary op {} on inputs of {} with result of {}.",
op,
dtype_to_string(a.dtype()),
dtype_to_string(out.dtype())));
}
}); });
}); });
} }

View File

@@ -137,98 +137,101 @@ void binary_op_gpu_inplace(
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out_a); encoder.set_output_array(out_a);
encoder.set_output_array(out_b); encoder.set_output_array(out_b);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
dispatch_all_types(a.dtype(), [&](auto in_type_tag) { dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) {
dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) { using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_IN = MLX_GET_TYPE(in_type_tag); using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) { using InType = cuda_type_t<CTYPE_IN>;
using InType = cuda_type_t<CTYPE_IN>; using OutType = cuda_type_t<CTYPE_OUT>;
using OutType = cuda_type_t<CTYPE_OUT>;
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
if (bopt == BinaryOpType::General) { if (bopt == BinaryOpType::General) {
dispatch_bool( dispatch_bool(
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
out_a.data_size() > INT32_MAX, out_a.data_size() > INT32_MAX,
[&](auto large) { [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, int32_t>; using IdxT = std::conditional_t<large(), int64_t, int32_t>;
Shape shape; Shape shape;
std::vector<Strides> strides; std::vector<Strides> strides;
std::tie(shape, strides) = std::tie(shape, strides) =
collapse_contiguous_dims(a, b, out_a); collapse_contiguous_dims(a, b, out_a);
auto& a_strides = strides[0]; auto& a_strides = strides[0];
auto& b_strides = strides[1]; auto& b_strides = strides[1];
int ndim = shape.size(); int ndim = shape.size();
if (ndim <= 3) { if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu::binary_g_nd< auto kernel = cu::
Op, binary_g_nd<Op, InType, OutType, IdxT, dims_constant()>;
InType,
OutType,
IdxT,
dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides));
});
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large()); get_launch_args(kernel, out_a, large());
kernel<<<num_blocks, block_dims, 0, stream>>>( encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),
out_a.data<OutType>(), out_a.data<OutType>(),
out_b.data<OutType>(), out_b.data<OutType>(),
out_a.size(), out_a.size(),
const_param(shape), const_param<dims_constant()>(shape),
const_param(a_strides), const_param<dims_constant()>(a_strides),
const_param(b_strides), const_param<dims_constant()>(b_strides));
ndim); });
} } else {
}); auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
} else { auto [num_blocks, block_dims] =
dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) { get_launch_args(kernel, out_a, large());
using IdxT = std::conditional_t<large(), int64_t, uint32_t>; encoder.add_kernel_node(
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>; kernel,
if (bopt == BinaryOpType::ScalarVector) { num_blocks,
kernel = cu::binary_sv<Op, InType, OutType, IdxT>; block_dims,
} else if (bopt == BinaryOpType::VectorScalar) { a.data<InType>(),
kernel = cu::binary_vs<Op, InType, OutType, IdxT>; b.data<InType>(),
} else if (bopt == BinaryOpType::VectorVector) { out_a.data<OutType>(),
kernel = cu::binary_vv<Op, InType, OutType, IdxT>; out_b.data<OutType>(),
} out_a.size(),
auto [num_blocks, block_dims] = get_launch_args( const_param(shape),
kernel, const_param(a_strides),
out_a.data_size(), const_param(b_strides),
out_a.shape(), ndim);
out_a.strides(), }
large()); });
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.data_size());
});
}
} else { } else {
throw std::runtime_error(fmt::format( dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) {
"Can not do binary op {} on inputs of {} with result of {}.", using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
op, auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
dtype_to_string(a.dtype()), if (bopt == BinaryOpType::ScalarVector) {
dtype_to_string(out_a.dtype()))); kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorScalar) {
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] = get_launch_args(
kernel,
out_a.data_size(),
out_a.shape(),
out_a.strides(),
large());
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.data_size());
});
} }
}); } else {
throw std::runtime_error(fmt::format(
"Can not do binary op {} on inputs of {} with result of {}.",
op,
dtype_to_string(a.dtype()),
dtype_to_string(out_a.dtype())));
}
}); });
}); });
} }

View File

@@ -3,6 +3,7 @@
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/graph_utils.h" #include "mlx/graph_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -178,6 +179,7 @@ void Compiled::eval_gpu(
// Whether to use large index. // Whether to use large index.
bool large = compiled_use_large_index(inputs, outputs, contiguous); bool large = compiled_use_large_index(inputs, outputs, contiguous);
cu::KernelArgs args;
// Put inputs. // Put inputs.
int strides_index = 1; int strides_index = 1;
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
@@ -185,26 +187,26 @@ void Compiled::eval_gpu(
continue; continue;
} }
const auto& x = inputs[i]; const auto& x = inputs[i];
mod.append_arg(x); args.append(x);
if (!contiguous && !is_scalar(x)) { if (!contiguous && !is_scalar(x)) {
mod.append_arg(strides_vec[strides_index++]); args.append_ptr(strides_vec[strides_index++].data());
} }
} }
// Put outputs. // Put outputs.
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous); compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
for (auto& x : outputs) { for (auto& x : outputs) {
mod.append_arg(x); args.append(x);
} }
// Put shape and size. // Put shape and size.
if (!contiguous) { if (!contiguous) {
mod.append_arg(shape); args.append_ptr(shape.data());
} }
if (large) { if (large) {
mod.append_arg<int64_t>(outputs[0].data_size()); args.append<int64_t>(outputs[0].data_size());
} else { } else {
mod.append_arg<uint32_t>(outputs[0].data_size()); args.append<uint32_t>(outputs[0].data_size());
} }
// Launch kernel. // Launch kernel.
@@ -222,9 +224,10 @@ void Compiled::eval_gpu(
for (const auto& out : outputs) { for (const auto& out : outputs) {
encoder.set_output_array(out); encoder.set_output_array(out);
} }
encoder.launch_kernel([&](cudaStream_t stream) {
mod.launch_kernel(stream, kernel_name, outputs[0], large); auto kernel = mod.get_kernel(kernel_name);
}); auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -35,24 +35,25 @@ void copy_contiguous(
array& out, array& out,
int64_t in_offset, int64_t in_offset,
int64_t out_offset) { int64_t out_offset) {
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) { dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>; using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>; using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
using IdxT = std::conditional_t<large(), int64_t, uint32_t>; auto kernel = cu::copy_s<InType, OutType, IdxT>;
auto kernel = cu::copy_s<InType, OutType, IdxT>; if (ctype == CopyType::Vector) {
if (ctype == CopyType::Vector) { kernel = cu::copy_v<InType, OutType, IdxT>;
kernel = cu::copy_v<InType, OutType, IdxT>; }
} auto [num_blocks, block_dims] = get_launch_args(
auto [num_blocks, block_dims] = get_launch_args( kernel, out.data_size(), out.shape(), out.strides(), large());
kernel, out.data_size(), out.shape(), out.strides(), large()); encoder.add_kernel_node(
kernel<<<num_blocks, block_dims, 0, stream>>>( kernel,
in.data<InType>() + in_offset, num_blocks,
out.data<OutType>() + out_offset, block_dims,
out.data_size()); in.data<InType>() + in_offset,
}); out.data<OutType>() + out_offset,
out.data_size());
}); });
}); });
}); });

View File

@@ -55,50 +55,54 @@ void copy_general(
const Shape& shape, const Shape& shape,
const Strides& strides_in, const Strides& strides_in,
const Strides& strides_out) { const Strides& strides_out) {
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) { dispatch_bool(
dispatch_bool( in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, [&](auto large) {
[&](auto large) { using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>; using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>; using IdxT = std::conditional_t<large(), int64_t, int32_t>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>; const InType* in_ptr = in.data<InType>() + offset_in;
const InType* in_ptr = in.data<InType>() + offset_in; OutType* out_ptr = out.data<OutType>() + offset_out;
OutType* out_ptr = out.data<OutType>() + offset_out; int ndim = shape.size();
int ndim = shape.size(); size_t data_size = 1;
size_t data_size = 1; for (auto& s : shape)
for (auto& s : shape) data_size *= s;
data_size *= s; if (ndim <= 3) {
if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto ndim_constant) {
dispatch_1_2_3(ndim, [&](auto ndim_constant) { auto kernel =
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>;
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>;
auto [num_blocks, block_dims] = get_launch_args(
kernel, data_size, shape, out.strides(), large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
data_size,
const_param<ndim_constant()>(shape),
const_param<ndim_constant()>(strides_in),
const_param<ndim_constant()>(strides_out));
});
} else { // ndim >= 4
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args( auto [num_blocks, block_dims] = get_launch_args(
kernel, data_size, shape, out.strides(), large()); kernel, data_size, shape, out.strides(), large());
kernel<<<num_blocks, block_dims, 0, stream>>>( encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
in_ptr, in_ptr,
out_ptr, out_ptr,
data_size, data_size,
const_param(shape), const_param<ndim_constant()>(shape),
const_param(strides_in), const_param<ndim_constant()>(strides_in),
const_param(strides_out), const_param<ndim_constant()>(strides_out));
ndim); });
} } else { // ndim >= 4
}); auto kernel = cu::copy_gg<InType, OutType, IdxT>;
}); auto [num_blocks, block_dims] = get_launch_args(
kernel, data_size, shape, out.strides(), large());
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
in_ptr,
out_ptr,
data_size,
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim);
}
});
}); });
}); });
} }

View File

@@ -61,54 +61,55 @@ void copy_general_dynamic(
const Strides& strides_out, const Strides& strides_out,
const array& dynamic_offset_in, const array& dynamic_offset_in,
const array& dynamic_offset_out) { const array& dynamic_offset_out) {
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) { dispatch_bool(
dispatch_bool( in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, [&](auto large) {
[&](auto large) { using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>; using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>; using IdxT = std::conditional_t<large(), int64_t, int32_t>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>; const InType* in_ptr = in.data<InType>() + offset_in;
const InType* in_ptr = in.data<InType>() + offset_in; OutType* out_ptr = out.data<OutType>() + offset_out;
OutType* out_ptr = out.data<OutType>() + offset_out; int ndim = shape.size();
int ndim = shape.size(); if (ndim <= 3) {
if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
dispatch_1_2_3(ndim, [&](auto dims_constant) { auto kernel = cu::
auto kernel = cu::copy_gg_dynamic_nd< copy_gg_dynamic_nd<InType, OutType, IdxT, dims_constant()>;
InType,
OutType,
IdxT,
dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(strides_in),
const_param<dims_constant()>(strides_out),
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
});
} else { // ndim >= 4
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large()); get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>( encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
in_ptr, in_ptr,
out_ptr, out_ptr,
out.size(), out.size(),
const_param(shape), const_param<dims_constant()>(shape),
const_param(strides_in), const_param<dims_constant()>(strides_in),
const_param(strides_out), const_param<dims_constant()>(strides_out),
ndim,
dynamic_offset_in.data<int64_t>(), dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>()); dynamic_offset_out.data<int64_t>());
} });
}); } else { // ndim >= 4
}); auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
in_ptr,
out_ptr,
out.size(),
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim,
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
}
});
}); });
}); });
} }

View File

@@ -50,45 +50,49 @@ void copy_general_input(
int64_t offset_out, int64_t offset_out,
const Shape& shape, const Shape& shape,
const Strides& strides_in) { const Strides& strides_in) {
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) { dispatch_bool(
dispatch_bool( in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, [&](auto large) {
[&](auto large) { using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>; using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>; using IdxT = std::conditional_t<large(), int64_t, int32_t>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>; const InType* in_ptr = in.data<InType>() + offset_in;
const InType* in_ptr = in.data<InType>() + offset_in; OutType* out_ptr = out.data<OutType>() + offset_out;
OutType* out_ptr = out.data<OutType>() + offset_out; int ndim = shape.size();
int ndim = shape.size(); if (ndim <= 3) {
if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
dispatch_1_2_3(ndim, [&](auto dims_constant) { auto kernel =
auto kernel = cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>;
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(strides_in));
});
} else { // ndim >= 4
auto kernel = cu::copy_g<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large()); get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>( encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
in_ptr, in_ptr,
out_ptr, out_ptr,
out.size(), out.size(),
const_param(shape), const_param<dims_constant()>(shape),
const_param(strides_in), const_param<dims_constant()>(strides_in));
ndim); });
} } else { // ndim >= 4
}); auto kernel = cu::copy_g<InType, OutType, IdxT>;
}); auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
in_ptr,
out_ptr,
out.size(),
const_param(shape),
const_param(strides_in),
ndim);
}
});
}); });
}); });
} }

View File

@@ -2,38 +2,28 @@
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/worker.h" #include "mlx/backend/cuda/worker.h"
#include "mlx/backend/metal/metal.h" #include "mlx/utils.h"
#include <fmt/format.h> #include <fmt/format.h>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <future> #include <future>
#include <unordered_set>
namespace mlx::core { namespace mlx::core {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
// This should be less than 255
constexpr int default_max_nodes_per_graph = 20;
int cuda_graph_cache_size() {
static int cache_size = []() {
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
}();
return cache_size;
}
namespace cu { namespace cu {
DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {}
void DeviceStream::synchronize() {
cudaStreamSynchronize(stream_);
}
cudaStream_t DeviceStream::schedule_cuda_stream() {
// TODO: Return a stream that maximizes parallelism.
return stream_;
}
cudaStream_t DeviceStream::last_cuda_stream() {
return stream_;
}
CommandEncoder& DeviceStream::get_encoder() {
if (!encoder_) {
encoder_ = std::make_unique<CommandEncoder>(*this);
}
return *encoder_;
}
Device::Device(int device) : device_(device) { Device::Device(int device) : device_(device) {
CHECK_CUDA_ERROR(cudaDeviceGetAttribute( CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
&compute_capability_major_, cudaDevAttrComputeCapabilityMajor, device_)); &compute_capability_major_, cudaDevAttrComputeCapabilityMajor, device_));
@@ -67,49 +57,261 @@ void Device::make_current() {
} }
} }
DeviceStream& Device::get_stream(Stream s) { CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
auto it = streams_.find(s.index); CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
if (it == streams_.end()) { CHECK_CUDA_ERROR(
it = streams_.try_emplace(s.index, *this).first; cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
}
CommandEncoder::CaptureContext::~CaptureContext() {
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
size_t num_nodes;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes));
if (num_nodes == 1) {
cudaGraphNode_t captured_node;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes));
CUDA_KERNEL_NODE_PARAMS params;
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, &params));
cudaGraphNode_t node;
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, enc.graph_, NULL, 0, &params));
enc.insert_graph_dependencies(GraphNode{node, 'K'});
} else {
cudaGraphNode_t node;
CHECK_CUDA_ERROR(
cudaGraphAddChildGraphNode(&node, enc.graph_, NULL, 0, graph));
enc.insert_graph_dependencies(GraphNode{node, 'G'});
}
CHECK_CUDA_ERROR(cudaGraphDestroy(graph));
}
CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
: enc(enc) {
enc.in_concurrent_ = true;
}
CommandEncoder::ConcurrentContext::~ConcurrentContext() {
enc.in_concurrent_ = false;
// Use an empty graph node for synchronization
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
enc.empty_node_count_++;
CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0));
// Insert the concurrent -> empty node dependencies
for (auto& from : enc.concurrent_nodes_) {
enc.from_nodes_.push_back(from.node);
enc.to_nodes_.push_back(empty.node);
enc.graph_key_ += from.id;
enc.graph_key_ += from.node_type;
enc.graph_key_ += empty.id;
enc.graph_key_ += empty.node_type;
}
// Insert the input -> concurrent node dependencies without updating output
// nodes
auto outputs = std::move(enc.active_outputs_);
enc.insert_graph_dependencies(std::move(enc.concurrent_nodes_));
// Update output node to be the empty node
for (auto o : outputs) {
enc.node_map_.emplace(o, empty).first->second = empty;
}
}
void CommandEncoder::insert_graph_dependencies(GraphNode node) {
if (node.node_type == 'G') {
graph_node_count_++;
}
node.id = std::to_string(node_count_++);
if (in_concurrent_) {
concurrent_nodes_.push_back(std::move(node));
} else {
std::vector<GraphNode> nodes;
nodes.push_back(std::move(node));
insert_graph_dependencies(std::move(nodes));
}
}
void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
std::vector<GraphNode> deps;
{
// Dependencies must be added in the same order to produce a consistent
// topology
std::unordered_set<cudaGraphNode_t> set_deps;
for (auto d : active_deps_) {
if (auto it = node_map_.find(d); it != node_map_.end()) {
auto [_, inserted] = set_deps.insert(it->second.node);
if (inserted) {
deps.push_back(it->second);
}
}
}
}
active_deps_.clear();
for (auto o : active_outputs_) {
for (auto& node : nodes) {
node_map_.emplace(o, node).first->second = node;
}
}
active_outputs_.clear();
for (auto& from : deps) {
for (auto& to : nodes) {
from_nodes_.push_back(from.node);
to_nodes_.push_back(to.node);
graph_key_ += from.id;
graph_key_ += from.node_type;
graph_key_ += to.id;
graph_key_ += to.node_type;
}
}
}
CommandEncoder& Device::get_command_encoder(Stream s) {
auto it = encoders_.find(s.index);
if (it == encoders_.end()) {
it = encoders_.try_emplace(s.index, *this).first;
} }
return it->second; return it->second;
} }
CommandEncoder::CommandEncoder(DeviceStream& s) CommandEncoder::CommandEncoder(Device& d) : stream_(d) {
: device_(s.device()), stream_(s) {} CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
}
void clear_graphs(std::unordered_map<std::string, cudaGraphExec_t>& graphs) {
for (auto& [_, graph_exec] : graphs) {
CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec));
}
graphs.clear();
}
CommandEncoder::~CommandEncoder() {
clear_graphs(graph_cache_);
}
void CommandEncoder::add_completed_handler(std::function<void()> task) { void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task)); worker_.add_task(std::move(task));
} }
void CommandEncoder::end_encoding() { void CommandEncoder::set_input_array(const array& arr) {
if (!temporaries_.empty()) { auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
add_completed_handler([temporaries = std::move(temporaries_)]() {}); active_deps_.push_back(id);
} }
// There is no kernel running, run completion handlers immediately. void CommandEncoder::set_output_array(const array& arr) {
if (!has_gpu_work_) { auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
worker_.consume_in_this_thread(); active_deps_.push_back(id);
return; active_outputs_.push_back(id);
} }
has_gpu_work_ = false;
// Put completion handlers in a batch. void CommandEncoder::maybe_commit() {
worker_.end_batch(); if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) {
// Signaling kernel completion is expensive, delay until enough batches.
// TODO: This number is arbitrarily picked, profile for a better stragety.
if (worker_.uncommited_batches() > 8) {
commit(); commit();
} }
} }
void CommandEncoder::add_kernel_node(
void* func,
dim3 grid_dim,
dim3 block_dim,
void** params) {
cudaKernelNodeParams kernel_params = {0};
kernel_params.func = func;
kernel_params.gridDim = grid_dim;
kernel_params.blockDim = block_dim;
kernel_params.kernelParams = params;
cudaGraphNode_t node;
CHECK_CUDA_ERROR(
cudaGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
insert_graph_dependencies(GraphNode{node, 'K'});
}
void CommandEncoder::add_kernel_node(
CUfunction func,
dim3 grid_dim,
dim3 block_dim,
void** params) {
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
kernel_params.func = func;
kernel_params.gridDimX = grid_dim.x;
kernel_params.gridDimY = grid_dim.y;
kernel_params.gridDimZ = grid_dim.z;
kernel_params.blockDimX = block_dim.x;
kernel_params.blockDimY = block_dim.y;
kernel_params.blockDimZ = block_dim.z;
kernel_params.kernelParams = params;
CUgraphNode node;
CHECK_CUDA_ERROR(
cuGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
insert_graph_dependencies(GraphNode{node, 'K'});
}
void CommandEncoder::commit() { void CommandEncoder::commit() {
worker_.commit(stream_.last_cuda_stream()); if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {});
}
if (node_count_ > 0) {
if (!from_nodes_.empty()) {
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size()));
}
graph_key_ += ".";
graph_key_ += std::to_string(node_count_);
graph_key_ += ".";
graph_key_ += std::to_string(graph_node_count_);
graph_key_ += ".";
graph_key_ += std::to_string(empty_node_count_);
cudaGraphExec_t& graph_exec = graph_cache_[graph_key_];
if (graph_exec != nullptr) {
cudaGraphExecUpdateResult update_result;
#if CUDART_VERSION >= 12000
cudaGraphExecUpdateResultInfo info;
cudaGraphExecUpdate(graph_exec, graph_, &info);
update_result = info.result;
#else
cudaGraphNode_t error_node;
cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result);
#endif // CUDART_VERSION >= 12000
if (update_result != cudaGraphExecUpdateSuccess) {
cudaGetLastError(); // reset error
CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec));
graph_exec = nullptr;
}
}
if (graph_exec == nullptr) {
CHECK_CUDA_ERROR(
cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0));
}
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
// TODO smarter cache policy
if (graph_cache_.size() > cuda_graph_cache_size()) {
clear_graphs(graph_cache_);
}
// Reset state
node_count_ = 0;
graph_node_count_ = 0;
from_nodes_.clear();
to_nodes_.clear();
graph_key_.clear();
node_map_.clear();
CHECK_CUDA_ERROR(cudaGraphDestroy(graph_));
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
}
// Put completion handlers in a batch.
worker_.end_batch();
worker_.commit(stream_);
} }
void CommandEncoder::synchronize() { void CommandEncoder::synchronize() {
stream().synchronize(); cudaStreamSynchronize(stream_);
auto p = std::make_shared<std::promise<void>>(); auto p = std::make_shared<std::promise<void>>();
std::future<void> f = p->get_future(); std::future<void> f = p->get_future();
add_completed_handler([p = std::move(p)]() { p->set_value(); }); add_completed_handler([p = std::move(p)]() { p->set_value(); });
@@ -127,12 +329,8 @@ Device& device(mlx::core::Device device) {
return it->second; return it->second;
} }
DeviceStream& get_stream(Stream s) {
return device(s.device).get_stream(s);
}
CommandEncoder& get_command_encoder(Stream s) { CommandEncoder& get_command_encoder(Stream s) {
return get_stream(s).get_encoder(); return device(s.device).get_command_encoder(s);
} }
} // namespace cu } // namespace cu

View File

@@ -7,41 +7,108 @@
#include "mlx/stream.h" #include "mlx/stream.h"
#include <cublasLt.h> #include <cublasLt.h>
#include <cuda.h>
#include <thrust/execution_policy.h> #include <thrust/execution_policy.h>
#include <unordered_map> #include <unordered_map>
namespace mlx::core::cu { namespace mlx::core::cu {
class Device; class CommandEncoder {
class CommandEncoder;
class DeviceStream {
public: public:
explicit DeviceStream(Device& device); struct CaptureContext {
CaptureContext(CommandEncoder& enc);
~CaptureContext();
cudaGraph_t graph;
CommandEncoder& enc;
};
struct ConcurrentContext {
ConcurrentContext(CommandEncoder& enc);
~ConcurrentContext();
CommandEncoder& enc;
};
DeviceStream(const DeviceStream&) = delete; explicit CommandEncoder(Device& d);
DeviceStream& operator=(const DeviceStream&) = delete; ~CommandEncoder();
// Wait until kernels in the stream complete. CommandEncoder(const CommandEncoder&) = delete;
void synchronize(); CommandEncoder& operator=(const CommandEncoder&) = delete;
// Return a cuda stream for launching kernels. CaptureContext capture_context() {
cudaStream_t schedule_cuda_stream(); return CaptureContext{*this};
}
// Return the last cuda stream used. ConcurrentContext concurrent_context() {
cudaStream_t last_cuda_stream(); return ConcurrentContext{*this};
CommandEncoder& get_encoder();
Device& device() {
return device_;
} }
void set_input_array(const array& arr);
void set_output_array(const array& arr);
template <typename F, typename... Params>
void
add_kernel_node(F* func, dim3 grid_dim, dim3 block_dim, Params&&... params) {
constexpr size_t num = sizeof...(Params);
void* ptrs[num];
size_t i = 0;
([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(
std::forward<Params>(params)),
...);
add_kernel_node((void*)func, grid_dim, block_dim, ptrs);
}
void add_kernel_node(
CUfunction func,
dim3 grid_dim,
dim3 block_dim,
void** params);
void
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params);
void add_temporary(const array& arr) {
temporaries_.push_back(arr.data_shared_ptr());
}
void add_completed_handler(std::function<void()> task);
void maybe_commit();
void commit();
CudaStream& stream() {
return stream_;
}
// Wait until kernels and completion handlers are finished
void synchronize();
private: private:
Device& device_; struct GraphNode {
cudaGraphNode_t node;
// K = kernel
// E = empty
// G = subgraph
char node_type;
std::string id;
};
void insert_graph_dependencies(GraphNode node);
void insert_graph_dependencies(std::vector<GraphNode> nodes);
CudaStream stream_; CudaStream stream_;
std::unique_ptr<CommandEncoder> encoder_; cudaGraph_t graph_;
Worker worker_;
char node_count_{0};
char graph_node_count_{0};
char empty_node_count_{0};
bool in_concurrent_{false};
std::vector<cudaGraphNode_t> from_nodes_;
std::vector<cudaGraphNode_t> to_nodes_;
std::string graph_key_;
std::vector<GraphNode> concurrent_nodes_;
std::vector<std::shared_ptr<array::Data>> temporaries_;
std::unordered_map<std::string, cudaGraphExec_t> graph_cache_;
std::vector<std::uintptr_t> active_deps_;
std::vector<std::uintptr_t> active_outputs_;
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
}; };
class Device { class Device {
@@ -55,7 +122,7 @@ class Device {
// Make this device the current cuda device, required by some cuda calls. // Make this device the current cuda device, required by some cuda calls.
void make_current(); void make_current();
DeviceStream& get_stream(Stream s); CommandEncoder& get_command_encoder(Stream s);
int cuda_device() const { int cuda_device() const {
return device_; return device_;
@@ -75,67 +142,10 @@ class Device {
int compute_capability_major_; int compute_capability_major_;
int compute_capability_minor_; int compute_capability_minor_;
cublasLtHandle_t lt_; cublasLtHandle_t lt_;
std::unordered_map<int, DeviceStream> streams_; std::unordered_map<int, CommandEncoder> encoders_;
};
class CommandEncoder {
public:
explicit CommandEncoder(DeviceStream& stream);
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
void set_input_array(const array& arr) {}
void set_output_array(const array& arr) {}
void add_temporary(const array& arr) {
temporaries_.push_back(arr.data_shared_ptr());
}
void add_completed_handler(std::function<void()> task);
void end_encoding();
void commit();
// Schedule a cuda stream for |fun| to launch kernels, and check error
// afterwards.
template <typename F>
void launch_kernel(F&& fun) {
launch_kernel(stream_.schedule_cuda_stream(), std::forward<F>(fun));
}
template <typename F>
void launch_kernel(cudaStream_t stream, F&& fun) {
device_.make_current();
fun(stream);
check_cuda_error("kernel launch", cudaGetLastError());
has_gpu_work_ = true;
}
Device& device() {
return device_;
}
DeviceStream& stream() {
return stream_;
}
bool has_gpu_work() const {
return has_gpu_work_;
}
// Wait until kernels and completion handlers are finished
void synchronize();
private:
Device& device_;
DeviceStream& stream_;
Worker worker_;
bool has_gpu_work_{false};
std::vector<std::shared_ptr<array::Data>> temporaries_;
}; };
Device& device(mlx::core::Device device); Device& device(mlx::core::Device device);
DeviceStream& get_stream(Stream s);
CommandEncoder& get_command_encoder(Stream s); CommandEncoder& get_command_encoder(Stream s);
// Return an execution policy that does not sync for result. // Return an execution policy that does not sync for result.

View File

@@ -3,6 +3,8 @@
#pragma once #pragma once
#include <cuComplex.h> #include <cuComplex.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <thrust/iterator/transform_iterator.h> #include <thrust/iterator/transform_iterator.h>
namespace mlx::core::cu { namespace mlx::core::cu {
@@ -17,6 +19,26 @@ struct CastOp {
} }
}; };
// Castings between complex and boolean.
// TODO: Should make a custom complex type.
template <>
struct CastOp<cuComplex, bool> {
static constexpr bool is_castable = true;
__device__ bool operator()(cuComplex x) {
return x.x != 0 && x.y != 0;
}
};
template <>
struct CastOp<bool, cuComplex> {
static constexpr bool is_castable = true;
__device__ cuComplex operator()(bool x) {
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
}
};
// Converting a complex number to real number discards the imaginary part. // Converting a complex number to real number discards the imaginary part.
template <typename DstT> template <typename DstT>
struct CastOp< struct CastOp<
@@ -45,6 +67,7 @@ struct CastOp<
} }
}; };
// Do nothing when no casting is needed.
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
struct CastOp< struct CastOp<
SrcT, SrcT,
@@ -57,9 +80,53 @@ struct CastOp<
} }
}; };
// In CUDA 11 the half types do not define conversions between some types,
// provide fallbacks here.
#if CUDART_VERSION < 12000
template <typename SrcT, typename DstT>
struct CastOp<
SrcT,
DstT,
cuda::std::enable_if_t<
!cuda::std::is_convertible_v<SrcT, DstT> &&
!cuda::std::is_same_v<SrcT, cuComplex> &&
(cuda::std::is_same_v<DstT, __half> ||
cuda::std::is_same_v<DstT, __nv_bfloat16>)>> {
static constexpr bool is_castable = true;
__device__ DstT operator()(SrcT x) {
return DstT(static_cast<float>(x));
}
};
template <typename SrcT, typename DstT>
struct CastOp<
SrcT,
DstT,
cuda::std::enable_if_t<
!cuda::std::is_convertible_v<SrcT, DstT> &&
!cuda::std::is_same_v<DstT, cuComplex> &&
!cuda::std::is_same_v<DstT, __half> &&
!cuda::std::is_same_v<DstT, __nv_bfloat16> &&
(cuda::std::is_same_v<SrcT, __half> ||
cuda::std::is_same_v<SrcT, __nv_bfloat16>)>> {
static constexpr bool is_castable = true;
__device__ DstT operator()(SrcT x) {
return DstT(static_cast<float>(x));
}
};
#endif // CUDART_VERSION < 12000
// Helper to deduce the SrcT.
template <typename DstT, typename SrcT>
inline __host__ __device__ auto cast_to(SrcT x) {
return CastOp<SrcT, DstT>{}(x);
}
// Return an iterator that cast the value to DstT using CastOp. // Return an iterator that cast the value to DstT using CastOp.
template <typename DstT, typename Iterator> template <typename DstT, typename Iterator>
__host__ __device__ auto make_cast_iterator(Iterator it) { inline __host__ __device__ auto make_cast_iterator(Iterator it) {
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type; using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
if constexpr (std::is_same_v<SrcT, DstT>) { if constexpr (std::is_same_v<SrcT, DstT>) {
return it; return it;

View File

@@ -28,6 +28,27 @@ namespace mlx::core::cu {
using Shape = cuda::std::array<int32_t, MAX_NDIM>; using Shape = cuda::std::array<int32_t, MAX_NDIM>;
using Strides = cuda::std::array<int64_t, MAX_NDIM>; using Strides = cuda::std::array<int64_t, MAX_NDIM>;
// Vectorized load/store.
template <typename T, int N>
struct alignas(sizeof(T) * N) AlignedVector {
T val[N];
};
template <int N, typename T>
inline __device__ AlignedVector<T, N> load_vector(
const T* ptr,
uint32_t offset) {
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
return from[offset];
}
template <int N, typename T>
inline __device__ void
store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Type limits utils // Type limits utils
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@@ -78,20 +99,20 @@ struct Limits<
return cuda::std::numeric_limits<T>::infinity(); return cuda::std::numeric_limits<T>::infinity();
} }
static constexpr __host__ __device__ T min() { static constexpr __host__ __device__ T min() {
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000 #if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
return -cuda::std::numeric_limits<T>::infinity();
#else
return -cuda::std::numeric_limits<float>::infinity(); return -cuda::std::numeric_limits<float>::infinity();
#else
return -cuda::std::numeric_limits<T>::infinity();
#endif #endif
} }
static constexpr __host__ __device__ T finite_max() { static constexpr __host__ __device__ T finite_max() {
return cuda::std::numeric_limits<T>::max(); return cuda::std::numeric_limits<T>::max();
} }
static constexpr __host__ __device__ T finite_min() { static constexpr __host__ __device__ T finite_min() {
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000 #if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
return cuda::std::numeric_limits<T>::lowest();
#else
return cuda::std::numeric_limits<float>::lowest(); return cuda::std::numeric_limits<float>::lowest();
#else
return cuda::std::numeric_limits<T>::lowest();
#endif #endif
} }
}; };

View File

@@ -37,22 +37,20 @@ void eval(array& arr) {
} }
auto& encoder = cu::get_command_encoder(arr.primitive().stream()); auto& encoder = cu::get_command_encoder(arr.primitive().stream());
if (encoder.has_gpu_work()) { // Keep used buffers alive until kernel finishes running.
// Keep used buffers alive until kernel finishes running. std::unordered_set<std::shared_ptr<array::Data>> buffers;
std::unordered_set<std::shared_ptr<array::Data>> buffers; for (auto& in : arr.inputs()) {
for (auto& in : arr.inputs()) { buffers.insert(in.data_shared_ptr());
buffers.insert(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
}
// Remove the output if it was donated to by an input.
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
} }
encoder.end_encoding(); for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
}
// Remove the output if it was donated to by an input.
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
encoder.maybe_commit();
} }
void finalize(Stream s) { void finalize(Stream s) {

View File

@@ -61,7 +61,9 @@ void CudaEvent::wait(Stream s) {
if (s.device == mlx::core::Device::cpu) { if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this]() mutable { wait(); }); scheduler::enqueue(s, [*this]() mutable { wait(); });
} else { } else {
wait(cu::get_stream(s).last_cuda_stream()); auto& enc = cu::get_command_encoder(s);
enc.commit();
wait(enc.stream());
} }
} }
@@ -74,7 +76,9 @@ void CudaEvent::record(Stream s) {
if (s.device == mlx::core::Device::cpu) { if (s.device == mlx::core::Device::cpu) {
throw std::runtime_error("CudaEvent can not wait on cpu stream."); throw std::runtime_error("CudaEvent can not wait on cpu stream.");
} else { } else {
record(cu::get_stream(s).last_cuda_stream()); auto& enc = cu::get_command_encoder(s);
enc.commit();
record(enc.stream());
} }
} }
@@ -136,11 +140,9 @@ void SharedEvent::wait(Stream s, uint64_t value) {
scheduler::enqueue(s, [*this, value]() mutable { wait(value); }); scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
} else { } else {
auto& encoder = get_command_encoder(s); auto& encoder = get_command_encoder(s);
encoder.launch_kernel( encoder.commit();
encoder.stream().last_cuda_stream(), wait(encoder.stream(), value);
[this, value](cudaStream_t stream) { wait(stream, value); });
encoder.add_completed_handler([ac = ac_]() {}); encoder.add_completed_handler([ac = ac_]() {});
encoder.end_encoding();
} }
} }
@@ -162,11 +164,9 @@ void SharedEvent::signal(Stream s, uint64_t value) {
scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); }); scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); });
} else { } else {
auto& encoder = get_command_encoder(s); auto& encoder = get_command_encoder(s);
encoder.launch_kernel( encoder.commit();
encoder.stream().last_cuda_stream(), signal(encoder.stream(), value);
[this, value](cudaStream_t stream) { signal(stream, value); });
encoder.add_completed_handler([ac = ac_]() {}); encoder.add_completed_handler([ac = ac_]() {});
encoder.end_encoding();
} }
} }

View File

@@ -3,13 +3,16 @@
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "cuda_jit_sources.h" #include "cuda_jit_sources.h"
#include <cuda.h>
#include <fmt/format.h> #include <fmt/format.h>
#include <nvrtc.h>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <cassert> #include <cassert>
@@ -22,7 +25,7 @@ namespace {
constexpr const char* g_scatter_ops[] = {"Max", "Min", "Sum", "Prod", "Assign"}; constexpr const char* g_scatter_ops[] = {"Max", "Min", "Sum", "Prod", "Assign"};
void append_indices_arg( void append_indices_arg(
cu::JitModule& mod, cu::KernelArgs& args,
const std::vector<array>& inputs, const std::vector<array>& inputs,
int nidx, int nidx,
int idx_ndim) { int idx_ndim) {
@@ -30,7 +33,7 @@ void append_indices_arg(
for (int i = 0; i < nidx; ++i) { for (int i = 0; i < nidx; ++i) {
indices[i] = inputs[i + 1].data<void>(); indices[i] = inputs[i + 1].data<void>();
} }
mod.append_arg(std::move(indices)); args.append(std::move(indices));
std::vector<int32_t> indices_shape(nidx * idx_ndim); std::vector<int32_t> indices_shape(nidx * idx_ndim);
for (int i = 0; i < nidx; ++i) { for (int i = 0; i < nidx; ++i) {
std::copy_n( std::copy_n(
@@ -38,7 +41,7 @@ void append_indices_arg(
idx_ndim, idx_ndim,
indices_shape.data() + i * idx_ndim); indices_shape.data() + i * idx_ndim);
} }
mod.append_arg(std::move(indices_shape)); args.append(std::move(indices_shape));
std::vector<int64_t> indices_strides(nidx * idx_ndim); std::vector<int64_t> indices_strides(nidx * idx_ndim);
for (int i = 0; i < nidx; ++i) { for (int i = 0; i < nidx; ++i) {
std::copy_n( std::copy_n(
@@ -46,7 +49,7 @@ void append_indices_arg(
idx_ndim, idx_ndim,
indices_strides.data() + i * idx_ndim); indices_strides.data() + i * idx_ndim);
} }
mod.append_arg(std::move(indices_strides)); args.append(std::move(indices_strides));
} }
} // namespace } // namespace
@@ -94,20 +97,21 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
return std::make_pair(jit_source_gather, std::move(kernel_names)); return std::make_pair(jit_source_gather, std::move(kernel_names));
}); });
mod.append_arg(src); cu::KernelArgs args;
mod.append_arg(out); args.append(src);
args.append(out);
if (large) { if (large) {
mod.append_arg<int64_t>(out.size()); args.append<int64_t>(out.size());
} else { } else {
mod.append_arg<int32_t>(out.size()); args.append<int32_t>(out.size());
} }
mod.append_ndim_arg(src.shape()); args.append_ndim(src.shape());
mod.append_ndim_arg(src.strides()); args.append_ndim(src.strides());
mod.append_arg<int32_t>(src.ndim()); args.append<int32_t>(src.ndim());
mod.append_ndim_arg(slice_sizes_); args.append_ndim(slice_sizes_);
mod.append_arg(slice_size); args.append(slice_size);
mod.append_arg(axes_); args.append(axes_);
append_indices_arg(mod, inputs, nidx, idx_ndim); append_indices_arg(args, inputs, nidx, idx_ndim);
std::string kernel_name = fmt::format( std::string kernel_name = fmt::format(
"mlx::core::cu::gather<{}, {}, {}, {}, {}>", "mlx::core::cu::gather<{}, {}, {}, {}, {}>",
@@ -122,9 +126,10 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in); encoder.set_input_array(in);
} }
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
mod.launch_kernel(stream, kernel_name, out, large); auto kernel = mod.get_kernel(kernel_name);
}); auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
} }
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) { void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -187,26 +192,27 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
return std::make_pair(jit_source_scatter, std::move(kernel_names)); return std::make_pair(jit_source_scatter, std::move(kernel_names));
}); });
mod.append_arg(upd); cu::KernelArgs args;
mod.append_arg(out); args.append(upd);
args.append(out);
if (large) { if (large) {
mod.append_arg<int64_t>(upd.size()); args.append<int64_t>(upd.size());
} else { } else {
mod.append_arg<int32_t>(upd.size()); args.append<int32_t>(upd.size());
} }
mod.append_ndim_arg(upd.shape()); args.append_ndim(upd.shape());
mod.append_ndim_arg(upd.strides()); args.append_ndim(upd.strides());
mod.append_arg<int32_t>(upd.ndim()); args.append<int32_t>(upd.ndim());
if (large) { if (large) {
mod.append_arg<int64_t>(upd_post_idx_size); args.append<int64_t>(upd_post_idx_size);
} else { } else {
mod.append_arg<int32_t>(upd_post_idx_size); args.append<int32_t>(upd_post_idx_size);
} }
mod.append_ndim_arg(out.shape()); args.append_ndim(out.shape());
mod.append_ndim_arg(out.strides()); args.append_ndim(out.strides());
mod.append_arg<int32_t>(out.ndim()); args.append<int32_t>(out.ndim());
mod.append_arg(axes_); args.append(axes_);
append_indices_arg(mod, inputs, nidx, idx_ndim); append_indices_arg(args, inputs, nidx, idx_ndim);
std::string kernel_name = fmt::format( std::string kernel_name = fmt::format(
"mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>", "mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>",
@@ -222,9 +228,9 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in); encoder.set_input_array(in);
} }
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { auto kernel = mod.get_kernel(kernel_name);
mod.launch_kernel(stream, kernel_name, upd, large); auto [num_blocks, block_dims] = get_launch_args(kernel, upd, large);
}); encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
} }
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) { void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -275,25 +281,26 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
size_t idx_size_axis = idx.shape(axis_); size_t idx_size_axis = idx.shape(axis_);
mod.append_arg(src); cu::KernelArgs args;
mod.append_arg(idx); args.append(src);
mod.append_arg(out); args.append(idx);
args.append(out);
if (large) { if (large) {
mod.append_arg<int64_t>(idx_size_pre); args.append<int64_t>(idx_size_pre);
mod.append_arg<int64_t>(idx_size_axis); args.append<int64_t>(idx_size_axis);
mod.append_arg<int64_t>(idx_size_post); args.append<int64_t>(idx_size_post);
} else { } else {
mod.append_arg<int32_t>(idx_size_pre); args.append<int32_t>(idx_size_pre);
mod.append_arg<int32_t>(idx_size_axis); args.append<int32_t>(idx_size_axis);
mod.append_arg<int32_t>(idx_size_post); args.append<int32_t>(idx_size_post);
} }
mod.append_arg(remove_index(idx.shape(), axis_)); args.append(remove_index(idx.shape(), axis_));
mod.append_arg(remove_index(src.strides(), axis_)); args.append(remove_index(src.strides(), axis_));
mod.append_arg(remove_index(idx.strides(), axis_)); args.append(remove_index(idx.strides(), axis_));
mod.append_arg<int32_t>(axis_); args.append<int32_t>(axis_);
mod.append_arg(src.shape(axis_)); args.append(src.shape(axis_));
mod.append_arg(src.strides(axis_)); args.append(src.strides(axis_));
mod.append_arg(idx.strides(axis_)); args.append(idx.strides(axis_));
std::string kernel_name = fmt::format( std::string kernel_name = fmt::format(
"mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>", "mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>",
@@ -309,9 +316,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in); encoder.set_input_array(in);
} }
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { auto kernel = mod.get_kernel(kernel_name);
mod.launch_kernel(stream, kernel_name, idx, large); auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
}); encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
} }
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) { void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -377,25 +384,26 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
size_t idx_size_axis = idx.shape(axis_); size_t idx_size_axis = idx.shape(axis_);
mod.append_arg(upd); cu::KernelArgs args;
mod.append_arg(idx); args.append(upd);
mod.append_arg(out); args.append(idx);
args.append(out);
if (large) { if (large) {
mod.append_arg<int64_t>(idx_size_pre); args.append<int64_t>(idx_size_pre);
mod.append_arg<int64_t>(idx_size_axis); args.append<int64_t>(idx_size_axis);
mod.append_arg<int64_t>(idx_size_post); args.append<int64_t>(idx_size_post);
} else { } else {
mod.append_arg<int32_t>(idx_size_pre); args.append<int32_t>(idx_size_pre);
mod.append_arg<int32_t>(idx_size_axis); args.append<int32_t>(idx_size_axis);
mod.append_arg<int32_t>(idx_size_post); args.append<int32_t>(idx_size_post);
} }
mod.append_arg(remove_index(idx.shape(), axis_)); args.append(remove_index(idx.shape(), axis_));
mod.append_arg(remove_index(upd.strides(), axis_)); args.append(remove_index(upd.strides(), axis_));
mod.append_arg(remove_index(idx.strides(), axis_)); args.append(remove_index(idx.strides(), axis_));
mod.append_arg<int32_t>(axis_); args.append<int32_t>(axis_);
mod.append_arg(out.shape(axis_)); args.append(out.shape(axis_));
mod.append_arg(upd.strides(axis_)); args.append(upd.strides(axis_));
mod.append_arg(idx.strides(axis_)); args.append(idx.strides(axis_));
std::string kernel_name = fmt::format( std::string kernel_name = fmt::format(
"mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>", "mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>",
@@ -412,9 +420,9 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in); encoder.set_input_array(in);
} }
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { auto kernel = mod.get_kernel(kernel_name);
mod.launch_kernel(stream, kernel_name, idx, large); auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
}); encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -26,16 +26,6 @@ void check_nvrtc_error(const char* name, nvrtcResult err) {
} }
} }
#define CHECK_CU_ERROR(cmd) check_cu_error(#cmd, (cmd))
void check_cu_error(const char* name, CUresult err) {
if (err != CUDA_SUCCESS) {
const char* err_str = "Unknown error";
cuGetErrorString(err, &err_str);
throw std::runtime_error(fmt::format("{} failed: {}", name, err_str));
}
}
// Return the location of the CUDA toolkit. // Return the location of the CUDA toolkit.
const std::string& cuda_home() { const std::string& cuda_home() {
static std::string home = []() -> std::string { static std::string home = []() -> std::string {
@@ -280,60 +270,13 @@ JitModule::JitModule(
// Load kernels. // Load kernels.
for (const auto& [name, mangled] : ptx_kernels) { for (const auto& [name, mangled] : ptx_kernels) {
CUfunction kernel; CUfunction kernel;
CHECK_CU_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
kernels_[name] = kernel; kernels_[name] = kernel;
} }
} }
JitModule::~JitModule() { JitModule::~JitModule() {
CHECK_CU_ERROR(cuModuleUnload(module_)); CHECK_CUDA_ERROR(cuModuleUnload(module_));
}
void JitModule::launch_kernel(
CUstream stream,
const std::string& kernel_name,
const array& arr,
bool large,
int work_per_thread) {
CUfunction kernel = get_kernel(kernel_name);
size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread);
int _, block_dim;
CHECK_CU_ERROR(
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
if (block_dim > nthreads) {
block_dim = nthreads;
}
Dims num_blocks{1, 1, 1};
if (large) {
num_blocks =
get_2d_grid_dims_common(arr.shape(), arr.strides(), work_per_thread);
std::get<0>(num_blocks) =
(std::get<0>(num_blocks) + block_dim - 1) / block_dim;
} else {
std::get<0>(num_blocks) = (nthreads + block_dim - 1) / block_dim;
}
launch_kernel(stream, kernel, num_blocks, Dims{block_dim, 1, 1});
}
void JitModule::launch_kernel(
CUstream stream,
CUfunction kernel,
Dims num_blocks,
Dims block_dims) {
CHECK_CU_ERROR(cuLaunchKernel(
kernel,
std::get<0>(num_blocks),
std::get<1>(num_blocks),
std::get<2>(num_blocks),
std::get<0>(block_dims),
std::get<1>(block_dims),
std::get<2>(block_dims),
0,
stream,
args_.data(),
nullptr));
args_.clear();
storage_.clear();
} }
CUfunction JitModule::get_kernel(const std::string& kernel_name) { CUfunction JitModule::get_kernel(const std::string& kernel_name) {
@@ -345,10 +288,6 @@ CUfunction JitModule::get_kernel(const std::string& kernel_name) {
return it->second; return it->second;
} }
void JitModule::append_ptr_arg(const void* v) {
args_.push_back(const_cast<void*>(v));
}
JitModule& get_jit_module( JitModule& get_jit_module(
const mlx::core::Device& device, const mlx::core::Device& device,
const std::string& name, const std::string& name,

View File

@@ -4,6 +4,7 @@
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/config.h" #include "mlx/backend/cuda/device/config.h"
#include <deque> #include <deque>
@@ -23,72 +24,48 @@ using KernelBuilderResult = std::pair<
/* kernel names */ std::vector<std::string>>; /* kernel names */ std::vector<std::string>>;
using KernelBuilder = std::function<KernelBuilderResult()>; using KernelBuilder = std::function<KernelBuilderResult()>;
class JitModule { struct KernelArgs {
public: void** args() {
JitModule( return args_.data();
Device& device, }
const std::string& module_name,
const KernelBuilder& builder);
~JitModule();
JitModule(const JitModule&) = delete; void append(const array& a) {
JitModule& operator=(const JitModule&) = delete; append(reinterpret_cast<CUdeviceptr>(a.data<void>()));
void append_arg(const array& a) {
append_arg(reinterpret_cast<CUdeviceptr>(a.data<void>()));
} }
template <typename T> template <typename T>
void append_arg(T val) { void append(T val) {
storage_.emplace_back(val); storage_.emplace_back(val);
append_ptr_arg(&storage_.back()); append_ptr(&storage_.back());
} }
template <typename T> template <typename T>
void append_arg(std::vector<T> vec) { void append(std::vector<T> vec) {
if (vec.empty()) { if (vec.empty()) {
// The nullptr can not be used as arg, pass something not null. // The nullptr can not be used as arg, pass something not null.
append_arg(std::monostate{}); append(std::monostate{});
} else { } else {
append_ptr_arg(vec.data()); append_ptr(vec.data());
storage_.emplace_back(std::move(vec)); storage_.emplace_back(std::move(vec));
} }
} }
// Make sure the arg is copied to an array with size of NDIM. // Make sure the arg is copied to an array with size of NDIM.
template <size_t NDIM = MAX_NDIM, typename T> template <size_t NDIM = MAX_NDIM, typename T>
void append_ndim_arg(const std::vector<T>& vec) { void append_ndim(std::vector<T> vec) {
if (vec.size() > NDIM) { if (vec.size() > NDIM) {
throw std::runtime_error( throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", NDIM)); fmt::format("ndim can not be larger than {}.", NDIM));
} }
std::vector<T> copied(NDIM); vec.resize(NDIM);
std::copy(vec.begin(), vec.end(), copied.data()); append(std::move(vec));
append_arg(std::move(copied));
} }
// Launch kernel with |kernel_name| that each thread works on void append_ptr(const void* v) {
// |work_per_thread| elements of |arr|. args_.push_back(const_cast<void*>(v));
void launch_kernel( }
CUstream stream,
const std::string& kernel_name,
const array& arr,
bool large,
int work_per_thread = 1);
void launch_kernel(
CUstream stream,
CUfunction kernel,
Dims num_blocks,
Dims block_dims);
CUfunction get_kernel(const std::string& kernel_name);
private: private:
void append_ptr_arg(const void* v);
CUmodule module_{nullptr};
std::unordered_map<std::string, CUfunction> kernels_;
std::vector<void*> args_; std::vector<void*> args_;
// The cuLaunchKernel API requires passing pointers to arguments so store // The cuLaunchKernel API requires passing pointers to arguments so store
@@ -105,6 +82,23 @@ class JitModule {
std::deque<Arg> storage_; std::deque<Arg> storage_;
}; };
class JitModule {
public:
JitModule(
Device& device,
const std::string& module_name,
const KernelBuilder& builder);
~JitModule();
JitModule(const JitModule&) = delete;
JitModule& operator=(const JitModule&) = delete;
CUfunction get_kernel(const std::string& kernel_name);
private:
CUmodule module_{nullptr};
std::unordered_map<std::string, CUfunction> kernels_;
};
JitModule& get_jit_module( JitModule& get_jit_module(
const mlx::core::Device& device, const mlx::core::Device& device,
const std::string& name, const std::string& name,

View File

@@ -12,6 +12,7 @@
#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/device/utils.cuh"
#include <cuComplex.h> #include <cuComplex.h>
#include <cuda.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <fmt/format.h> #include <fmt/format.h>
@@ -120,7 +121,13 @@ std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
template <typename T> template <typename T>
inline uint max_occupancy_block_dim(T kernel) { inline uint max_occupancy_block_dim(T kernel) {
int _, block_dim; int _, block_dim;
CHECK_CUDA_ERROR(cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel)); if constexpr (std::is_same_v<T, CUfunction>) {
CHECK_CUDA_ERROR(
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
} else {
CHECK_CUDA_ERROR(
cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
}
return block_dim; return block_dim;
} }

View File

@@ -258,23 +258,23 @@ void LayerNorm::eval_gpu(
encoder.set_input_array(w); encoder.set_input_array(w);
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) {
dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) { constexpr uint32_t N_READS = 4;
constexpr uint32_t N_READS = 4; dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
dispatch_block_dim( using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { auto kernel = cu::layer_norm<DataType, block_dim(), N_READS>;
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; encoder.add_kernel_node(
auto kernel = cu::layer_norm<DataType, block_dim(), N_READS>; kernel,
kernel<<<n_rows, block_dim(), 0, stream>>>( n_rows,
x.data<DataType>(), block_dim(),
w.data<DataType>(), x.data<DataType>(),
b.data<DataType>(), w.data<DataType>(),
out.data<DataType>(), b.data<DataType>(),
eps_, out.data<DataType>(),
axis_size, eps_,
w_stride, axis_size,
b_stride); w_stride,
}); b_stride);
}); });
}); });
} }
@@ -289,21 +289,25 @@ void LayerNormVJP::eval_gpu(
// Ensure row contiguity. We could relax this step by checking that the array // Ensure row contiguity. We could relax this step by checking that the array
// is contiguous (no broadcasts or holes) and that the input strides are the // is contiguous (no broadcasts or holes) and that the input strides are the
// same as the cotangent strides but for now this is simpler. // same as the cotangent strides but for now this is simpler.
auto check_input = [&s](const array& x) -> std::pair<array, bool> { auto check_input = [&s](const array& x, bool& copied) {
if (x.flags().row_contiguous) { if (x.flags().row_contiguous) {
return {x, false}; copied = false;
return x;
} }
copied = true;
array x_copy(x.shape(), x.dtype(), nullptr, {}); array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s); copy_gpu(x, x_copy, CopyType::General, s);
return {x_copy, true}; return x_copy;
}; };
bool donate_x = inputs[0].is_donatable(); bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[3].is_donatable(); bool donate_g = inputs[3].is_donatable();
auto [x, copied] = check_input(inputs[0]); bool copied;
auto x = check_input(inputs[0], copied);
donate_x |= copied; donate_x |= copied;
const array& w = inputs[1]; const array& w = inputs[1];
const array& b = inputs[2]; const array& b = inputs[2];
auto [g, g_copied] = check_input(inputs[3]); bool g_copied;
auto g = check_input(inputs[3], g_copied);
donate_g |= g_copied; donate_g |= g_copied;
array& gx = outputs[0]; array& gx = outputs[0];
array& gw = outputs[1]; array& gw = outputs[1];
@@ -334,8 +338,10 @@ void LayerNormVJP::eval_gpu(
// gradient accumulators. // gradient accumulators.
array gw_temp = array gw_temp =
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
bool g_in_gw = false;
if (has_w) { if (has_w) {
if (!g_in_gx && donate_g) { if (!g_in_gx && donate_g) {
g_in_gw = true;
gw_temp.copy_shared_buffer(g); gw_temp.copy_shared_buffer(g);
} else { } else {
gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
@@ -343,41 +349,47 @@ void LayerNormVJP::eval_gpu(
} }
} }
// Finish with the gradient for b in case we had a b. // The gradient for b in case we had a b.
if (gb.ndim() == 1 && gb.size() == axis_size) { bool has_gb = (gb.ndim() == 1 && gb.size() == axis_size);
if (has_gb) {
ReductionPlan plan( ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan); col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan);
} }
// Insert dependency if `g` was donated
if ((g_in_gx || g_in_gw) && has_gb) {
encoder.set_input_array(gb);
}
encoder.set_input_array(x); encoder.set_input_array(x);
encoder.set_input_array(w); encoder.set_input_array(w);
encoder.set_input_array(g); encoder.set_input_array(g);
encoder.set_output_array(gx); encoder.set_output_array(gx);
encoder.set_output_array(gw_temp); encoder.set_output_array(gw_temp);
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) {
dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) { dispatch_bool(has_w, [&](auto has_w_constant) {
dispatch_bool(has_w, [&](auto has_w_constant) { constexpr int N_READS = 4;
constexpr int N_READS = 4; dispatch_block_dim(
dispatch_block_dim( cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; auto kernel = cu::layer_norm_vjp<
auto kernel = cu::layer_norm_vjp< DataType,
DataType, has_w_constant.value,
has_w_constant.value, block_dim(),
block_dim(), N_READS>;
N_READS>; encoder.add_kernel_node(
kernel<<<n_rows, block_dim(), 0, stream>>>( kernel,
x.data<DataType>(), n_rows,
w.data<DataType>(), block_dim(),
g.data<DataType>(), x.data<DataType>(),
gx.data<DataType>(), w.data<DataType>(),
gw_temp.data<DataType>(), g.data<DataType>(),
eps_, gx.data<DataType>(),
axis_size, gw_temp.data<DataType>(),
w_stride); eps_,
}); axis_size,
}); w_stride);
});
}); });
}); });

View File

@@ -143,16 +143,18 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) {
dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) { constexpr int N_READS = 4;
constexpr int N_READS = 4; dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
dispatch_block_dim( using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { auto kernel = cu::logsumexp<DataType, float, block_dim(), N_READS>;
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; encoder.add_kernel_node(
auto kernel = cu::logsumexp<DataType, float, block_dim(), N_READS>; kernel,
kernel<<<n_rows, block_dim(), 0, stream>>>( n_rows,
in.data<DataType>(), out.data<DataType>(), axis_size); block_dim(),
}); in.data<DataType>(),
out.data<DataType>(),
axis_size);
}); });
}); });
} }

View File

@@ -42,7 +42,8 @@ class MatMul {
int64_t ldb, int64_t ldb,
int32_t batch_count, int32_t batch_count,
int64_t a_batch_stride, int64_t a_batch_stride,
int64_t b_batch_stride) { int64_t b_batch_stride)
: handle_(device.lt_handle()) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cuda_type(dtype); auto scale_type = dtype_to_cuda_type(dtype);
@@ -147,7 +148,7 @@ class MatMul {
if (heuristic_.state != CUBLAS_STATUS_SUCCESS) { if (heuristic_.state != CUBLAS_STATUS_SUCCESS) {
int ret = 0; int ret = 0;
CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic( CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(
encoder.device().lt_handle(), handle_,
matmul_desc_, matmul_desc_,
a_desc_, a_desc_,
b_desc_, b_desc_,
@@ -172,25 +173,24 @@ class MatMul {
workspace_ptr = workspace.data<void>(); workspace_ptr = workspace.data<void>();
} }
encoder.launch_kernel([&](cudaStream_t stream) { auto capture = encoder.capture_context();
CHECK_CUBLAS_ERROR(cublasLtMatmul( CHECK_CUBLAS_ERROR(cublasLtMatmul(
encoder.device().lt_handle(), handle_,
matmul_desc_, matmul_desc_,
&alpha, &alpha,
a, a,
a_desc_, a_desc_,
b, b,
b_desc_, b_desc_,
&beta, &beta,
c ? c : out, c ? c : out,
c ? c_desc_ : out_desc_, c ? c_desc_ : out_desc_,
out, out,
out_desc_, out_desc_,
&heuristic_.algo, &heuristic_.algo,
workspace_ptr, workspace_ptr,
heuristic_.workspaceSize, heuristic_.workspaceSize,
stream)); encoder.stream()));
});
} }
private: private:
@@ -259,6 +259,7 @@ class MatMul {
return desc; return desc;
} }
cublasLtHandle_t handle_{nullptr};
cublasLtMatmulDesc_t matmul_desc_{nullptr}; cublasLtMatmulDesc_t matmul_desc_{nullptr};
cublasLtMatmulPreference_t pref_{nullptr}; cublasLtMatmulPreference_t pref_{nullptr};
cublasLtMatrixLayout_t a_desc_{nullptr}; cublasLtMatrixLayout_t a_desc_{nullptr};
@@ -273,7 +274,7 @@ class MatMul {
namespace { namespace {
std::tuple<bool, int64_t, array> std::tuple<bool, int64_t, array>
check_transpose(std::vector<array>& copies, const Stream& s, const array& arr) { check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2]; auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1]; auto sty = arr.strides()[arr.ndim() - 1];
if (sty == 1 && stx == arr.shape(-1)) { if (sty == 1 && stx == arr.shape(-1)) {
@@ -283,7 +284,7 @@ check_transpose(std::vector<array>& copies, const Stream& s, const array& arr) {
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s); copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy); enc.add_temporary(arr_copy);
return std::make_tuple(false, arr.shape(-1), arr_copy); return std::make_tuple(false, arr.shape(-1), arr_copy);
} }
} }
@@ -317,13 +318,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Keep a vector with copies to be cleared in the completed buffer to release // Keep a vector with copies to be cleared in the completed buffer to release
// the arrays // the arrays
std::vector<array> copies; auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre); auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre);
for (auto& temp : copies) {
encoder.add_temporary(temp);
}
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions // Check and collapse batch dimensions
@@ -348,7 +344,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Invoke cublasLt // Invoke cublasLt
cu::MatMul matmul( cu::MatMul matmul(
encoder.device(), cu::device(s.device),
a.dtype(), a.dtype(),
a_transposed, a_transposed,
M, M,
@@ -373,6 +369,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) { for (size_t i = 0; i < nbatch; ++i) {
matmul.run( matmul.run(
encoder, encoder,
@@ -405,14 +402,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Keep a vector with copies to be cleared in the completed buffer to release // Keep a vector with copies to be cleared in the completed buffer to release
// the arrays // the arrays
std::vector<array> copies; auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre); auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre); auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre);
auto [c_transposed, ldc, c] = check_transpose(copies, s, c_pre);
for (auto& temp : copies) {
encoder.add_temporary(temp);
}
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions // Check and collapse batch dimensions
@@ -440,7 +432,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Invoke cublasLt // Invoke cublasLt
cu::MatMul matmul( cu::MatMul matmul(
encoder.device(), cu::device(s.device),
a.dtype(), a.dtype(),
a_transposed, a_transposed,
M, M,
@@ -478,6 +470,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) { for (size_t i = 0; i < nbatch; ++i) {
matmul.run( matmul.run(
encoder, encoder,

View File

@@ -24,23 +24,21 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }
auto& s = stream(); auto& encoder = cu::get_command_encoder(stream());
auto& encoder = cu::get_command_encoder(s);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&, this](cudaStream_t stream) { auto capture = encoder.capture_context();
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag); using CTYPE = MLX_GET_TYPE(type_tag);
using OutType = cuda_type_t<CTYPE>; using OutType = cuda_type_t<CTYPE>;
CTYPE step = CTYPE step =
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_); static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
thrust::transform( thrust::transform(
cu::thrust_policy(stream), cu::thrust_policy(encoder.stream()),
thrust::counting_iterator<uint32_t>(0), thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(out.data_size()), thrust::counting_iterator<uint32_t>(out.data_size()),
thrust::device_pointer_cast(out.data<OutType>()), thrust::device_pointer_cast(out.data<OutType>()),
cu::Arange<OutType>{ cu::Arange<OutType>{
static_cast<OutType>(start_), static_cast<OutType>(step)}); static_cast<OutType>(start_), static_cast<OutType>(step)});
});
}); });
} }

View File

@@ -156,34 +156,39 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(keys); encoder.set_input_array(keys);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dim3 grid_dims{num_keys, half_size + odd};
dim3 grid_dims{num_keys, half_size + odd}; int64_t total = grid_dims.x * grid_dims.y;
int64_t total = grid_dims.x * grid_dims.y; int32_t threads_y = 1;
int32_t threads_y = 1; while ((total / threads_y) >= (1U << 31)) {
while ((total / threads_y) >= (1U << 31)) { threads_y *= 2;
threads_y *= 2; }
} int32_t threads_x = cuda::ceil_div(total, threads_y);
int32_t threads_x = cuda::ceil_div(total, threads_y); auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1); auto& stream = encoder.stream();
if (keys.flags().row_contiguous) { if (keys.flags().row_contiguous) {
cu::rbitsc<<<grid, block, 0, stream>>>( encoder.add_kernel_node(
keys.data<uint32_t>(), cu::rbitsc,
out.data<uint8_t>(), grid,
grid_dims, block,
odd, keys.data<uint32_t>(),
bytes_per_key); out.data<uint8_t>(),
} else { grid_dims,
cu::rbits<<<grid, block, 0, stream>>>( odd,
keys.data<uint32_t>(), bytes_per_key);
out.data<uint8_t>(), } else {
grid_dims, encoder.add_kernel_node(
odd, cu::rbits,
bytes_per_key, grid,
keys.ndim(), block,
const_param(keys.shape()), keys.data<uint32_t>(),
const_param(keys.strides())); out.data<uint8_t>(),
} grid_dims,
}); odd,
bytes_per_key,
keys.ndim(),
const_param(keys.shape()),
const_param(keys.strides()));
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -37,15 +37,15 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
for (; i + block.size() * N <= check; i += block.size() * N) { for (; i + block.size() * N <= check; i += block.size() * N) {
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals); cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
for (int j = 0; j < N; j++) { for (int j = 0; j < N; j++) {
accs[0] = op(accs[0], __cast<U, T>(vals[j])); accs[0] = op(accs[0], cast_to<U>(vals[j]));
} }
} }
if (i < check) { if (i < check) {
cub::LoadDirectBlocked( cub::LoadDirectBlocked(
block.thread_rank(), in + i, vals, check - i, __cast<T, U>(init)); block.thread_rank(), in + i, vals, check - i, cast_to<T>(init));
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
accs[0] = op(accs[0], __cast<U, T>(vals[i])); accs[0] = op(accs[0], cast_to<U>(vals[i]));
} }
} }
@@ -110,19 +110,20 @@ void all_reduce(
intermediate.set_data(allocator::malloc(intermediate.nbytes())); intermediate.set_data(allocator::malloc(intermediate.nbytes()));
encoder.add_temporary(intermediate); encoder.add_temporary(intermediate);
encoder.set_output_array(intermediate); encoder.set_output_array(intermediate);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(dt, [&](auto type_tag) {
dispatch_all_types(dt, [&](auto type_tag) { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { using OP = MLX_GET_TYPE(reduce_type_tag);
using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using U = typename cu::ReduceResult<OP, T>::type;
using U = typename cu::ReduceResult<OP, T>::type; auto kernel = cu::all_reduce<T, U, OP, N_READS>;
auto kernel = cu::all_reduce<T, U, OP, N_READS>; encoder.add_kernel_node(
kernel<<<blocks, threads, 0, stream>>>( kernel,
static_cast<T*>(indata), blocks,
intermediate.data<U>(), threads,
block_step, static_cast<T*>(indata),
insize); intermediate.data<U>(),
}); block_step,
insize);
}); });
}); });
@@ -135,16 +136,20 @@ void all_reduce(
} }
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(dt, [&](auto type_tag) {
dispatch_all_types(dt, [&](auto type_tag) { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { using OP = MLX_GET_TYPE(reduce_type_tag);
using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using U = typename cu::ReduceResult<OP, T>::type;
using U = typename cu::ReduceResult<OP, T>::type; auto kernel = cu::all_reduce<T, U, OP, N_READS>;
auto kernel = cu::all_reduce<T, U, OP, N_READS>; encoder.add_kernel_node(
kernel<<<blocks, threads, 0, stream>>>( kernel,
static_cast<T*>(indata), out.data<U>(), block_step, insize); blocks,
}); threads,
static_cast<T*>(indata),
out.data<U>(),
block_step,
insize);
}); });
}); });
} }

View File

@@ -3,7 +3,6 @@
#include <numeric> #include <numeric>
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/cast_op.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh"
#include <cooperative_groups.h> #include <cooperative_groups.h>
@@ -128,7 +127,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
T vals[N_READS]; T vals[N_READS];
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals); cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
totals[i] = op(totals[i], __cast<U, T>(vals[i])); totals[i] = op(totals[i], cast_to<U>(vals[i]));
} }
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
} }
@@ -137,7 +136,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
T vals[N_READS]; T vals[N_READS];
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals); cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
totals[i] = op(totals[i], __cast<U, T>(vals[i])); totals[i] = op(totals[i], cast_to<U>(vals[i]));
} }
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
} }
@@ -150,9 +149,9 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
in + loop.location(), in + loop.location(),
vals, vals,
args.reduction_stride - tile_x * BN, args.reduction_stride - tile_x * BN,
__cast<T, U>(ReduceInit<Op, T>::value())); cast_to<T>(ReduceInit<Op, T>::value()));
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
totals[i] = op(totals[i], __cast<U, T>(vals[i])); totals[i] = op(totals[i], cast_to<U>(vals[i]));
} }
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
} }
@@ -214,26 +213,24 @@ void col_reduce_looped(
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_all_types(in.dtype(), [&](auto type_tag) { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { using OP = MLX_GET_TYPE(reduce_type_tag);
using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using U = typename cu::ReduceResult<OP, T>::type;
using U = typename cu::ReduceResult<OP, T>::type; // Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
// Cub doesn't like const pointers for vectorized loads. (sigh) constexpr int N_READS = 4;
T* indata = const_cast<T*>(in.data<T>()); constexpr int BM = 32;
constexpr int BN = 32;
constexpr int N_READS = 4; dim3 grid = output_grid_for_col_reduce(out, args, BN);
constexpr int BM = 32; int blocks = BM * BN / N_READS;
constexpr int BN = 32; auto kernel =
dim3 grid = output_grid_for_col_reduce(out, args, BN); cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
int blocks = BM * BN / N_READS; encoder.add_kernel_node(
auto kernel = kernel, grid, blocks, indata, out.data<U>(), args);
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
kernel<<<grid, blocks, 0, stream>>>(indata, out.data<U>(), args);
});
}); });
}); });
}); });

View File

@@ -32,18 +32,16 @@ void init_reduce(
} }
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_all_types(in.dtype(), [&](auto type_tag) { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { using OP = MLX_GET_TYPE(reduce_type_tag);
using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using U = typename cu::ReduceResult<OP, T>::type;
using U = typename cu::ReduceResult<OP, T>::type; auto kernel = cu::init_reduce<T, U, OP>;
auto kernel = cu::init_reduce<T, U, OP>; dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1); grid.x = (grid.x + 1023) / 1024;
grid.x = (grid.x + 1023) / 1024; encoder.add_kernel_node(kernel, grid, block, out.data<U>(), out.size());
kernel<<<grid, block, 0, stream>>>(out.data<U>(), out.size());
});
}); });
}); });
} }

View File

@@ -2,6 +2,8 @@
#pragma once #pragma once
#include "mlx/backend/cuda/device/atomic_ops.cuh"
#include "mlx/backend/cuda/device/cast_op.cuh"
#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/reduce/reduce_utils.cuh" #include "mlx/backend/cuda/reduce/reduce_utils.cuh"
@@ -40,15 +42,15 @@ struct Sum {
} }
__device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) { __device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) {
atomicAdd(x, y); atomic_add(x, y);
} }
__device__ void atomic_update(int* x, int y) { __device__ void atomic_update(int* x, int y) {
atomicAdd(x, y); atomic_add(x, y);
} }
__device__ void atomic_update(float* x, float y) { __device__ void atomic_update(float* x, float y) {
atomicAdd(x, y); atomic_add(x, y);
} }
}; };
@@ -152,7 +154,7 @@ struct ReduceInit<Sum, T> {
if constexpr (cuda::std::is_same_v<T, cuComplex>) { if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return T{0, 0}; return T{0, 0};
} else { } else {
return typename ReduceResult<Sum, T>::type{0}; return cast_to<typename ReduceResult<Sum, T>::type>(0);
} }
} }
}; };
@@ -163,7 +165,7 @@ struct ReduceInit<Prod, T> {
if constexpr (cuda::std::is_same_v<T, cuComplex>) { if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return T{1, 0}; return T{1, 0};
} else { } else {
return typename ReduceResult<Prod, T>::type{1}; return cast_to<typename ReduceResult<Prod, T>::type>(1);
} }
} }
}; };

View File

@@ -55,22 +55,6 @@ __device__ void atomic_reduce(T* x, T y) {
} }
} }
// TODO: Should make a custom complex type
template <typename U, typename T>
inline __device__ U __cast(T x) {
return static_cast<U>(x);
}
template <>
inline __device__ bool __cast<bool, cuComplex>(cuComplex x) {
return x.x != 0 && x.y != 0;
}
template <>
inline __device__ cuComplex __cast<cuComplex, bool>(bool x) {
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
}
template <typename T, int N, typename Block, typename Warp, typename Op> template <typename T, int N, typename Block, typename Warp, typename Op>
inline __device__ void inline __device__ void
block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) { block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) {

View File

@@ -3,7 +3,6 @@
#include <numeric> #include <numeric>
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/cast_op.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh"
#include <cooperative_groups.h> #include <cooperative_groups.h>
@@ -113,7 +112,7 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
in + k * size + r * (block.size() * N), in + k * size + r * (block.size() * N),
vals[k]); vals[k]);
for (int j = 0; j < N; j++) { for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], __cast<U, T>(vals[k][j])); accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
} }
} }
} }
@@ -125,7 +124,7 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
in + k * size + r * (block.size() * N), in + k * size + r * (block.size() * N),
vals[k]); vals[k]);
for (int j = 0; j < N; j++) { for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], __cast<U, T>(vals[k][j])); accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
} }
} }
} }
@@ -138,9 +137,9 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
in + k * size + final_offset, in + k * size + final_offset,
vals[k], vals[k],
size, size,
__cast<T, U>(init)); cast_to<T>(init));
for (int j = 0; j < N; j++) { for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], __cast<U, T>(vals[k][j])); accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
} }
} }
} }
@@ -199,7 +198,7 @@ __global__ void row_reduce_looped(
in + loop.location() + r * BLOCK_DIM * N_READS, in + loop.location() + r * BLOCK_DIM * N_READS,
vals); vals);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], __cast<U, T>(vals[i])); total[0] = op(total[0], cast_to<U>(vals[i]));
} }
} }
if (final_offset < args.row_size) { if (final_offset < args.row_size) {
@@ -209,9 +208,9 @@ __global__ void row_reduce_looped(
in + loop.location() + final_offset, in + loop.location() + final_offset,
vals, vals,
args.row_size - final_offset, args.row_size - final_offset,
__cast<T, U>(init)); cast_to<T>(init));
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], __cast<U, T>(vals[i])); total[0] = op(total[0], cast_to<U>(vals[i]));
} }
} }
// TODO: Maybe block.sync() here? // TODO: Maybe block.sync() here?
@@ -245,34 +244,32 @@ void row_reduce_simple(
// 2 passes. Something like 32 * out.size() and then do a warp reduce. // 2 passes. Something like 32 * out.size() and then do a warp reduce.
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_all_types(in.dtype(), [&](auto type_tag) { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { using OP = MLX_GET_TYPE(reduce_type_tag);
using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using U = typename cu::ReduceResult<OP, T>::type;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh) // Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>()); T* indata = const_cast<T*>(in.data<T>());
// Calculate the grid and block dims // Calculate the grid and block dims
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS; size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
int threads = std::min(1024UL, reductions); int threads = std::min(1024UL, reductions);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
dim3 block(threads, 1, 1); dim3 block(threads, 1, 1);
// Pick the kernel // Pick the kernel
auto kernel = cu::row_reduce_simple<T, U, OP, N_READS>; auto kernel = cu::row_reduce_simple<T, U, OP, N_READS>;
if (grid.x >= 1024) { if (grid.x >= 1024) {
grid.x = (grid.x + 1) / 2; grid.x = (grid.x + 1) / 2;
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>; kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
} }
// Launch int size = plan.shape.back();
kernel<<<grid, block, 0, stream>>>( encoder.add_kernel_node(
indata, out.data<U>(), out.size(), plan.shape.back()); kernel, grid, block, indata, out.data<U>(), out.size(), size);
});
}); });
}); });
} }
@@ -293,43 +290,39 @@ void row_reduce_looped(
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_all_types(in.dtype(), [&](auto type_tag) { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { using OP = MLX_GET_TYPE(reduce_type_tag);
using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using U = typename cu::ReduceResult<OP, T>::type;
using U = typename cu::ReduceResult<OP, T>::type; // Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
// Cub doesn't like const pointers for vectorized loads. (sigh) // Calculate the grid and block dims
T* indata = const_cast<T*>(in.data<T>()); args.sort_access_pattern(in, axes);
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
int threads = std::min(1024UL, reductions);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
dim3 block(threads, 1, 1);
// Calculate the grid and block dims // Pick the kernel
args.sort_access_pattern(in, axes); auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
size_t reductions = (args.row_size + N_READS - 1) / N_READS; dispatch_block_dim(threads, [&](auto threads_constant) {
int threads = std::min(1024UL, reductions); kernel = cu::row_reduce_looped<
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; T,
dim3 block(threads, 1, 1); U,
OP,
// Pick the kernel reduce_ndim.value,
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>; threads_constant.value,
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { N_READS>;
dispatch_block_dim(threads, [&](auto threads_constant) { block.x = threads_constant.value;
kernel = cu::row_reduce_looped<
T,
U,
OP,
reduce_ndim.value,
threads_constant.value,
N_READS>;
block.x = threads_constant.value;
});
}); });
// Launch
kernel<<<grid, block, 0, stream>>>(
indata, out.data<U>(), out.size(), args);
}); });
encoder.add_kernel_node(
kernel, grid, block, indata, out.data<U>(), out.size(), args);
}); });
}); });
} }

View File

@@ -74,7 +74,7 @@ __global__ void rms_norm(
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank(); auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS]; T xn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size, 0); cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
for (int i = 0; i < N_READS; ++i) { for (int i = 0; i < N_READS; ++i) {
float t = static_cast<float>(xn[i]); float t = static_cast<float>(xn[i]);
normalizer += t * t; normalizer += t * t;
@@ -130,7 +130,7 @@ __global__ void rms_norm_vjp(
T wn[N_READS] = {}; T wn[N_READS] = {};
T gn[N_READS] = {}; T gn[N_READS] = {};
auto index = r * BLOCK_DIM + block.thread_rank(); auto index = r * BLOCK_DIM + block.thread_rank();
cub::LoadDirectBlocked(index, x, xn, axis_size, 0); cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
cub::LoadDirectBlocked(index, g, gn, axis_size); cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
@@ -224,21 +224,21 @@ void RMSNorm::eval_gpu(
encoder.set_input_array(x); encoder.set_input_array(x);
encoder.set_input_array(w); encoder.set_input_array(w);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) { constexpr uint32_t N_READS = 4;
constexpr uint32_t N_READS = 4; dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
dispatch_block_dim( using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>;
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; encoder.add_kernel_node(
auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>; kernel,
kernel<<<n_rows, block_dim(), 0, stream>>>( n_rows,
x.data<DataType>(), block_dim(),
w.data<DataType>(), x.data<DataType>(),
out.data<DataType>(), w.data<DataType>(),
eps_, out.data<DataType>(),
axis_size, eps_,
w_stride); axis_size,
}); w_stride);
}); });
}); });
} }
@@ -253,20 +253,24 @@ void RMSNormVJP::eval_gpu(
// Ensure row contiguity. We could relax this step by checking that the array // Ensure row contiguity. We could relax this step by checking that the array
// is contiguous (no broadcasts or holes) and that the input strides are the // is contiguous (no broadcasts or holes) and that the input strides are the
// same as the cotangent strides but for now this is simpler. // same as the cotangent strides but for now this is simpler.
auto check_input = [&s](const array& x) -> std::pair<array, bool> { auto check_input = [&s](const array& x, bool& copied) {
if (x.flags().row_contiguous) { if (x.flags().row_contiguous) {
return {x, false}; copied = false;
return x;
} }
copied = true;
array x_copy(x.shape(), x.dtype(), nullptr, {}); array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s); copy_gpu(x, x_copy, CopyType::General, s);
return {x_copy, true}; return x_copy;
}; };
bool donate_x = inputs[0].is_donatable(); bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[2].is_donatable(); bool donate_g = inputs[2].is_donatable();
auto [x, copied] = check_input(inputs[0]); bool copied;
auto x = check_input(inputs[0], copied);
donate_x |= copied; donate_x |= copied;
const array& w = inputs[1]; const array& w = inputs[1];
auto [g, g_copied] = check_input(inputs[2]); bool g_copied;
auto g = check_input(inputs[2], g_copied);
donate_g |= g_copied; donate_g |= g_copied;
array& gx = outputs[0]; array& gx = outputs[0];
array& gw = outputs[1]; array& gw = outputs[1];
@@ -310,30 +314,31 @@ void RMSNormVJP::eval_gpu(
encoder.set_input_array(g); encoder.set_input_array(g);
encoder.set_output_array(gx); encoder.set_output_array(gx);
encoder.set_output_array(gw_temp); encoder.set_output_array(gw_temp);
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) {
dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) { dispatch_bool(has_w, [&](auto has_w_constant) {
dispatch_bool(has_w, [&](auto has_w_constant) { constexpr int N_READS = 4;
constexpr int N_READS = 4; dispatch_block_dim(
dispatch_block_dim( cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; constexpr int N_READS = 4;
constexpr int N_READS = 4; auto kernel = cu::rms_norm_vjp<
auto kernel = cu::rms_norm_vjp< DataType,
DataType, has_w_constant.value,
has_w_constant.value, block_dim(),
block_dim(), N_READS>;
N_READS>; encoder.add_kernel_node(
kernel<<<n_rows, block_dim(), 0, stream>>>( kernel,
x.data<DataType>(), n_rows,
w.data<DataType>(), block_dim(),
g.data<DataType>(), x.data<DataType>(),
gx.data<DataType>(), w.data<DataType>(),
gw_temp.data<DataType>(), g.data<DataType>(),
eps_, gx.data<DataType>(),
axis_size, gw_temp.data<DataType>(),
w_stride); eps_,
}); axis_size,
}); w_stride);
});
}); });
}); });

View File

@@ -308,76 +308,89 @@ void RoPE::eval_gpu(
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(donated ? out : in); encoder.set_input_array(donated ? out : in);
encoder.set_input_array(offset); encoder.set_input_array(offset);
if (with_freqs) {
encoder.set_input_array(inputs[2]);
}
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) {
dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) { dispatch_bool(traditional_, [&](auto traditional) {
dispatch_bool(traditional_, [&](auto traditional) { dispatch_bool(forward_, [&](auto forward) {
dispatch_bool(forward_, [&](auto forward) { using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; if (single && !with_freqs) {
if (single && !with_freqs) { auto kernel =
auto kernel = cu::rope_single<DataType, traditional.value, forward.value>;
cu::rope_single<DataType, traditional.value, forward.value>; uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); encoder.add_kernel_node(
kernel<<<grid, block, 0, stream>>>( kernel,
(donated ? out : in).data<DataType>(), grid,
out.data<DataType>(), block,
offset.data<int32_t>(), (donated ? out : in).data<DataType>(),
scale_, out.data<DataType>(),
std::log2(base_), offset.data<int32_t>(),
mat_size, scale_,
dims); std::log2(base_),
} else if (single) { mat_size,
auto kernel = cu:: dims);
rope_single_freqs<DataType, traditional.value, forward.value>; } else if (single) {
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); auto kernel =
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); cu::rope_single_freqs<DataType, traditional.value, forward.value>;
kernel<<<grid, block, 0, stream>>>( uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
(donated ? out : in).data<DataType>(), auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
out.data<DataType>(), encoder.add_kernel_node(
offset.data<int32_t>(), kernel,
inputs[2].data<float>(), grid,
scale_, block,
mat_size, (donated ? out : in).data<DataType>(),
dims, out.data<DataType>(),
inputs[2].strides(0)); offset.data<int32_t>(),
} else if (with_freqs) { inputs[2].data<float>(),
auto kernel = scale_,
cu::rope_freqs<DataType, traditional.value, forward.value>; mat_size,
uint3 dims = dims,
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); inputs[2].strides(0));
dims.z = (dims.z + 3) / 4; } else if (with_freqs) {
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); auto kernel =
kernel<<<grid, block, 0, stream>>>( cu::rope_freqs<DataType, traditional.value, forward.value>;
(donated ? out : in).data<DataType>(), uint3 dims =
out.data<DataType>(), make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
offset.data<int32_t>(), dims.z = (dims.z + 3) / 4;
inputs[2].data<float>(), auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
scale_, encoder.add_kernel_node(
std::log2(base_), kernel,
strides, grid,
out_strides, block,
in.size() / mat_size, (donated ? out : in).data<DataType>(),
dims, out.data<DataType>(),
inputs[2].strides(0)); offset.data<int32_t>(),
} else { inputs[2].data<float>(),
auto kernel = cu::rope<DataType, traditional.value, forward.value>; scale_,
uint3 dims = std::log2(base_),
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); strides,
dims.z = (dims.z + 3) / 4; out_strides,
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); in.size() / mat_size,
kernel<<<grid, block, 0, stream>>>( dims,
(donated ? out : in).data<DataType>(), inputs[2].strides(0));
out.data<DataType>(), } else {
offset.data<int32_t>(), auto kernel = cu::rope<DataType, traditional.value, forward.value>;
scale_, uint3 dims =
std::log2(base_), make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
strides, dims.z = (dims.z + 3) / 4;
out_strides, auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
in.size() / mat_size, encoder.add_kernel_node(
dims); kernel,
} grid,
}); block,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
scale_,
std::log2(base_),
strides,
out_strides,
in.size() / mat_size,
dims);
}
}); });
}); });
}); });

View File

@@ -43,7 +43,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
// Thread reduce. // Thread reduce.
AccT prevmax; AccT prevmax;
AccT maxval = Limits<AccT>::finite_min(); AccT maxval = Limits<AccT>::finite_min();
AccT normalizer = 0; AccT normalizer = cast_to<AccT>(0);
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
AccT vals[N_READS]; AccT vals[N_READS];
cub::LoadDirectBlocked( cub::LoadDirectBlocked(
@@ -141,19 +141,21 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) {
dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) { constexpr int N_READS = 4;
constexpr int N_READS = 4; dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
dispatch_block_dim( using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { auto kernel = cu::softmax<DataType, DataType, block_dim(), N_READS>;
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; if (precise) {
auto kernel = cu::softmax<DataType, DataType, block_dim(), N_READS>; kernel = cu::softmax<DataType, float, block_dim(), N_READS>;
if (precise) { }
kernel = cu::softmax<DataType, float, block_dim(), N_READS>; encoder.add_kernel_node(
} kernel,
kernel<<<n_rows, block_dim(), 0, stream>>>( n_rows,
in.data<DataType>(), out.data<DataType>(), axis_size); block_dim(),
}); in.data<DataType>(),
out.data<DataType>(),
axis_size);
}); });
}); });
} }

View File

@@ -50,32 +50,6 @@ array swapaxes_in_eval(const array& in, int axis1, int axis2) {
return out; return out;
} }
template <typename... Args>
void segmented_sort_pairs(cu::CommandEncoder& encoder, Args&&... args) {
// Allocate temporary storage.
size_t size;
CHECK_CUDA_ERROR(
cub::DeviceSegmentedSort::StableSortPairs(nullptr, size, args...));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Run op.
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
temp.data<void>(), size, args...));
}
template <typename... Args>
void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
// Allocate temporary storage.
size_t size;
CHECK_CUDA_ERROR(
cub::DeviceSegmentedSort::StableSortKeys(nullptr, size, args...));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Run op.
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
temp.data<void>(), size, args...));
}
struct OffsetTransform { struct OffsetTransform {
int nsort; int nsort;
@@ -113,57 +87,94 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_all_types(in.dtype(), [&](auto type_tag) { using CTYPE = MLX_GET_TYPE(type_tag);
using CTYPE = MLX_GET_TYPE(type_tag); auto& stream = encoder.stream();
if constexpr (!std::is_same_v<CTYPE, complex64_t>) { if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
using Type = cuda_type_t<CTYPE>; using Type = cuda_type_t<CTYPE>;
auto offsets = thrust::make_transform_iterator( auto offsets = thrust::make_transform_iterator(
thrust::make_counting_iterator(0), OffsetTransform{nsort}); thrust::make_counting_iterator(0), OffsetTransform{nsort});
if (argsort) { if (argsort) {
// Indices in the sorted dimension. // Indices in the sorted dimension.
array indices( array indices(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
allocator::malloc(out.nbytes()), in.shape(), out.dtype()); encoder.add_temporary(indices);
encoder.add_temporary(indices);
thrust::transform(
cu::thrust_policy(stream),
thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(indices.data_size()),
thrust::device_pointer_cast(indices.data<uint32_t>()),
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
// In argsort though we don't need the result of sorted values, the // In argsort though we don't need the result of sorted values, the
// API requires us to provide an array to store it. // API requires us to provide an array to store it.
array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype()); array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype());
encoder.add_temporary(discard); encoder.add_temporary(discard);
segmented_sort_pairs( size_t size;
encoder, CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
in.data<Type>(), nullptr,
discard.data<Type>(), size,
indices.data<uint32_t>(), in.data<Type>(),
out.data<uint32_t>(), discard.data<Type>(),
in.data_size(), indices.data<uint32_t>(),
in.data_size() / nsort, out.data<uint32_t>(),
offsets, in.data_size(),
offsets + 1, in.data_size() / nsort,
stream); offsets,
} else { offsets + 1,
segmented_sort( stream));
encoder,
in.data<Type>(), array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
out.data<Type>(), encoder.add_temporary(temp);
in.data_size(),
in.data_size() / nsort, // Start capturing after allocations
offsets, auto capture = encoder.capture_context();
offsets + 1, thrust::transform(
stream); cu::thrust_policy(stream),
} thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(indices.data_size()),
thrust::device_pointer_cast(indices.data<uint32_t>()),
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
temp.data<void>(),
size,
in.data<Type>(),
discard.data<Type>(),
indices.data<uint32_t>(),
out.data<uint32_t>(),
in.data_size(),
in.data_size() / nsort,
offsets,
offsets + 1,
stream));
} else { } else {
throw std::runtime_error( size_t size;
"CUDA backend does not support sorting complex numbers"); CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
nullptr,
size,
in.data<Type>(),
out.data<Type>(),
in.data_size(),
in.data_size() / nsort,
offsets,
offsets + 1,
stream));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Start capturing after allocations
auto capture = encoder.capture_context();
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
temp.data<void>(),
size,
in.data<Type>(),
out.data<Type>(),
in.data_size(),
in.data_size() / nsort,
offsets,
offsets + 1,
stream));
} }
}); } else {
throw std::runtime_error(
"CUDA backend does not support sorting complex numbers");
}
}); });
if (!is_segmented_sort) { if (!is_segmented_sort) {

View File

@@ -91,73 +91,80 @@ void ternary_op_gpu_inplace(
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_input_array(c); encoder.set_input_array(c);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(out.dtype(), [&](auto type_tag) {
dispatch_all_types(out.dtype(), [&](auto type_tag) { using DType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using DType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto topt = get_ternary_op_type(a, b, c); auto topt = get_ternary_op_type(a, b, c);
if (topt == TernaryOpType::General) { if (topt == TernaryOpType::General) {
dispatch_bool( dispatch_bool(
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX, c.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
[&](auto large) { [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, int32_t>; using IdxT = std::conditional_t<large(), int64_t, int32_t>;
Shape shape; Shape shape;
std::vector<Strides> strides; std::vector<Strides> strides;
std::tie(shape, strides) = collapse_contiguous_dims(a, b, c, out); std::tie(shape, strides) = collapse_contiguous_dims(a, b, c, out);
auto& a_strides = strides[0]; auto& a_strides = strides[0];
auto& b_strides = strides[1]; auto& b_strides = strides[1];
auto& c_strides = strides[2]; auto& c_strides = strides[2];
int ndim = shape.size(); int ndim = shape.size();
if (ndim <= 3) { if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = auto kernel =
cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>; cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
out.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides),
const_param<dims_constant()>(c_strides));
});
} else {
auto kernel = cu::ternary_g<Op, DType, IdxT>;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large()); get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>( encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
a.data<bool>(), a.data<bool>(),
b.data<DType>(), b.data<DType>(),
c.data<DType>(), c.data<DType>(),
out.data<DType>(), out.data<DType>(),
out.data_size(), out.size(),
const_param(shape), const_param<dims_constant()>(shape),
const_param(a_strides), const_param<dims_constant()>(a_strides),
const_param(b_strides), const_param<dims_constant()>(b_strides),
const_param(c_strides), const_param<dims_constant()>(c_strides));
ndim); });
} } else {
}); auto kernel = cu::ternary_g<Op, DType, IdxT>;
} else { auto [num_blocks, block_dims] =
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { get_launch_args(kernel, out, large());
using IdxT = std::conditional_t<large(), int64_t, uint32_t>; encoder.add_kernel_node(
auto kernel = cu::ternary_v<Op, DType, IdxT>; kernel,
auto [num_blocks, block_dims] = get_launch_args( num_blocks,
kernel, out.data_size(), out.shape(), out.strides(), large()); block_dims,
kernel<<<num_blocks, block_dims, 0, stream>>>( a.data<bool>(),
a.data<bool>(), b.data<DType>(),
b.data<DType>(), c.data<DType>(),
c.data<DType>(), out.data<DType>(),
out.data<DType>(), out.data_size(),
out.data_size()); const_param(shape),
}); const_param(a_strides),
} const_param(b_strides),
}); const_param(c_strides),
ndim);
}
});
} else {
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
auto kernel = cu::ternary_v<Op, DType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), large());
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
out.data_size());
});
}
}); });
} }

View File

@@ -9,14 +9,38 @@
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
namespace mlx::core { namespace mlx::core {
namespace cu { namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void unary_v(const In* in, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(in[index]);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void unary_g(
const In* in,
Out* out,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto idx = elem_to_loc_4d(index, shape.data(), strides.data(), ndim);
out[index] = Op{}(in[idx]);
}
}
template <typename Op, typename In, typename Out> template <typename Op, typename In, typename Out>
constexpr bool supports_unary_op() { constexpr bool supports_unary_op() {
if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> || if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||
@@ -71,38 +95,61 @@ void unary_op_gpu_inplace(
if (in.size() == 0) { if (in.size() == 0) {
return; return;
} }
bool contig = in.flags().contiguous;
bool large;
if (!contig) {
large = in.data_size() > INT32_MAX || out.size() > INT32_MAX;
} else {
large = in.data_size() > UINT32_MAX;
}
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) { using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_IN = MLX_GET_TYPE(in_type_tag); using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) { dispatch_bool(large, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
using InType = cuda_type_t<CTYPE_IN>; using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>; using OutType = cuda_type_t<CTYPE_OUT>;
auto policy = cu::thrust_policy(stream); using IdxT = std::conditional_t<large(), int64_t, int32_t>;
auto in_ptr = thrust::device_pointer_cast(in.data<InType>()); if (contig) {
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>()); auto kernel = cu::unary_v<Op, InType, OutType, IdxT>;
if (in.flags().contiguous) { auto [num_blocks, block_dims] = get_launch_args(
thrust::transform( kernel, out.data_size(), out.shape(), out.strides(), large);
policy, in_ptr, in_ptr + in.data_size(), out_ptr, Op()); encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
in.data<InType>(),
out.data<OutType>(),
out.data_size());
} else { } else {
auto [shape, strides] = collapse_contiguous_dims(in); auto [shape, strides] = collapse_contiguous_dims(in);
auto [in_begin, in_end] = cu::make_general_iterators<int64_t>( auto kernel = cu::unary_g<Op, InType, OutType, IdxT>;
in_ptr, in.size(), shape, strides); auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
thrust::transform(policy, in_begin, in_end, out_ptr, Op()); encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
in.data<InType>(),
out.data<OutType>(),
out.data_size(),
const_param(shape),
const_param(strides),
shape.size());
} }
} else { });
throw std::runtime_error(fmt::format( } else {
"Can not do unary op {} on input of {} with output of {}.", throw std::runtime_error(fmt::format(
op, "Can not do unary op {} on input of {} with output of {}.",
dtype_to_string(in.dtype()), op,
dtype_to_string(out.dtype()))); dtype_to_string(in.dtype()),
} dtype_to_string(out.dtype())));
}); }
}); });
}); });
} }

View File

@@ -24,6 +24,14 @@ void check_cuda_error(const char* name, cudaError_t err) {
} }
} }
void check_cuda_error(const char* name, CUresult err) {
if (err != CUDA_SUCCESS) {
const char* err_str = "Unknown error";
cuGetErrorString(err, &err_str);
throw std::runtime_error(fmt::format("{} failed: {}", name, err_str));
}
}
const char* dtype_to_cuda_type(const Dtype& dtype) { const char* dtype_to_cuda_type(const Dtype& dtype) {
switch (dtype) { switch (dtype) {
case bool_: case bool_:

View File

@@ -4,6 +4,7 @@
#pragma once #pragma once
#include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
namespace mlx::core { namespace mlx::core {
@@ -33,6 +34,7 @@ class CudaStream {
// Throw exception if the cuda API does not succeed. // Throw exception if the cuda API does not succeed.
void check_cuda_error(const char* name, cudaError_t err); void check_cuda_error(const char* name, cudaError_t err);
void check_cuda_error(const char* name, CUresult err);
// The macro version that prints the command that failed. // The macro version that prints the command that failed.
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) #define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))

View File

@@ -31,6 +31,7 @@ inline void threadgroup_sum(
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
x[i] = simd_sum(x[i]); x[i] = simd_sum(x[i]);
} }
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_lane_id == 0) { if (simd_lane_id == 0) {
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
xs[N * simd_group_id + i] = x[i]; xs[N * simd_group_id + i] = x[i];

View File

@@ -688,7 +688,7 @@ array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) {
perm = expand_dims(perm, -1, s); perm = expand_dims(perm, -1, s);
take_axis -= 1; take_axis -= 1;
} }
auto pb = take_along_axis(b, perm, take_axis); auto pb = take_along_axis(b, perm, take_axis, s);
auto y = solve_triangular(luf[1], pb, /* upper = */ false, s); auto y = solve_triangular(luf[1], pb, /* upper = */ false, s);
return solve_triangular(luf[2], y, /* upper = */ true, s); return solve_triangular(luf[2], y, /* upper = */ true, s);
} }

View File

@@ -4,7 +4,7 @@
#define MLX_VERSION_MAJOR 0 #define MLX_VERSION_MAJOR 0
#define MLX_VERSION_MINOR 26 #define MLX_VERSION_MINOR 26
#define MLX_VERSION_PATCH 2 #define MLX_VERSION_PATCH 3
#define MLX_VERSION_NUMERIC \ #define MLX_VERSION_NUMERIC \
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)

View File

@@ -53,11 +53,7 @@ class CMakeBuild(build_ext):
# Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level
# across all generators. # across all generators.
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
# self.parallel is a Python 3 only way to set parallel jobs by hand build_args += [f"-j{os.cpu_count()}"]
# using -j in the build_ext call, not supported by pip or PyPA-build.
if hasattr(self, "parallel") and self.parallel:
# CMake 3.12+ only.
build_args += [f"-j{self.parallel}"]
build_temp = Path(self.build_temp) / ext.name build_temp = Path(self.build_temp) / ext.name
if not build_temp.exists(): if not build_temp.exists():

View File

@@ -175,11 +175,12 @@ void init_fast(nb::module_& parent_module) {
* `Grouped Query Attention <https://arxiv.org/abs/2305.13245>`_ * `Grouped Query Attention <https://arxiv.org/abs/2305.13245>`_
* `Multi-Query Attention <https://arxiv.org/abs/1911.02150>`_ * `Multi-Query Attention <https://arxiv.org/abs/1911.02150>`_
Note: The softmax operation is performed in ``float32`` regardless of .. note::
the input precision.
Note: For Grouped Query Attention and Multi-Query Attention, the ``k`` * The softmax operation is performed in ``float32`` regardless of
and ``v`` inputs should not be pre-tiled to match ``q``. the input precision.
* For Grouped Query Attention and Multi-Query Attention, the ``k``
and ``v`` inputs should not be pre-tiled to match ``q``.
In the following the dimensions are given by: In the following the dimensions are given by:
@@ -195,13 +196,30 @@ void init_fast(nb::module_& parent_module) {
k (array): Keys with shape ``[B, N_kv, T_kv, D]``. k (array): Keys with shape ``[B, N_kv, T_kv, D]``.
v (array): Values with shape ``[B, N_kv, T_kv, D]``. v (array): Values with shape ``[B, N_kv, T_kv, D]``.
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
mask (Union[None, str, array], optional): A causal, boolean or additive mask (Union[None, str, array], optional): The mask to apply to the
mask to apply to the query-key scores. The mask can have at most 4 query-key scores. The mask can be an array or a string indicating
dimensions and must be broadcast-compatible with the shape the mask type. The only supported string type is ``"causal"``. If
``[B, N, T_q, T_kv]``. If an additive mask is given its type must the mask is an array it can be a boolean or additive mask. The mask
promote to the promoted type of ``q``, ``k``, and ``v``. can have at most 4 dimensions and must be broadcast-compatible with
the shape ``[B, N, T_q, T_kv]``. If an additive mask is given its
type must promote to the promoted type of ``q``, ``k``, and ``v``.
Returns: Returns:
array: The output array. array: The output array.
Example:
.. code-block:: python
B = 2
N_q = N_kv = 32
T_q = T_kv = 1000
D = 128
q = mx.random.normal(shape=(B, N_q, T_q, D))
k = mx.random.normal(shape=(B, N_kv, T_kv, D))
v = mx.random.normal(shape=(B, N_kv, T_kv, D))
scale = D ** -0.5
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask="causal")
)pbdoc"); )pbdoc");
m.def( m.def(

View File

@@ -391,9 +391,11 @@ class TestLoad(mlx_tests.MLXTestCase):
scale = mx.array(2.0) scale = mx.array(2.0)
y = mx.load(save_file) y = mx.load(save_file)
mx.eval(y) mx.eval(y)
mx.synchronize()
load_only = mx.get_peak_memory() load_only = mx.get_peak_memory()
y = mx.load(save_file) * scale y = mx.load(save_file) * scale
mx.eval(y) mx.eval(y)
mx.synchronize()
load_with_binary = mx.get_peak_memory() load_with_binary = mx.get_peak_memory()
self.assertEqual(load_only, load_with_binary) self.assertEqual(load_only, load_with_binary)

View File

@@ -97,11 +97,7 @@ class CMakeBuild(build_ext):
# Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level
# across all generators. # across all generators.
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
# self.parallel is a Python 3 only way to set parallel jobs by hand build_args += [f"-j{os.cpu_count()}"]
# using -j in the build_ext call, not supported by pip or PyPA-build.
if hasattr(self, "parallel") and self.parallel:
# CMake 3.12+ only.
build_args += [f"-j{self.parallel}"]
build_temp = Path(self.build_temp) / ext.name build_temp = Path(self.build_temp) / ext.name
if not build_temp.exists(): if not build_temp.exists():