Compare commits

...

6 Commits

Author SHA1 Message Date
Gaétan Lepage
0002e0083d
Merge d2e0b0465c into d7e680ffe4 2025-06-11 21:34:56 -04:00
Cheng
d7e680ffe4
CUDA backend: layernorm (#2271) 2025-06-11 15:48:32 -07:00
Cheng
c371baf53a
CUDA backend: softmax (#2272) 2025-06-11 13:55:22 -07:00
Cheng
ccf78f566c
CUDA backend: argreduce (#2270) 2025-06-11 13:26:17 -07:00
Cheng
c9fa68664a
CUDA backend: reduce (#2269) 2025-06-11 11:22:25 -07:00
Gaetan Lepage
d2e0b0465c Feat: add USE_SYSTEM_FMT CMake option 2025-05-23 15:19:48 +02:00
16 changed files with 2112 additions and 12 deletions

View File

@ -42,6 +42,7 @@ option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
# --------------------- Processor tests -------------------------
message(
@ -234,12 +235,16 @@ target_include_directories(
# Do not add mlx_EXPORTS define for shared library.
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
if(USE_SYSTEM_FMT)
find_package(fmt REQUIRED)
else()
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
endif()
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS)

View File

@ -6,6 +6,7 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
@ -19,9 +20,16 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp

View File

@ -0,0 +1,189 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
#include <cub/block/block_load.cuh>
#include <cub/block/block_reduce.cuh>
#include <cassert>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T>
struct IndexValPair {
uint32_t index;
T val;
};
template <typename T>
struct ArgMin {
constexpr __device__ T init() {
return Limits<T>::max();
}
__device__ IndexValPair<T> operator()(
const IndexValPair<T>& best,
const IndexValPair<T>& current) {
if (best.val > current.val ||
(best.val == current.val && best.index > current.index)) {
return current;
} else {
return best;
}
}
template <int N>
__device__ IndexValPair<T>
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
for (int i = 0; i < N; i++) {
if (vals[i] < best.val) {
best.val = vals[i];
best.index = offset + i;
}
}
return best;
}
};
template <typename T>
struct ArgMax {
constexpr __device__ T init() {
return Limits<T>::min();
}
__device__ IndexValPair<T> operator()(
const IndexValPair<T>& best,
const IndexValPair<T>& current) {
if (best.val < current.val ||
(best.val == current.val && best.index > current.index)) {
return current;
} else {
return best;
}
}
template <int N>
__device__ IndexValPair<T>
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
for (int i = 0; i < N; i++) {
if (vals[i] > best.val) {
best.val = vals[i];
best.index = offset + i;
}
}
return best;
}
};
template <typename T, typename Op, int BLOCK_DIM, int N_READS = 4>
__global__ void arg_reduce_general(
const T* in,
uint32_t* out,
size_t size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides in_strides,
const __grid_constant__ Strides out_strides,
int32_t ndim,
int64_t axis_stride,
int32_t axis_size) {
auto block = cg::this_thread_block();
int64_t index = cg::this_grid().block_rank();
if (index >= size) {
return;
}
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
Op op;
T init = op.init();
IndexValPair<T> best{0, init};
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T vals[N_READS];
auto tid = r * BLOCK_DIM + block.thread_index().z;
cub::LoadDirectBlocked(
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
best = op.reduce_many(best, vals, tid * N_READS);
}
typedef cub::BlockReduce<IndexValPair<T>, BLOCK_DIM> BlockReduceT;
__shared__ typename BlockReduceT::TempStorage temp;
best = BlockReduceT(temp).Reduce(best, op);
if (block.thread_rank() == 0) {
out[out_idx] = best.index;
}
}
} // namespace cu
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("ArgReduce::eval_gpu");
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
// Prepare the shapes, strides and axis arguments.
Shape shape = remove_index(in.shape(), axis_);
Strides in_strides = remove_index(in.strides(), axis_);
Strides out_strides = out.ndim() == in.ndim()
? remove_index(out.strides(), axis_)
: out.strides();
int64_t axis_stride = in.strides()[axis_];
int32_t axis_size = in.shape()[axis_];
int32_t ndim = shape.size();
// ArgReduce.
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_REAL_TYPES_CHECKED(in.dtype(), "ArgReduce", CTYPE, {
using InType = cuda_type_t<CTYPE>;
constexpr uint32_t N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
dim3 block_dims{1, 1, BLOCK_DIM};
auto kernel = &cu::arg_reduce_general<
InType,
cu::ArgMax<InType>,
BLOCK_DIM,
N_READS>;
if (reduce_type_ == ArgReduce::ArgMin) {
kernel = &cu::arg_reduce_general<
InType,
cu::ArgMin<InType>,
BLOCK_DIM,
N_READS>;
}
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>(),
out.data<uint32_t>(),
out.size(),
const_param(shape),
const_param(in_strides),
const_param(out_strides),
ndim,
axis_stride,
axis_size);
});
});
});
}
} // namespace mlx::core

View File

@ -0,0 +1,60 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <thrust/iterator/iterator_adaptor.h>
#include <thrust/iterator/iterator_facade.h>
namespace mlx::core::cu {
// RandomAccessIterator for strided access to array entries.
template <typename Iterator, typename Stride = int64_t>
class strided_iterator
: public thrust::
iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator> {
public:
using super_t =
thrust::iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator>;
using reference = typename super_t::reference;
using difference_type = typename super_t::difference_type;
__host__ __device__ strided_iterator(Iterator it, Stride stride)
: super_t(it), stride_(stride) {}
__host__ __device__ Stride stride() const {
return stride_;
}
private:
friend class thrust::iterator_core_access;
__host__ __device__ bool equal(const strided_iterator& other) const {
return this->base() == other.base();
}
__host__ __device__ void advance(difference_type n) {
this->base_reference() += n * stride_;
}
__host__ __device__ void increment() {
this->base_reference() += stride_;
}
__host__ __device__ void decrement() {
this->base_reference() -= stride_;
}
__host__ __device__ difference_type
distance_to(const strided_iterator& other) const {
const difference_type dist = other.base() - this->base();
_CCCL_ASSERT(
dist % stride() == 0,
"Underlying iterator difference must be divisible by the stride");
return dist / stride();
}
Stride stride_;
};
} // namespace mlx::core::cu

View File

@ -47,6 +47,31 @@ namespace mlx::core {
__VA_ARGS__; \
}
// Convert a block_dim to constexpr between WARP_SIZE and WARP_SIZE ^ 2.
#define MLX_SWITCH_BLOCK_DIM(NUM_THREADS, BLOCK_DIM, ...) \
{ \
uint32_t _num_threads = NUM_THREADS; \
if (_num_threads <= WARP_SIZE) { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE; \
__VA_ARGS__; \
} else if (_num_threads <= WARP_SIZE * 2) { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 2; \
__VA_ARGS__; \
} else if (_num_threads <= WARP_SIZE * 4) { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 4; \
__VA_ARGS__; \
} else if (_num_threads <= WARP_SIZE * 8) { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 8; \
__VA_ARGS__; \
} else if (_num_threads <= WARP_SIZE * 16) { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 16; \
__VA_ARGS__; \
} else { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * WARP_SIZE; \
__VA_ARGS__; \
} \
}
// Maps CPU types to CUDA types.
template <typename T>
struct CTypeToCudaType {

View File

@ -9,6 +9,8 @@
#pragma once
#include <cuComplex.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda/std/array>
#include <cuda/std/limits>
#include <cuda/std/tuple>
@ -19,6 +21,10 @@ namespace mlx::core::cu {
// CUDA kernel utils
///////////////////////////////////////////////////////////////////////////////
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
// warpSize variable exists, using it would prevent compile-time optimizations.
#define WARP_SIZE 32
// To pass shape/strides to kernels via constant memory, their size must be
// known at compile time.
#define MAX_NDIM 8
@ -26,6 +32,94 @@ namespace mlx::core::cu {
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
///////////////////////////////////////////////////////////////////////////////
// Type limits utils
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename = void>
struct Limits {
static constexpr __host__ __device__ T max() {
return cuda::std::numeric_limits<T>::max();
}
static constexpr __host__ __device__ T min() {
return cuda::std::numeric_limits<T>::min();
}
static constexpr __host__ __device__ T finite_max() {
return cuda::std::numeric_limits<T>::max();
}
static constexpr __host__ __device__ T finite_min() {
return cuda::std::numeric_limits<T>::min();
}
};
template <typename T>
struct Limits<
T,
cuda::std::enable_if_t<
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double>>> {
static constexpr __host__ __device__ T max() {
return cuda::std::numeric_limits<T>::infinity();
}
static constexpr __host__ __device__ T min() {
return -cuda::std::numeric_limits<T>::infinity();
}
static constexpr __host__ __device__ T finite_max() {
return cuda::std::numeric_limits<T>::max();
}
static constexpr __host__ __device__ T finite_min() {
return cuda::std::numeric_limits<T>::lowest();
}
};
// CUDA 11 does not have host side arithmatic operators for half types.
template <typename T>
struct Limits<
T,
cuda::std::enable_if_t<
cuda::std::is_same_v<T, __half> ||
cuda::std::is_same_v<T, __nv_bfloat16>>> {
static constexpr __host__ __device__ T max() {
return cuda::std::numeric_limits<T>::infinity();
}
static constexpr __host__ __device__ T min() {
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
return -cuda::std::numeric_limits<T>::infinity();
#else
return -cuda::std::numeric_limits<float>::infinity();
#endif
}
static constexpr __host__ __device__ T finite_max() {
return cuda::std::numeric_limits<T>::max();
}
static constexpr __host__ __device__ T finite_min() {
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
return cuda::std::numeric_limits<T>::lowest();
#else
return cuda::std::numeric_limits<float>::lowest();
#endif
}
};
template <>
struct Limits<bool> {
static constexpr __host__ __device__ bool max() {
return true;
}
static constexpr __host__ __device__ bool min() {
return false;
}
};
template <>
struct Limits<cuComplex> {
static constexpr __host__ __device__ cuComplex max() {
return {Limits<float>::max(), Limits<float>::max()};
}
static constexpr __host__ __device__ cuComplex min() {
return {Limits<float>::min(), Limits<float>::min()};
}
};
///////////////////////////////////////////////////////////////////////////////
// Indexing utils
///////////////////////////////////////////////////////////////////////////////
@ -101,4 +195,108 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
return cuda::std::make_tuple(a_loc, b_loc);
}
///////////////////////////////////////////////////////////////////////////////
// Elem to loc in a loop utils
///////////////////////////////////////////////////////////////////////////////
template <int DIM, bool General = true, typename OffsetT = size_t>
struct LoopedElemToLoc {
int dim;
LoopedElemToLoc<DIM - 1, General, OffsetT> inner_looper;
OffsetT offset{0};
int index{0};
__device__ LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}
__device__ void next(const int* shape, const int64_t* strides) {
if (dim == 0) {
return;
}
index++;
offset += OffsetT(strides[dim - 1]);
if (index >= shape[dim - 1]) {
index = 0;
inner_looper.next(shape, strides);
offset = inner_looper.offset;
}
}
__device__ void next(int n, const int* shape, const int64_t* strides) {
if (dim == 0) {
return;
}
index += n;
offset += n * OffsetT(strides[dim - 1]);
if (index >= shape[dim - 1]) {
int extra = index - shape[dim - 1];
if (extra >= shape[dim - 1]) {
inner_looper.next(1 + extra / shape[dim - 1], shape, strides);
extra = extra % shape[dim - 1];
} else {
inner_looper.next(shape, strides);
}
index = 0;
offset = inner_looper.offset;
if (extra > 0) {
next(extra, shape, strides);
}
}
}
__device__ OffsetT location() {
return offset;
}
};
template <typename OffsetT>
struct LoopedElemToLoc<1, true, OffsetT> {
int dim;
OffsetT offset{0};
int index{0};
__device__ LoopedElemToLoc(int dim) : dim(dim) {}
__device__ void next(const int* shape, const int64_t* strides) {
index++;
if (dim > 1) {
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
} else {
offset += OffsetT(strides[0]);
}
}
__device__ void next(int n, const int* shape, const int64_t* strides) {
index += n;
if (dim > 1) {
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
} else {
offset = index * OffsetT(strides[0]);
}
}
__device__ OffsetT location() {
return offset;
}
};
template <typename OffsetT>
struct LoopedElemToLoc<1, false, OffsetT> {
OffsetT offset{0};
__device__ LoopedElemToLoc(int) {}
__device__ void next(const int*, const int64_t* strides) {
offset += OffsetT(strides[0]);
}
__device__ void next(int n, const int*, const int64_t* strides) {
offset += n * OffsetT(strides[0]);
}
__device__ OffsetT location() {
return offset;
}
};
} // namespace mlx::core::cu

View File

@ -0,0 +1,390 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
#include <cub/block/block_load.cuh>
#include <cub/block/block_reduce.cuh>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
inline __device__ float3 plus_f3(const float3& a, const float3& b) {
return {a.x + b.x, a.y + b.y, a.z + b.z};
}
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
template <typename T, int BLOCK_DIM>
struct BlockBroadcastReduce {
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
static_assert(BLOCK_DIM % WARP_SIZE == 0);
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
cg::thread_block& block;
TempStorage& temp;
template <typename Op>
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
auto warp = cg::tiled_partition<WARP_SIZE>(block);
T x = cg::reduce(warp, input, op);
if (warp.thread_rank() == 0) {
temp[warp.meta_group_rank()] = x;
}
block.sync();
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
: init_value;
return cg::reduce(warp, x, op);
}
__device__ T Sum(const T& input) {
return Reduce(input, cg::plus<T>{}, T{});
}
};
template <typename T, int BLOCK_DIM, int N_READS = 4>
__global__ void layer_norm(
const T* x,
const T* w,
const T* b,
T* out,
float eps,
int32_t axis_size,
int64_t w_stride,
int64_t b_stride) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM>;
__shared__ typename BlockReduceT::TempStorage temp;
x += grid.block_rank() * axis_size;
out += grid.block_rank() * axis_size;
// Sum.
float sum = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS] = {};
cub::LoadDirectBlocked(index, x, xn, axis_size);
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
}
sum = BlockReduceT{block, temp}.Sum(sum);
// Mean.
float mean = sum / axis_size;
// Normalizer.
float normalizer = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
for (int i = 0; i < N_READS; ++i) {
float t = static_cast<float>(xn[i]) - mean;
normalizer += t * t;
}
}
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
normalizer = rsqrt(normalizer / axis_size + eps);
// Outputs.
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS];
T wn[N_READS];
T bn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(b, b_stride), bn, axis_size);
for (int i = 0; i < N_READS; ++i) {
float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
}
cub::StoreDirectBlocked(index, out, xn, axis_size);
}
}
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
__global__ void layer_norm_vjp(
const T* x,
const T* w,
const T* g,
T* gx,
T* gw,
float eps,
int32_t axis_size,
int64_t w_stride) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
using BlockReduceF3 = BlockBroadcastReduce<float3, BLOCK_DIM>;
__shared__ union {
typename BlockReduceF::TempStorage f;
typename BlockReduceF3::TempStorage f3;
} temp;
x += grid.block_rank() * axis_size;
g += grid.block_rank() * axis_size;
gx += grid.block_rank() * axis_size;
gw += grid.block_rank() * axis_size;
// Sum.
float sum = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS] = {};
cub::LoadDirectBlocked(index, x, xn, axis_size);
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
}
sum = BlockReduceF{block, temp.f}.Sum(sum);
// Mean.
float mean = sum / axis_size;
// Normalizer.
float3 factors = {};
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T xn[N_READS];
T wn[N_READS] = {};
T gn[N_READS] = {};
auto index = r * BLOCK_DIM + block.thread_rank();
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) {
float t = static_cast<float>(xn[i]) - mean;
float wi = wn[i];
float gi = gn[i];
float wg = wi * gi;
factors = plus_f3(factors, {wg, wg * t, t * t});
}
}
factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {});
float meanwg = factors.x / axis_size;
float meanwgxc = factors.y / axis_size;
float normalizer2 = 1 / (factors.z / axis_size + eps);
float normalizer = sqrt(normalizer2);
// Outputs.
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS];
T wn[N_READS];
T gn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) {
float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
float wi = wn[i];
float gi = gn[i];
xn[i] = normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2;
if constexpr (HAS_W) {
wn[i] = gi * xi;
}
}
cub::StoreDirectBlocked(index, gx, xn, axis_size);
if constexpr (HAS_W) {
cub::StoreDirectBlocked(index, gw, wn, axis_size);
}
}
}
} // namespace cu
namespace fast {
bool LayerNorm::use_fallback(Stream s) {
return s.device == Device::cpu;
}
// TODO: There are duplicate code with backend/metal/normalization.cpp
void LayerNorm::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("LayerNorm::eval_gpu");
auto& s = stream();
auto& out = outputs[0];
// Make sure that the last dimension is contiguous.
auto set_output = [&s, &out](const array& x) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.is_donatable()) {
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
}
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
};
array o = set_output(inputs[0]);
const array& x = o.data_shared_ptr() ? o : out;
const array& w = inputs[1];
const array& b = inputs[2];
int32_t axis_size = x.shape().back();
int32_t n_rows = x.data_size() / axis_size;
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "layernorm", CTYPE, {
using DataType = cuda_type_t<CTYPE>;
constexpr uint32_t N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
auto kernel = cu::layer_norm<DataType, BLOCK_DIM, N_READS>;
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
x.data<DataType>(),
w.data<DataType>(),
b.data<DataType>(),
out.data<DataType>(),
eps_,
axis_size,
w_stride,
b_stride);
});
});
});
}
void LayerNormVJP::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("LayerNormVJP::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// Ensure row contiguity. We could relax this step by checking that the array
// is contiguous (no broadcasts or holes) and that the input strides are the
// same as the cotangent strides but for now this is simpler.
auto check_input = [&s](const array& x) -> std::pair<array, bool> {
if (x.flags().row_contiguous) {
return {x, false};
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return {x_copy, true};
};
bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[3].is_donatable();
auto [x, copied] = check_input(inputs[0]);
donate_x |= copied;
const array& w = inputs[1];
const array& b = inputs[2];
auto [g, g_copied] = check_input(inputs[3]);
donate_g |= g_copied;
array& gx = outputs[0];
array& gw = outputs[1];
array& gb = outputs[2];
// Check whether we had a weight.
bool has_w = w.ndim() != 0;
// Allocate space for the outputs.
bool g_in_gx = false;
if (donate_x) {
gx.copy_shared_buffer(x);
} else if (donate_g) {
gx.copy_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(allocator::malloc(gx.nbytes()));
}
if (g_copied && !g_in_gx) {
encoder.add_temporary(g);
}
int32_t axis_size = x.shape().back();
int32_t n_rows = x.data_size() / axis_size;
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
// Allocate a temporary to store the gradients for w and allocate the output
// gradient accumulators.
array gw_temp =
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
if (has_w) {
if (!g_in_gx && donate_g) {
gw_temp.copy_shared_buffer(g);
} else {
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
encoder.add_temporary(gw_temp);
}
}
gw.set_data(allocator::malloc(gw.nbytes()));
gb.set_data(allocator::malloc(gb.nbytes()));
// Finish with the gradient for b in case we had a b.
if (gb.ndim() == 1 && gb.size() == axis_size) {
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan);
}
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(g);
encoder.set_output_array(gx);
encoder.set_output_array(gw_temp);
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "layernorm_vjp", CTYPE, {
using DataType = cuda_type_t<CTYPE>;
constexpr int N_READS = 4;
MLX_SWITCH_BOOL(has_w, HAS_W, {
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
auto kernel = cu::layer_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>;
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
x.data<DataType>(),
w.data<DataType>(),
g.data<DataType>(),
gx.data<DataType>(),
gw_temp.data<DataType>(),
eps_,
axis_size,
w_stride);
});
});
});
});
if (has_w) {
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan);
}
}
} // namespace fast
} // namespace mlx::core

View File

@ -0,0 +1,159 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/kernels/cast_op.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
#include <cub/block/block_load.cuh>
#include <cassert>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T>
inline __device__ T softmax_exp(T x) {
// Softmax doesn't need high precision exponential cause x is gonna be in
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
return __expf(x);
}
template <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4>
__global__ void logsumexp(const T* in, T* out, int axis_size) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
in += grid.block_rank() * axis_size;
cg::greater<AccT> max_op;
cg::plus<AccT> plus_op;
// Thread reduce.
AccT prevmax;
AccT maxval = Limits<AccT>::finite_min();
AccT normalizer = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
AccT vals[N_READS];
cub::LoadDirectBlocked(
r * BLOCK_DIM + block.thread_rank(),
make_cast_iterator<AccT>(in),
vals,
axis_size,
Limits<AccT>::min());
prevmax = maxval;
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
// Online normalizer calculation for softmax:
// https://github.com/NVIDIA/online-softmax
normalizer = normalizer * softmax_exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) {
normalizer = normalizer + softmax_exp(vals[i] - maxval);
}
}
// First warp reduce.
prevmax = maxval;
maxval = cg::reduce(warp, maxval, max_op);
normalizer = normalizer * softmax_exp(prevmax - maxval);
normalizer = cg::reduce(warp, normalizer, plus_op);
__shared__ AccT local_max[WARP_SIZE];
__shared__ AccT local_normalizer[WARP_SIZE];
// Write to shared memory and do second warp reduce.
prevmax = maxval;
if (warp.thread_rank() == 0) {
local_max[warp.meta_group_rank()] = maxval;
}
block.sync();
maxval = warp.thread_rank() < warp.meta_group_size()
? local_max[warp.thread_rank()]
: Limits<AccT>::finite_min();
maxval = cg::reduce(warp, maxval, max_op);
normalizer = normalizer * softmax_exp(prevmax - maxval);
if (warp.thread_rank() == 0) {
local_normalizer[warp.meta_group_rank()] = normalizer;
}
block.sync();
normalizer = warp.thread_rank() < warp.meta_group_size()
? local_normalizer[warp.thread_rank()]
: AccT{};
normalizer = cg::reduce(warp, normalizer, plus_op);
// Write output.
if (block.thread_rank() == 0) {
out[grid.block_rank()] = isinf(maxval) ? maxval : log(normalizer) + maxval;
}
}
} // namespace cu
void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("LogSumExp::eval_gpu");
assert(inputs.size() == 1);
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// Make sure that the last dimension is contiguous.
auto ensure_contiguous = [&s, &encoder](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy);
return x_copy;
}
};
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
auto strides = in.strides();
for (auto& s : strides) {
s /= n;
}
bool col_contig = strides[0] == 1;
for (int i = 1; col_contig && i < strides.size(); ++i) {
col_contig &=
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
}
flags.col_contiguous = col_contig;
out.set_data(
allocator::malloc(in.nbytes() / n),
in.data_size() / n,
std::move(strides),
flags);
}
int axis_size = in.shape().back();
int n_rows = in.data_size() / axis_size;
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "logsumexp", CTYPE, {
using DataType = cuda_type_t<CTYPE>;
constexpr int N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
auto kernel = cu::logsumexp<DataType, float, BLOCK_DIM, N_READS>;
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
in.data<DataType>(), out.data<DataType>(), axis_size);
});
});
});
}
} // namespace mlx::core

View File

@ -72,7 +72,6 @@ bool fast::ScaledDotProductAttention::use_fallback(
}
NO_GPU(ArgPartition)
NO_GPU(ArgReduce)
NO_GPU(BlockMaskedMM)
NO_GPU_MULTI(Compiled)
NO_GPU(Convolution)
@ -86,18 +85,15 @@ NO_GPU(GatherMM)
NO_GPU(GatherQMM)
NO_GPU(Hadamard)
NO_GPU(Load)
NO_GPU(LogSumExp)
NO_GPU_MULTI(LUF)
NO_GPU(Partition)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(Reduce)
NO_GPU(Scan)
NO_GPU(Scatter)
NO_GPU(ScatterAxis)
NO_GPU(Select)
NO_GPU(SliceUpdate)
NO_GPU(Softmax)
NO_GPU_MULTI(SVD)
NO_GPU(Inverse)
NO_GPU(Cholesky)
@ -105,8 +101,6 @@ NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU_USE_FALLBACK(LayerNorm)
NO_GPU_MULTI(LayerNormVJP)
NO_GPU_USE_FALLBACK(RMSNorm)
NO_GPU_MULTI(RMSNormVJP)
NO_GPU_USE_FALLBACK(RoPE)

View File

@ -0,0 +1,82 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/reduce/reduce.cuh"
#include "mlx/backend/gpu/copy.h"
#include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/fill.h>
#include <cassert>
namespace mlx::core {
void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Reduce::eval_gpu");
assert(inputs.size() == 1);
array in = inputs[0];
// Make sure no identity reductions trickle down here.
assert(!axes_.empty());
assert(out.size() != in.size());
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
// Fill out with init value.
if (in.size() == 0) {
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, {
using InType = cuda_type_t<CTYPE>;
using OutType = cu::ReduceResult<OP, InType>::type;
thrust::fill_n(
cu::thrust_policy(stream),
thrust::device_pointer_cast(out.data<OutType>()),
out.data_size(),
cu::ReduceInit<OP, InType>::value());
});
});
});
return;
}
// Reduce.
ReductionPlan plan = get_reduction_plan(in, axes_);
// If it is a general reduce then copy the input to a contiguous array and
// recompute the plan.
if (plan.type == GeneralReduce) {
array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s);
encoder.add_temporary(in_copy);
in = in_copy;
plan = get_reduction_plan(in, axes_);
}
if ((plan.type == ContiguousAllReduce) ||
(plan.type == ContiguousReduce && plan.shape.size() == 1)) {
segmented_reduce(encoder, in, out, reduce_type_, axes_, plan);
return;
}
if (plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {
row_reduce(encoder, in, out, reduce_type_, axes_, plan);
return;
}
if (plan.type == ContiguousStridedReduce ||
plan.type == GeneralStridedReduce) {
col_reduce(encoder, in, out, reduce_type_, axes_, plan);
return;
}
throw std::runtime_error("No plan reached in reduce.");
}
} // namespace mlx::core

View File

@ -0,0 +1,278 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernels/cast_op.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/block/block_load.cuh>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
struct ColReduceArgs {
// The size of the contiguous column reduction.
size_t reduction_size;
int64_t reduction_stride;
// Input shape and strides excluding the reduction axes.
Shape shape;
Strides strides;
int ndim;
// Input shape and strides of the reduction axes (including last dimension).
Shape reduce_shape;
Strides reduce_strides;
int reduce_ndim;
// The number of column we are reducing. Namely prod(reduce_shape).
size_t non_col_reductions;
ColReduceArgs(
const array& in,
const ReductionPlan& plan,
const std::vector<int>& axes) {
assert(!plan.shape.empty());
reduction_size = plan.shape.back();
reduction_stride = plan.strides.back();
int64_t stride_back = 1;
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
while (!shape_vec.empty() && stride_back < reduction_stride) {
stride_back *= shape_vec.back();
shape_vec.pop_back();
strides_vec.pop_back();
}
std::tie(shape_vec, strides_vec) =
collapse_contiguous_dims(shape_vec, strides_vec);
shape = const_param(shape_vec);
strides = const_param(strides_vec);
ndim = shape_vec.size();
reduce_shape = const_param(plan.shape);
reduce_strides = const_param(plan.strides);
reduce_ndim = plan.shape.size();
non_col_reductions = 1;
for (int i = 0; i < reduce_ndim - 1; i++) {
non_col_reductions *= reduce_shape[i];
}
}
};
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
__global__ void col_reduce_small(
const T* in,
U* out,
const __grid_constant__ ColReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
int column =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
if (column * N_READS >= args.reduction_stride) {
return;
}
int out_idx = grid.block_rank() / grid.dim_blocks().x;
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
Op op;
U totals[N_READS];
for (int i = 0; i < N_READS; i++) {
totals[i] = ReduceInit<Op, T>::value();
}
// Read input to local.
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(
block.thread_index().y,
args.reduce_shape.data(),
args.reduce_strides.data());
for (size_t r = block.thread_index().y;
r < args.non_col_reductions * args.reduction_size;
r += block.dim_threads().y) {
U vals[N_READS];
cub::LoadDirectBlocked(
column,
make_cast_iterator<U>(in + loop.location()),
vals,
args.reduction_stride,
ReduceInit<Op, T>::value());
for (int i = 0; i < N_READS; i++) {
totals[i] = op(vals[i], totals[i]);
}
loop.next(
block.dim_threads().y,
args.reduce_shape.data(),
args.reduce_strides.data());
}
// Do block reduce when each column has more than 1 element to reduce.
if (block.dim_threads().y > 1) {
__shared__ U shared_vals[32 * 8 * N_READS];
size_t col =
block.thread_index().y * block.dim_threads().x + block.thread_index().x;
for (int i = 0; i < N_READS; i++) {
shared_vals[col * N_READS + i] = totals[i];
}
block.sync();
if (block.thread_index().y == 0) {
for (int i = 0; i < N_READS; i++) {
totals[i] = shared_vals[block.thread_index().x * N_READS + i];
}
for (int j = 1; j < block.dim_threads().y; j++) {
col = j * block.dim_threads().x + block.thread_index().x;
for (int i = 0; i < N_READS; i++) {
totals[i] = op(shared_vals[col * N_READS + i], totals[i]);
}
}
}
}
// Write result.
if (block.thread_index().y == 0) {
cub::StoreDirectBlocked(
column,
out + out_idx * args.reduction_stride,
totals,
args.reduction_stride);
}
}
template <
typename T,
typename U,
typename Op,
int NDIM,
int BM,
int BN,
int N_READS = 4>
__global__ void col_reduce_looped(
const T* in,
U* out,
const __grid_constant__ ColReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
constexpr int n_warps = BN / N_READS;
int out_idx = grid.block_rank() / grid.dim_blocks().x;
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
Op op;
U totals[N_READS];
for (int i = 0; i < N_READS; i++) {
totals[i] = ReduceInit<Op, T>::value();
}
// Read input to local.
int r = block.thread_rank() / n_warps;
int column = block.thread_rank() % n_warps;
int in_offset = grid.block_index().x * BN;
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(r, args.reduce_shape.data(), args.reduce_strides.data());
for (; r < args.non_col_reductions * args.reduction_size; r += BM) {
U vals[N_READS];
cub::LoadDirectBlocked(
column,
make_cast_iterator<U>(in + loop.location() + in_offset),
vals,
args.reduction_stride - in_offset,
ReduceInit<Op, T>::value());
for (int i = 0; i < N_READS; i++) {
totals[i] = op(vals[i], totals[i]);
}
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
}
// Do warp reduce for each output.
constexpr int n_outputs = BN / n_warps;
static_assert(BM == 32 && n_outputs == N_READS);
__shared__ U shared_vals[BM * BN];
size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS;
for (int i = 0; i < N_READS; i++) {
shared_vals[col + i] = totals[i];
}
block.sync();
col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
for (int i = 0; i < n_outputs; i++) {
totals[i] = cg::reduce(warp, shared_vals[col + i], op);
}
// Write result.
if (warp.thread_rank() == 0) {
size_t out_offset = grid.block_index().x * BN;
cub::StoreDirectBlocked(
warp.meta_group_rank(),
out + out_idx * args.reduction_stride + out_offset,
totals,
args.reduction_stride - out_offset);
}
}
} // namespace cu
inline auto output_grid_for_col_reduce(
const array& out,
const cu::ColReduceArgs& args) {
auto out_shape = out.shape();
auto out_strides = out.strides();
while (!out_shape.empty() && out_strides.back() < args.reduction_stride) {
out_shape.pop_back();
out_strides.pop_back();
}
return get_2d_grid_dims(out_shape, out_strides);
}
void col_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan) {
cu::ColReduceArgs args(in, plan, axes);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
using InType = cuda_type_t<CTYPE>;
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using OutType = cu::ReduceResult<OP, InType>::type;
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
constexpr int N_READS = 4;
dim3 block_dims;
dim3 num_blocks = output_grid_for_col_reduce(out, args);
num_blocks.z = num_blocks.y;
num_blocks.y = num_blocks.x;
auto kernel =
cu::col_reduce_small<InType, OutType, OP, NDIM, N_READS>;
size_t total = args.non_col_reductions * args.reduction_size;
if (total < 32) {
size_t stride_blocks =
cuda::ceil_div(args.reduction_stride, N_READS);
block_dims.x = std::min(stride_blocks, 32ul);
block_dims.y = std::min(total, 8ul);
num_blocks.x = cuda::ceil_div(stride_blocks, block_dims.x);
} else {
constexpr int BM = 32;
constexpr int BN = 32;
block_dims.x = BM * BN / N_READS;
num_blocks.x = cuda::ceil_div(args.reduction_stride, BN);
kernel = cu::
col_reduce_looped<InType, OutType, OP, NDIM, BM, BN, N_READS>;
}
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>(), out.data<OutType>(), args);
});
});
});
});
}
} // namespace mlx::core

View File

@ -0,0 +1,74 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/reduce.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
#include "mlx/backend/cuda/reduce/reduce_ops.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
// Dispatch dynamic ndim to constexpr.
// The behavior follows get_kernel_reduce_ndim in metal/reduce.cpp file.
#define MLX_SWITCH_REDUCE_NDIM(ndim, NDIM, ...) \
if (ndim == 1) { \
constexpr uint32_t NDIM = 1; \
__VA_ARGS__; \
} else if (ndim == 2) { \
constexpr uint32_t NDIM = 2; \
__VA_ARGS__; \
} else { \
constexpr uint32_t NDIM = 5; \
__VA_ARGS__; \
}
// Dispatch reduce ops to constexpr.
#define MLX_SWITCH_REDUCE_OPS(REDUCE, OP, ...) \
if (REDUCE == Reduce::ReduceType::And) { \
using OP = cu::And; \
__VA_ARGS__; \
} else if (REDUCE == Reduce::ReduceType::Or) { \
using OP = cu::Or; \
__VA_ARGS__; \
} else if (REDUCE == Reduce::ReduceType::Sum) { \
using OP = cu::Sum; \
__VA_ARGS__; \
} else if (REDUCE == Reduce::ReduceType::Prod) { \
using OP = cu::Prod; \
__VA_ARGS__; \
} else if (REDUCE == Reduce::ReduceType::Max) { \
using OP = cu::Max; \
__VA_ARGS__; \
} else if (REDUCE == Reduce::ReduceType::Min) { \
using OP = cu::Min; \
__VA_ARGS__; \
} else { \
throw std::invalid_argument("Unknown reduce type."); \
}
void segmented_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan);
void row_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan);
void col_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan);
} // namespace mlx::core

View File

@ -0,0 +1,144 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/kernels/utils.cuh"
namespace mlx::core::cu {
// Reduce ops.
struct And {
__device__ bool operator()(bool a, bool b) {
return a && b;
}
};
struct Or {
__device__ bool operator()(bool a, bool b) {
return a || b;
}
};
struct Sum {
template <typename T>
__device__ T operator()(T a, T b) {
return a + b;
}
};
struct Prod {
template <typename T>
__device__ T operator()(T a, T b) {
return a * b;
}
};
struct Min {
template <typename T>
__device__ T operator()(T a, T b) {
return a < b ? a : b;
}
};
struct Max {
template <typename T>
__device__ T operator()(T a, T b) {
return a > b ? a : b;
}
};
// Traits to get the result type of reduce op.
template <typename Op, typename T>
struct ReduceResult;
template <typename T>
struct ReduceResult<And, T> {
using type = bool;
};
template <typename T>
struct ReduceResult<Or, T> {
using type = bool;
};
template <typename T>
struct ReduceResult<Sum, T> {
using type = cuda::std::conditional_t<
(cuda::std::is_integral_v<T> && sizeof(T) <= 4),
int32_t,
T>;
};
template <typename T>
struct ReduceResult<Prod, T> {
using type = cuda::std::conditional_t<
(cuda::std::is_integral_v<T> && sizeof(T) <= 4),
int32_t,
T>;
};
template <typename T>
struct ReduceResult<Min, T> {
using type = T;
};
template <typename T>
struct ReduceResult<Max, T> {
using type = T;
};
// Traits to get the init value of reduce op.
template <typename Op, typename T>
struct ReduceInit;
template <typename T>
struct ReduceInit<And, T> {
static constexpr __host__ __device__ bool value() {
return true;
}
};
template <typename T>
struct ReduceInit<Or, T> {
static constexpr __host__ __device__ bool value() {
return false;
}
};
template <typename T>
struct ReduceInit<Sum, T> {
static constexpr __host__ __device__ auto value() {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return T{0, 0};
} else {
return typename ReduceResult<Sum, T>::type{0};
}
}
};
template <typename T>
struct ReduceInit<Prod, T> {
static constexpr __host__ __device__ auto value() {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return T{1, 1};
} else {
return typename ReduceResult<Prod, T>::type{1};
}
}
};
template <typename T>
struct ReduceInit<Min, T> {
static constexpr __host__ __device__ T value() {
return Limits<T>::max();
}
};
template <typename T>
struct ReduceInit<Max, T> {
static constexpr __host__ __device__ T value() {
return Limits<T>::min();
}
};
} // namespace mlx::core::cu

View File

@ -0,0 +1,250 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernels/cast_op.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/block/block_load.cuh>
#include <cub/block/block_reduce.cuh>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
struct RowReduceArgs {
// The size of the row being reduced, i.e. the size of last dimension.
int row_size;
// Input shape and strides excluding the reduction axes.
Shape shape;
Strides strides;
int ndim;
// Input shape and strides of the reduction axes excluding last dimension.
Shape reduce_shape;
Strides reduce_strides;
int reduce_ndim;
// The number of rows we are reducing. Namely prod(reduce_shape).
size_t non_row_reductions;
RowReduceArgs(
const array& in,
const ReductionPlan& plan,
const std::vector<int>& axes) {
assert(!plan.shape.empty());
row_size = plan.shape.back();
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
std::tie(shape_vec, strides_vec) =
collapse_contiguous_dims(shape_vec, strides_vec);
shape = const_param(shape_vec);
strides = const_param(strides_vec);
ndim = shape_vec.size();
reduce_shape = const_param(plan.shape);
reduce_strides = const_param(plan.strides);
reduce_ndim = plan.shape.size() - 1;
non_row_reductions = 1;
for (int i = 0; i < reduce_ndim; i++) {
non_row_reductions *= reduce_shape[i];
}
}
};
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
__global__ void row_reduce_small(
const T* in,
U* out,
size_t out_size,
const __grid_constant__ RowReduceArgs args) {
size_t out_idx = cg::this_grid().thread_rank();
if (out_idx >= out_size) {
return;
}
Op op;
U total_val = ReduceInit<Op, T>::value();
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
for (size_t n = 0; n < args.non_row_reductions; n++) {
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
U vals[N_READS];
cub::LoadDirectBlocked(
r,
make_cast_iterator<U>(in + loop.location()),
vals,
args.row_size,
ReduceInit<Op, T>::value());
total_val = op(total_val, cub::ThreadReduce(vals, op));
}
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
out[out_idx] = total_val;
}
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
__global__ void row_reduce_small_warp(
const T* in,
U* out,
size_t out_size,
const __grid_constant__ RowReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
size_t out_idx = grid.thread_rank() / WARP_SIZE;
if (out_idx >= out_size) {
return;
}
Op op;
U total_val = ReduceInit<Op, T>::value();
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
for (size_t n = warp.thread_rank(); n < args.non_row_reductions;
n += WARP_SIZE) {
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
U vals[N_READS];
cub::LoadDirectBlocked(
r,
make_cast_iterator<U>(in + loop.location()),
vals,
args.row_size,
ReduceInit<Op, T>::value());
total_val = op(total_val, cub::ThreadReduce(vals, op));
}
loop.next(WARP_SIZE, args.reduce_shape.data(), args.reduce_strides.data());
}
total_val = cg::reduce(warp, total_val, op);
if (warp.thread_rank() == 0) {
out[out_idx] = total_val;
}
}
template <
typename T,
typename U,
typename Op,
int NDIM,
int BLOCK_DIM_X,
int N_READS = 4>
__global__ void row_reduce_looped(
const T* in,
U* out,
size_t out_size,
const __grid_constant__ RowReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
size_t out_idx = grid.thread_rank() / BLOCK_DIM_X;
if (out_idx >= out_size) {
return;
}
Op op;
U total_val = ReduceInit<Op, T>::value();
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
for (size_t n = 0; n < args.non_row_reductions; n++) {
for (size_t r = 0; r < cuda::ceil_div(args.row_size, BLOCK_DIM_X * N_READS);
r++) {
U vals[N_READS];
cub::LoadDirectBlocked(
r * BLOCK_DIM_X + block.thread_index().x,
make_cast_iterator<U>(in + loop.location()),
vals,
args.row_size,
ReduceInit<Op, T>::value());
total_val = op(total_val, cub::ThreadReduce(vals, op));
}
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
typedef cub::BlockReduce<U, BLOCK_DIM_X> BlockReduceT;
__shared__ typename BlockReduceT::TempStorage temp;
total_val = BlockReduceT(temp).Reduce(total_val, op);
if (block.thread_rank() == 0) {
out[out_idx] = total_val;
}
}
} // namespace cu
void row_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan) {
cu::RowReduceArgs args(in, plan, axes);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
using InType = cuda_type_t<CTYPE>;
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using OutType = cu::ReduceResult<OP, InType>::type;
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
constexpr size_t N_READS = 4;
dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides());
dim3 block_dims, num_blocks;
auto kernel =
cu::row_reduce_small<InType, OutType, OP, NDIM, N_READS>;
if (args.row_size <= 64) {
if ((args.non_row_reductions < 32 && args.row_size <= 8) ||
(args.non_row_reductions <= 8)) {
block_dims.x = std::min(out_dims.x, 1024u);
num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x);
num_blocks.y = out_dims.y;
} else {
block_dims.x = WARP_SIZE;
num_blocks.y = out_dims.x;
num_blocks.z = out_dims.y;
kernel =
cu::row_reduce_small_warp<InType, OutType, OP, NDIM, N_READS>;
}
} else {
size_t num_threads = cuda::ceil_div(args.row_size, N_READS);
num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE;
MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, {
num_blocks.y = out_dims.x;
num_blocks.z = out_dims.y;
block_dims.x = BLOCK_DIM_X;
kernel = cu::row_reduce_looped<
InType,
OutType,
OP,
NDIM,
BLOCK_DIM_X,
N_READS>;
});
}
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>(), out.data<OutType>(), out.size(), args);
});
});
});
});
}
} // namespace mlx::core

View File

@ -0,0 +1,84 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernels/cast_op.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh"
#include <thrust/device_ptr.h>
#include <cub/device/device_reduce.cuh>
#include <cub/device/device_segmented_reduce.cuh>
namespace mlx::core {
template <typename... Args>
void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) {
// Allocate temporary storage.
size_t size;
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Run op.
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data<void>(), size, args...));
}
template <typename... Args>
void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) {
// Allocate temporary storage.
size_t size;
CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Run op.
CHECK_CUDA_ERROR(
cub::DeviceSegmentedReduce::Reduce(temp.data<void>(), size, args...));
}
struct MultiplyOp {
int factor;
__device__ int operator()(int i) {
return i * factor;
}
};
void segmented_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan) {
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using InType = cuda_type_t<CTYPE>;
using OutType = cu::ReduceResult<OP, InType>::type;
auto in_iter = cu::make_cast_iterator<OutType>(
thrust::device_pointer_cast(in.data<InType>()));
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
auto init = cu::ReduceInit<OP, InType>::value();
if (plan.type == ContiguousAllReduce) {
cub_all_reduce(
encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream);
} else if (plan.type == ContiguousReduce) {
auto offsets = thrust::make_transform_iterator(
thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()});
cub_segmented_reduce(
encoder,
in_iter,
out_ptr,
out.size(),
offsets,
offsets + 1,
OP(),
init,
stream);
} else {
throw std::runtime_error("Unsupported plan in segmented_reduce.");
}
});
});
});
}
} // namespace mlx::core

160
mlx/backend/cuda/softmax.cu Normal file
View File

@ -0,0 +1,160 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/kernels/cast_op.cuh"
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
#include <cub/block/block_load.cuh>
#include <cassert>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T>
inline __device__ T softmax_exp(T x) {
// Softmax doesn't need high precision exponential cause x is gonna be in
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
return __expf(x);
}
template <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4>
__global__ void softmax(const T* in, T* out, int axis_size) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
in += grid.block_rank() * axis_size;
out += grid.block_rank() * axis_size;
cg::greater<AccT> max_op;
cg::plus<AccT> plus_op;
// Thread reduce.
AccT prevmax;
AccT maxval = Limits<AccT>::finite_min();
AccT normalizer = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
AccT vals[N_READS];
cub::LoadDirectBlocked(
r * BLOCK_DIM + block.thread_rank(),
make_cast_iterator<AccT>(in),
vals,
axis_size,
Limits<AccT>::finite_min());
prevmax = maxval;
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
// Online normalizer calculation for softmax:
// https://github.com/NVIDIA/online-softmax
normalizer = normalizer * softmax_exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) {
normalizer = normalizer + softmax_exp(vals[i] - maxval);
}
}
// First warp reduce.
prevmax = maxval;
maxval = cg::reduce(warp, maxval, max_op);
normalizer = normalizer * softmax_exp(prevmax - maxval);
normalizer = cg::reduce(warp, normalizer, plus_op);
__shared__ AccT local_max[WARP_SIZE];
__shared__ AccT local_normalizer[WARP_SIZE];
// Write to shared memory and do second warp reduce.
prevmax = maxval;
if (warp.thread_rank() == 0) {
local_max[warp.meta_group_rank()] = maxval;
}
block.sync();
maxval = warp.thread_rank() < warp.meta_group_size()
? local_max[warp.thread_rank()]
: Limits<AccT>::finite_min();
maxval = cg::reduce(warp, maxval, max_op);
normalizer = normalizer * softmax_exp(prevmax - maxval);
if (warp.thread_rank() == 0) {
local_normalizer[warp.meta_group_rank()] = normalizer;
}
block.sync();
normalizer = warp.thread_rank() < warp.meta_group_size()
? local_normalizer[warp.thread_rank()]
: AccT{};
normalizer = cg::reduce(warp, normalizer, plus_op);
normalizer = 1 / normalizer;
// Write output.
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
auto index = r * BLOCK_DIM + block.thread_rank();
T vals[N_READS];
cub::LoadDirectBlocked(index, in, vals, axis_size);
for (int i = 0; i < N_READS; i++) {
vals[i] = softmax_exp(static_cast<AccT>(vals[i]) - maxval) * normalizer;
}
cub::StoreDirectBlocked(index, out, vals, axis_size);
}
}
} // namespace cu
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Softmax::eval_gpu");
assert(inputs.size() == 1);
auto& s = stream();
// Make sure that the last dimension is contiguous.
auto set_output = [&s, &out](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
if (x.is_donatable()) {
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
}
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
};
array in = set_output(inputs[0]);
bool precise = in.dtype() != float32 && precise_;
int axis_size = in.shape().back();
int n_rows = in.data_size() / axis_size;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "softmax", CTYPE, {
using DataType = cuda_type_t<CTYPE>;
constexpr int N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
auto kernel = cu::softmax<DataType, DataType, BLOCK_DIM, N_READS>;
if (precise) {
kernel = cu::softmax<DataType, float, BLOCK_DIM, N_READS>;
}
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
in.data<DataType>(), out.data<DataType>(), axis_size);
});
});
});
}
} // namespace mlx::core