mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-02 08:46:42 +08:00
Compare commits
6 Commits
543fad7536
...
0002e0083d
Author | SHA1 | Date | |
---|---|---|---|
![]() |
0002e0083d | ||
![]() |
d7e680ffe4 | ||
![]() |
c371baf53a | ||
![]() |
ccf78f566c | ||
![]() |
c9fa68664a | ||
![]() |
d2e0b0465c |
@ -42,6 +42,7 @@ option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
|||||||
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||||
|
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
|
||||||
|
|
||||||
# --------------------- Processor tests -------------------------
|
# --------------------- Processor tests -------------------------
|
||||||
message(
|
message(
|
||||||
@ -234,12 +235,16 @@ target_include_directories(
|
|||||||
# Do not add mlx_EXPORTS define for shared library.
|
# Do not add mlx_EXPORTS define for shared library.
|
||||||
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
||||||
|
|
||||||
FetchContent_Declare(
|
if(USE_SYSTEM_FMT)
|
||||||
fmt
|
find_package(fmt REQUIRED)
|
||||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
else()
|
||||||
GIT_TAG 10.2.1
|
FetchContent_Declare(
|
||||||
EXCLUDE_FROM_ALL)
|
fmt
|
||||||
FetchContent_MakeAvailable(fmt)
|
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||||
|
GIT_TAG 10.2.1
|
||||||
|
EXCLUDE_FROM_ALL)
|
||||||
|
FetchContent_MakeAvailable(fmt)
|
||||||
|
endif()
|
||||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
||||||
|
|
||||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
||||||
@ -19,9 +20,16 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
|
189
mlx/backend/cuda/arg_reduce.cu
Normal file
189
mlx/backend/cuda/arg_reduce.cu
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <cub/block/block_load.cuh>
|
||||||
|
#include <cub/block/block_reduce.cuh>
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct IndexValPair {
|
||||||
|
uint32_t index;
|
||||||
|
T val;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ArgMin {
|
||||||
|
constexpr __device__ T init() {
|
||||||
|
return Limits<T>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ IndexValPair<T> operator()(
|
||||||
|
const IndexValPair<T>& best,
|
||||||
|
const IndexValPair<T>& current) {
|
||||||
|
if (best.val > current.val ||
|
||||||
|
(best.val == current.val && best.index > current.index)) {
|
||||||
|
return current;
|
||||||
|
} else {
|
||||||
|
return best;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N>
|
||||||
|
__device__ IndexValPair<T>
|
||||||
|
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
if (vals[i] < best.val) {
|
||||||
|
best.val = vals[i];
|
||||||
|
best.index = offset + i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return best;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ArgMax {
|
||||||
|
constexpr __device__ T init() {
|
||||||
|
return Limits<T>::min();
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ IndexValPair<T> operator()(
|
||||||
|
const IndexValPair<T>& best,
|
||||||
|
const IndexValPair<T>& current) {
|
||||||
|
if (best.val < current.val ||
|
||||||
|
(best.val == current.val && best.index > current.index)) {
|
||||||
|
return current;
|
||||||
|
} else {
|
||||||
|
return best;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N>
|
||||||
|
__device__ IndexValPair<T>
|
||||||
|
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
if (vals[i] > best.val) {
|
||||||
|
best.val = vals[i];
|
||||||
|
best.index = offset + i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return best;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename Op, int BLOCK_DIM, int N_READS = 4>
|
||||||
|
__global__ void arg_reduce_general(
|
||||||
|
const T* in,
|
||||||
|
uint32_t* out,
|
||||||
|
size_t size,
|
||||||
|
const __grid_constant__ Shape shape,
|
||||||
|
const __grid_constant__ Strides in_strides,
|
||||||
|
const __grid_constant__ Strides out_strides,
|
||||||
|
int32_t ndim,
|
||||||
|
int64_t axis_stride,
|
||||||
|
int32_t axis_size) {
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
|
int64_t index = cg::this_grid().block_rank();
|
||||||
|
if (index >= size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
|
||||||
|
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
|
||||||
|
|
||||||
|
Op op;
|
||||||
|
T init = op.init();
|
||||||
|
IndexValPair<T> best{0, init};
|
||||||
|
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
|
T vals[N_READS];
|
||||||
|
auto tid = r * BLOCK_DIM + block.thread_index().z;
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
|
||||||
|
best = op.reduce_many(best, vals, tid * N_READS);
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef cub::BlockReduce<IndexValPair<T>, BLOCK_DIM> BlockReduceT;
|
||||||
|
__shared__ typename BlockReduceT::TempStorage temp;
|
||||||
|
|
||||||
|
best = BlockReduceT(temp).Reduce(best, op);
|
||||||
|
|
||||||
|
if (block.thread_rank() == 0) {
|
||||||
|
out[out_idx] = best.index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("ArgReduce::eval_gpu");
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
auto& s = stream();
|
||||||
|
|
||||||
|
// Prepare the shapes, strides and axis arguments.
|
||||||
|
Shape shape = remove_index(in.shape(), axis_);
|
||||||
|
Strides in_strides = remove_index(in.strides(), axis_);
|
||||||
|
Strides out_strides = out.ndim() == in.ndim()
|
||||||
|
? remove_index(out.strides(), axis_)
|
||||||
|
: out.strides();
|
||||||
|
int64_t axis_stride = in.strides()[axis_];
|
||||||
|
int32_t axis_size = in.shape()[axis_];
|
||||||
|
int32_t ndim = shape.size();
|
||||||
|
|
||||||
|
// ArgReduce.
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_REAL_TYPES_CHECKED(in.dtype(), "ArgReduce", CTYPE, {
|
||||||
|
using InType = cuda_type_t<CTYPE>;
|
||||||
|
constexpr uint32_t N_READS = 4;
|
||||||
|
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||||
|
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
|
dim3 block_dims{1, 1, BLOCK_DIM};
|
||||||
|
auto kernel = &cu::arg_reduce_general<
|
||||||
|
InType,
|
||||||
|
cu::ArgMax<InType>,
|
||||||
|
BLOCK_DIM,
|
||||||
|
N_READS>;
|
||||||
|
if (reduce_type_ == ArgReduce::ArgMin) {
|
||||||
|
kernel = &cu::arg_reduce_general<
|
||||||
|
InType,
|
||||||
|
cu::ArgMin<InType>,
|
||||||
|
BLOCK_DIM,
|
||||||
|
N_READS>;
|
||||||
|
}
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
in.data<InType>(),
|
||||||
|
out.data<uint32_t>(),
|
||||||
|
out.size(),
|
||||||
|
const_param(shape),
|
||||||
|
const_param(in_strides),
|
||||||
|
const_param(out_strides),
|
||||||
|
ndim,
|
||||||
|
axis_stride,
|
||||||
|
axis_size);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
60
mlx/backend/cuda/iterators/strided_iterator.cuh
Normal file
60
mlx/backend/cuda/iterators/strided_iterator.cuh
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <thrust/iterator/iterator_adaptor.h>
|
||||||
|
#include <thrust/iterator/iterator_facade.h>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
// RandomAccessIterator for strided access to array entries.
|
||||||
|
template <typename Iterator, typename Stride = int64_t>
|
||||||
|
class strided_iterator
|
||||||
|
: public thrust::
|
||||||
|
iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator> {
|
||||||
|
public:
|
||||||
|
using super_t =
|
||||||
|
thrust::iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator>;
|
||||||
|
|
||||||
|
using reference = typename super_t::reference;
|
||||||
|
using difference_type = typename super_t::difference_type;
|
||||||
|
|
||||||
|
__host__ __device__ strided_iterator(Iterator it, Stride stride)
|
||||||
|
: super_t(it), stride_(stride) {}
|
||||||
|
|
||||||
|
__host__ __device__ Stride stride() const {
|
||||||
|
return stride_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend class thrust::iterator_core_access;
|
||||||
|
|
||||||
|
__host__ __device__ bool equal(const strided_iterator& other) const {
|
||||||
|
return this->base() == other.base();
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ void advance(difference_type n) {
|
||||||
|
this->base_reference() += n * stride_;
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ void increment() {
|
||||||
|
this->base_reference() += stride_;
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ void decrement() {
|
||||||
|
this->base_reference() -= stride_;
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ difference_type
|
||||||
|
distance_to(const strided_iterator& other) const {
|
||||||
|
const difference_type dist = other.base() - this->base();
|
||||||
|
_CCCL_ASSERT(
|
||||||
|
dist % stride() == 0,
|
||||||
|
"Underlying iterator difference must be divisible by the stride");
|
||||||
|
return dist / stride();
|
||||||
|
}
|
||||||
|
|
||||||
|
Stride stride_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
@ -47,6 +47,31 @@ namespace mlx::core {
|
|||||||
__VA_ARGS__; \
|
__VA_ARGS__; \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert a block_dim to constexpr between WARP_SIZE and WARP_SIZE ^ 2.
|
||||||
|
#define MLX_SWITCH_BLOCK_DIM(NUM_THREADS, BLOCK_DIM, ...) \
|
||||||
|
{ \
|
||||||
|
uint32_t _num_threads = NUM_THREADS; \
|
||||||
|
if (_num_threads <= WARP_SIZE) { \
|
||||||
|
constexpr uint32_t BLOCK_DIM = WARP_SIZE; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else if (_num_threads <= WARP_SIZE * 2) { \
|
||||||
|
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 2; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else if (_num_threads <= WARP_SIZE * 4) { \
|
||||||
|
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 4; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else if (_num_threads <= WARP_SIZE * 8) { \
|
||||||
|
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 8; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else if (_num_threads <= WARP_SIZE * 16) { \
|
||||||
|
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 16; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else { \
|
||||||
|
constexpr uint32_t BLOCK_DIM = WARP_SIZE * WARP_SIZE; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
// Maps CPU types to CUDA types.
|
// Maps CPU types to CUDA types.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct CTypeToCudaType {
|
struct CTypeToCudaType {
|
||||||
|
@ -9,6 +9,8 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <cuComplex.h>
|
#include <cuComplex.h>
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
#include <cuda/std/array>
|
#include <cuda/std/array>
|
||||||
#include <cuda/std/limits>
|
#include <cuda/std/limits>
|
||||||
#include <cuda/std/tuple>
|
#include <cuda/std/tuple>
|
||||||
@ -19,6 +21,10 @@ namespace mlx::core::cu {
|
|||||||
// CUDA kernel utils
|
// CUDA kernel utils
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
|
||||||
|
// warpSize variable exists, using it would prevent compile-time optimizations.
|
||||||
|
#define WARP_SIZE 32
|
||||||
|
|
||||||
// To pass shape/strides to kernels via constant memory, their size must be
|
// To pass shape/strides to kernels via constant memory, their size must be
|
||||||
// known at compile time.
|
// known at compile time.
|
||||||
#define MAX_NDIM 8
|
#define MAX_NDIM 8
|
||||||
@ -26,6 +32,94 @@ namespace mlx::core::cu {
|
|||||||
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
|
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
|
||||||
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
|
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Type limits utils
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T, typename = void>
|
||||||
|
struct Limits {
|
||||||
|
static constexpr __host__ __device__ T max() {
|
||||||
|
return cuda::std::numeric_limits<T>::max();
|
||||||
|
}
|
||||||
|
static constexpr __host__ __device__ T min() {
|
||||||
|
return cuda::std::numeric_limits<T>::min();
|
||||||
|
}
|
||||||
|
static constexpr __host__ __device__ T finite_max() {
|
||||||
|
return cuda::std::numeric_limits<T>::max();
|
||||||
|
}
|
||||||
|
static constexpr __host__ __device__ T finite_min() {
|
||||||
|
return cuda::std::numeric_limits<T>::min();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct Limits<
|
||||||
|
T,
|
||||||
|
cuda::std::enable_if_t<
|
||||||
|
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double>>> {
|
||||||
|
static constexpr __host__ __device__ T max() {
|
||||||
|
return cuda::std::numeric_limits<T>::infinity();
|
||||||
|
}
|
||||||
|
static constexpr __host__ __device__ T min() {
|
||||||
|
return -cuda::std::numeric_limits<T>::infinity();
|
||||||
|
}
|
||||||
|
static constexpr __host__ __device__ T finite_max() {
|
||||||
|
return cuda::std::numeric_limits<T>::max();
|
||||||
|
}
|
||||||
|
static constexpr __host__ __device__ T finite_min() {
|
||||||
|
return cuda::std::numeric_limits<T>::lowest();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// CUDA 11 does not have host side arithmatic operators for half types.
|
||||||
|
template <typename T>
|
||||||
|
struct Limits<
|
||||||
|
T,
|
||||||
|
cuda::std::enable_if_t<
|
||||||
|
cuda::std::is_same_v<T, __half> ||
|
||||||
|
cuda::std::is_same_v<T, __nv_bfloat16>>> {
|
||||||
|
static constexpr __host__ __device__ T max() {
|
||||||
|
return cuda::std::numeric_limits<T>::infinity();
|
||||||
|
}
|
||||||
|
static constexpr __host__ __device__ T min() {
|
||||||
|
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
|
||||||
|
return -cuda::std::numeric_limits<T>::infinity();
|
||||||
|
#else
|
||||||
|
return -cuda::std::numeric_limits<float>::infinity();
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
static constexpr __host__ __device__ T finite_max() {
|
||||||
|
return cuda::std::numeric_limits<T>::max();
|
||||||
|
}
|
||||||
|
static constexpr __host__ __device__ T finite_min() {
|
||||||
|
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
|
||||||
|
return cuda::std::numeric_limits<T>::lowest();
|
||||||
|
#else
|
||||||
|
return cuda::std::numeric_limits<float>::lowest();
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct Limits<bool> {
|
||||||
|
static constexpr __host__ __device__ bool max() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
static constexpr __host__ __device__ bool min() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct Limits<cuComplex> {
|
||||||
|
static constexpr __host__ __device__ cuComplex max() {
|
||||||
|
return {Limits<float>::max(), Limits<float>::max()};
|
||||||
|
}
|
||||||
|
static constexpr __host__ __device__ cuComplex min() {
|
||||||
|
return {Limits<float>::min(), Limits<float>::min()};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Indexing utils
|
// Indexing utils
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@ -101,4 +195,108 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
|||||||
return cuda::std::make_tuple(a_loc, b_loc);
|
return cuda::std::make_tuple(a_loc, b_loc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Elem to loc in a loop utils
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <int DIM, bool General = true, typename OffsetT = size_t>
|
||||||
|
struct LoopedElemToLoc {
|
||||||
|
int dim;
|
||||||
|
LoopedElemToLoc<DIM - 1, General, OffsetT> inner_looper;
|
||||||
|
OffsetT offset{0};
|
||||||
|
int index{0};
|
||||||
|
|
||||||
|
__device__ LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}
|
||||||
|
|
||||||
|
__device__ void next(const int* shape, const int64_t* strides) {
|
||||||
|
if (dim == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
index++;
|
||||||
|
offset += OffsetT(strides[dim - 1]);
|
||||||
|
if (index >= shape[dim - 1]) {
|
||||||
|
index = 0;
|
||||||
|
inner_looper.next(shape, strides);
|
||||||
|
offset = inner_looper.offset;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ void next(int n, const int* shape, const int64_t* strides) {
|
||||||
|
if (dim == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
index += n;
|
||||||
|
offset += n * OffsetT(strides[dim - 1]);
|
||||||
|
|
||||||
|
if (index >= shape[dim - 1]) {
|
||||||
|
int extra = index - shape[dim - 1];
|
||||||
|
if (extra >= shape[dim - 1]) {
|
||||||
|
inner_looper.next(1 + extra / shape[dim - 1], shape, strides);
|
||||||
|
extra = extra % shape[dim - 1];
|
||||||
|
} else {
|
||||||
|
inner_looper.next(shape, strides);
|
||||||
|
}
|
||||||
|
index = 0;
|
||||||
|
offset = inner_looper.offset;
|
||||||
|
if (extra > 0) {
|
||||||
|
next(extra, shape, strides);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ OffsetT location() {
|
||||||
|
return offset;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OffsetT>
|
||||||
|
struct LoopedElemToLoc<1, true, OffsetT> {
|
||||||
|
int dim;
|
||||||
|
OffsetT offset{0};
|
||||||
|
int index{0};
|
||||||
|
|
||||||
|
__device__ LoopedElemToLoc(int dim) : dim(dim) {}
|
||||||
|
|
||||||
|
__device__ void next(const int* shape, const int64_t* strides) {
|
||||||
|
index++;
|
||||||
|
if (dim > 1) {
|
||||||
|
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
||||||
|
} else {
|
||||||
|
offset += OffsetT(strides[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ void next(int n, const int* shape, const int64_t* strides) {
|
||||||
|
index += n;
|
||||||
|
if (dim > 1) {
|
||||||
|
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
||||||
|
} else {
|
||||||
|
offset = index * OffsetT(strides[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ OffsetT location() {
|
||||||
|
return offset;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OffsetT>
|
||||||
|
struct LoopedElemToLoc<1, false, OffsetT> {
|
||||||
|
OffsetT offset{0};
|
||||||
|
|
||||||
|
__device__ LoopedElemToLoc(int) {}
|
||||||
|
|
||||||
|
__device__ void next(const int*, const int64_t* strides) {
|
||||||
|
offset += OffsetT(strides[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ void next(int n, const int*, const int64_t* strides) {
|
||||||
|
offset += n * OffsetT(strides[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ OffsetT location() {
|
||||||
|
return offset;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core::cu
|
||||||
|
390
mlx/backend/cuda/layer_norm.cu
Normal file
390
mlx/backend/cuda/layer_norm.cu
Normal file
@ -0,0 +1,390 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/fast_primitives.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cooperative_groups/reduce.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <cub/block/block_load.cuh>
|
||||||
|
#include <cub/block/block_reduce.cuh>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
inline __device__ float3 plus_f3(const float3& a, const float3& b) {
|
||||||
|
return {a.x + b.x, a.y + b.y, a.z + b.z};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
|
||||||
|
template <typename T, int BLOCK_DIM>
|
||||||
|
struct BlockBroadcastReduce {
|
||||||
|
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
|
||||||
|
static_assert(BLOCK_DIM % WARP_SIZE == 0);
|
||||||
|
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
|
||||||
|
|
||||||
|
cg::thread_block& block;
|
||||||
|
TempStorage& temp;
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
|
||||||
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
T x = cg::reduce(warp, input, op);
|
||||||
|
if (warp.thread_rank() == 0) {
|
||||||
|
temp[warp.meta_group_rank()] = x;
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
|
||||||
|
: init_value;
|
||||||
|
return cg::reduce(warp, x, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ T Sum(const T& input) {
|
||||||
|
return Reduce(input, cg::plus<T>{}, T{});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, int BLOCK_DIM, int N_READS = 4>
|
||||||
|
__global__ void layer_norm(
|
||||||
|
const T* x,
|
||||||
|
const T* w,
|
||||||
|
const T* b,
|
||||||
|
T* out,
|
||||||
|
float eps,
|
||||||
|
int32_t axis_size,
|
||||||
|
int64_t w_stride,
|
||||||
|
int64_t b_stride) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
|
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||||
|
__shared__ typename BlockReduceT::TempStorage temp;
|
||||||
|
|
||||||
|
x += grid.block_rank() * axis_size;
|
||||||
|
out += grid.block_rank() * axis_size;
|
||||||
|
|
||||||
|
// Sum.
|
||||||
|
float sum = 0;
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
|
T xn[N_READS] = {};
|
||||||
|
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||||
|
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
|
||||||
|
}
|
||||||
|
sum = BlockReduceT{block, temp}.Sum(sum);
|
||||||
|
|
||||||
|
// Mean.
|
||||||
|
float mean = sum / axis_size;
|
||||||
|
|
||||||
|
// Normalizer.
|
||||||
|
float normalizer = 0;
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
|
T xn[N_READS];
|
||||||
|
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
float t = static_cast<float>(xn[i]) - mean;
|
||||||
|
normalizer += t * t;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
|
||||||
|
normalizer = rsqrt(normalizer / axis_size + eps);
|
||||||
|
|
||||||
|
// Outputs.
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
|
T xn[N_READS];
|
||||||
|
T wn[N_READS];
|
||||||
|
T bn[N_READS];
|
||||||
|
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||||
|
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||||
|
cub::LoadDirectBlocked(index, strided_iterator(b, b_stride), bn, axis_size);
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||||
|
xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
|
||||||
|
}
|
||||||
|
cub::StoreDirectBlocked(index, out, xn, axis_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
|
||||||
|
__global__ void layer_norm_vjp(
|
||||||
|
const T* x,
|
||||||
|
const T* w,
|
||||||
|
const T* g,
|
||||||
|
T* gx,
|
||||||
|
T* gw,
|
||||||
|
float eps,
|
||||||
|
int32_t axis_size,
|
||||||
|
int64_t w_stride) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
|
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||||
|
using BlockReduceF3 = BlockBroadcastReduce<float3, BLOCK_DIM>;
|
||||||
|
__shared__ union {
|
||||||
|
typename BlockReduceF::TempStorage f;
|
||||||
|
typename BlockReduceF3::TempStorage f3;
|
||||||
|
} temp;
|
||||||
|
|
||||||
|
x += grid.block_rank() * axis_size;
|
||||||
|
g += grid.block_rank() * axis_size;
|
||||||
|
gx += grid.block_rank() * axis_size;
|
||||||
|
gw += grid.block_rank() * axis_size;
|
||||||
|
|
||||||
|
// Sum.
|
||||||
|
float sum = 0;
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
|
T xn[N_READS] = {};
|
||||||
|
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||||
|
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
|
||||||
|
}
|
||||||
|
sum = BlockReduceF{block, temp.f}.Sum(sum);
|
||||||
|
|
||||||
|
// Mean.
|
||||||
|
float mean = sum / axis_size;
|
||||||
|
|
||||||
|
// Normalizer.
|
||||||
|
float3 factors = {};
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
|
T xn[N_READS];
|
||||||
|
T wn[N_READS] = {};
|
||||||
|
T gn[N_READS] = {};
|
||||||
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
|
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
|
||||||
|
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||||
|
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
float t = static_cast<float>(xn[i]) - mean;
|
||||||
|
float wi = wn[i];
|
||||||
|
float gi = gn[i];
|
||||||
|
float wg = wi * gi;
|
||||||
|
factors = plus_f3(factors, {wg, wg * t, t * t});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {});
|
||||||
|
float meanwg = factors.x / axis_size;
|
||||||
|
float meanwgxc = factors.y / axis_size;
|
||||||
|
float normalizer2 = 1 / (factors.z / axis_size + eps);
|
||||||
|
float normalizer = sqrt(normalizer2);
|
||||||
|
|
||||||
|
// Outputs.
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
|
T xn[N_READS];
|
||||||
|
T wn[N_READS];
|
||||||
|
T gn[N_READS];
|
||||||
|
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||||
|
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||||
|
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||||
|
float wi = wn[i];
|
||||||
|
float gi = gn[i];
|
||||||
|
xn[i] = normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2;
|
||||||
|
if constexpr (HAS_W) {
|
||||||
|
wn[i] = gi * xi;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cub::StoreDirectBlocked(index, gx, xn, axis_size);
|
||||||
|
if constexpr (HAS_W) {
|
||||||
|
cub::StoreDirectBlocked(index, gw, wn, axis_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
namespace fast {
|
||||||
|
|
||||||
|
bool LayerNorm::use_fallback(Stream s) {
|
||||||
|
return s.device == Device::cpu;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: There are duplicate code with backend/metal/normalization.cpp
|
||||||
|
void LayerNorm::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
nvtx3::scoped_range r("LayerNorm::eval_gpu");
|
||||||
|
auto& s = stream();
|
||||||
|
auto& out = outputs[0];
|
||||||
|
|
||||||
|
// Make sure that the last dimension is contiguous.
|
||||||
|
auto set_output = [&s, &out](const array& x) {
|
||||||
|
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||||
|
if (no_copy && x.ndim() > 1) {
|
||||||
|
auto s = x.strides()[x.ndim() - 2];
|
||||||
|
no_copy &= (s == 0 || s == x.shape().back());
|
||||||
|
}
|
||||||
|
if (no_copy) {
|
||||||
|
if (x.is_donatable()) {
|
||||||
|
out.copy_shared_buffer(x);
|
||||||
|
} else {
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc(x.data_size() * x.itemsize()),
|
||||||
|
x.data_size(),
|
||||||
|
x.strides(),
|
||||||
|
x.flags());
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
} else {
|
||||||
|
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||||
|
copy_gpu(x, x_copy, CopyType::General, s);
|
||||||
|
out.copy_shared_buffer(x_copy);
|
||||||
|
return x_copy;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
array o = set_output(inputs[0]);
|
||||||
|
const array& x = o.data_shared_ptr() ? o : out;
|
||||||
|
const array& w = inputs[1];
|
||||||
|
const array& b = inputs[2];
|
||||||
|
|
||||||
|
int32_t axis_size = x.shape().back();
|
||||||
|
int32_t n_rows = x.data_size() / axis_size;
|
||||||
|
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||||
|
int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(x);
|
||||||
|
encoder.set_input_array(w);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "layernorm", CTYPE, {
|
||||||
|
using DataType = cuda_type_t<CTYPE>;
|
||||||
|
constexpr uint32_t N_READS = 4;
|
||||||
|
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||||
|
auto kernel = cu::layer_norm<DataType, BLOCK_DIM, N_READS>;
|
||||||
|
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||||
|
x.data<DataType>(),
|
||||||
|
w.data<DataType>(),
|
||||||
|
b.data<DataType>(),
|
||||||
|
out.data<DataType>(),
|
||||||
|
eps_,
|
||||||
|
axis_size,
|
||||||
|
w_stride,
|
||||||
|
b_stride);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void LayerNormVJP::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
nvtx3::scoped_range r("LayerNormVJP::eval_gpu");
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
|
// Ensure row contiguity. We could relax this step by checking that the array
|
||||||
|
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||||
|
// same as the cotangent strides but for now this is simpler.
|
||||||
|
auto check_input = [&s](const array& x) -> std::pair<array, bool> {
|
||||||
|
if (x.flags().row_contiguous) {
|
||||||
|
return {x, false};
|
||||||
|
}
|
||||||
|
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||||
|
copy_gpu(x, x_copy, CopyType::General, s);
|
||||||
|
return {x_copy, true};
|
||||||
|
};
|
||||||
|
bool donate_x = inputs[0].is_donatable();
|
||||||
|
bool donate_g = inputs[3].is_donatable();
|
||||||
|
auto [x, copied] = check_input(inputs[0]);
|
||||||
|
donate_x |= copied;
|
||||||
|
const array& w = inputs[1];
|
||||||
|
const array& b = inputs[2];
|
||||||
|
auto [g, g_copied] = check_input(inputs[3]);
|
||||||
|
donate_g |= g_copied;
|
||||||
|
array& gx = outputs[0];
|
||||||
|
array& gw = outputs[1];
|
||||||
|
array& gb = outputs[2];
|
||||||
|
|
||||||
|
// Check whether we had a weight.
|
||||||
|
bool has_w = w.ndim() != 0;
|
||||||
|
|
||||||
|
// Allocate space for the outputs.
|
||||||
|
bool g_in_gx = false;
|
||||||
|
if (donate_x) {
|
||||||
|
gx.copy_shared_buffer(x);
|
||||||
|
} else if (donate_g) {
|
||||||
|
gx.copy_shared_buffer(g);
|
||||||
|
g_in_gx = true;
|
||||||
|
} else {
|
||||||
|
gx.set_data(allocator::malloc(gx.nbytes()));
|
||||||
|
}
|
||||||
|
if (g_copied && !g_in_gx) {
|
||||||
|
encoder.add_temporary(g);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t axis_size = x.shape().back();
|
||||||
|
int32_t n_rows = x.data_size() / axis_size;
|
||||||
|
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||||
|
|
||||||
|
// Allocate a temporary to store the gradients for w and allocate the output
|
||||||
|
// gradient accumulators.
|
||||||
|
array gw_temp =
|
||||||
|
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
|
||||||
|
if (has_w) {
|
||||||
|
if (!g_in_gx && donate_g) {
|
||||||
|
gw_temp.copy_shared_buffer(g);
|
||||||
|
} else {
|
||||||
|
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||||
|
encoder.add_temporary(gw_temp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||||
|
gb.set_data(allocator::malloc(gb.nbytes()));
|
||||||
|
|
||||||
|
// Finish with the gradient for b in case we had a b.
|
||||||
|
if (gb.ndim() == 1 && gb.size() == axis_size) {
|
||||||
|
ReductionPlan plan(
|
||||||
|
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||||
|
col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan);
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.set_input_array(x);
|
||||||
|
encoder.set_input_array(w);
|
||||||
|
encoder.set_input_array(g);
|
||||||
|
encoder.set_output_array(gx);
|
||||||
|
encoder.set_output_array(gw_temp);
|
||||||
|
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "layernorm_vjp", CTYPE, {
|
||||||
|
using DataType = cuda_type_t<CTYPE>;
|
||||||
|
constexpr int N_READS = 4;
|
||||||
|
MLX_SWITCH_BOOL(has_w, HAS_W, {
|
||||||
|
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||||
|
auto kernel = cu::layer_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>;
|
||||||
|
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||||
|
x.data<DataType>(),
|
||||||
|
w.data<DataType>(),
|
||||||
|
g.data<DataType>(),
|
||||||
|
gx.data<DataType>(),
|
||||||
|
gw_temp.data<DataType>(),
|
||||||
|
eps_,
|
||||||
|
axis_size,
|
||||||
|
w_stride);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
if (has_w) {
|
||||||
|
ReductionPlan plan(
|
||||||
|
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||||
|
col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace fast
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
159
mlx/backend/cuda/logsumexp.cu
Normal file
159
mlx/backend/cuda/logsumexp.cu
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cooperative_groups/reduce.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <cub/block/block_load.cuh>
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline __device__ T softmax_exp(T x) {
|
||||||
|
// Softmax doesn't need high precision exponential cause x is gonna be in
|
||||||
|
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
|
||||||
|
return __expf(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4>
|
||||||
|
__global__ void logsumexp(const T* in, T* out, int axis_size) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
|
||||||
|
in += grid.block_rank() * axis_size;
|
||||||
|
|
||||||
|
cg::greater<AccT> max_op;
|
||||||
|
cg::plus<AccT> plus_op;
|
||||||
|
|
||||||
|
// Thread reduce.
|
||||||
|
AccT prevmax;
|
||||||
|
AccT maxval = Limits<AccT>::finite_min();
|
||||||
|
AccT normalizer = 0;
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||||
|
AccT vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
r * BLOCK_DIM + block.thread_rank(),
|
||||||
|
make_cast_iterator<AccT>(in),
|
||||||
|
vals,
|
||||||
|
axis_size,
|
||||||
|
Limits<AccT>::min());
|
||||||
|
prevmax = maxval;
|
||||||
|
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
|
||||||
|
// Online normalizer calculation for softmax:
|
||||||
|
// https://github.com/NVIDIA/online-softmax
|
||||||
|
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
normalizer = normalizer + softmax_exp(vals[i] - maxval);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// First warp reduce.
|
||||||
|
prevmax = maxval;
|
||||||
|
maxval = cg::reduce(warp, maxval, max_op);
|
||||||
|
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||||
|
normalizer = cg::reduce(warp, normalizer, plus_op);
|
||||||
|
|
||||||
|
__shared__ AccT local_max[WARP_SIZE];
|
||||||
|
__shared__ AccT local_normalizer[WARP_SIZE];
|
||||||
|
|
||||||
|
// Write to shared memory and do second warp reduce.
|
||||||
|
prevmax = maxval;
|
||||||
|
if (warp.thread_rank() == 0) {
|
||||||
|
local_max[warp.meta_group_rank()] = maxval;
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
maxval = warp.thread_rank() < warp.meta_group_size()
|
||||||
|
? local_max[warp.thread_rank()]
|
||||||
|
: Limits<AccT>::finite_min();
|
||||||
|
maxval = cg::reduce(warp, maxval, max_op);
|
||||||
|
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||||
|
if (warp.thread_rank() == 0) {
|
||||||
|
local_normalizer[warp.meta_group_rank()] = normalizer;
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
normalizer = warp.thread_rank() < warp.meta_group_size()
|
||||||
|
? local_normalizer[warp.thread_rank()]
|
||||||
|
: AccT{};
|
||||||
|
normalizer = cg::reduce(warp, normalizer, plus_op);
|
||||||
|
|
||||||
|
// Write output.
|
||||||
|
if (block.thread_rank() == 0) {
|
||||||
|
out[grid.block_rank()] = isinf(maxval) ? maxval : log(normalizer) + maxval;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("LogSumExp::eval_gpu");
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
|
// Make sure that the last dimension is contiguous.
|
||||||
|
auto ensure_contiguous = [&s, &encoder](const array& x) {
|
||||||
|
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||||
|
return x;
|
||||||
|
} else {
|
||||||
|
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||||
|
copy_gpu(x, x_copy, CopyType::General, s);
|
||||||
|
encoder.add_temporary(x_copy);
|
||||||
|
return x_copy;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto in = ensure_contiguous(inputs[0]);
|
||||||
|
if (in.flags().row_contiguous) {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
} else {
|
||||||
|
auto n = in.shape(-1);
|
||||||
|
auto flags = in.flags();
|
||||||
|
auto strides = in.strides();
|
||||||
|
for (auto& s : strides) {
|
||||||
|
s /= n;
|
||||||
|
}
|
||||||
|
bool col_contig = strides[0] == 1;
|
||||||
|
for (int i = 1; col_contig && i < strides.size(); ++i) {
|
||||||
|
col_contig &=
|
||||||
|
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
|
||||||
|
}
|
||||||
|
flags.col_contiguous = col_contig;
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc(in.nbytes() / n),
|
||||||
|
in.data_size() / n,
|
||||||
|
std::move(strides),
|
||||||
|
flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
int axis_size = in.shape().back();
|
||||||
|
int n_rows = in.data_size() / axis_size;
|
||||||
|
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "logsumexp", CTYPE, {
|
||||||
|
using DataType = cuda_type_t<CTYPE>;
|
||||||
|
constexpr int N_READS = 4;
|
||||||
|
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||||
|
auto kernel = cu::logsumexp<DataType, float, BLOCK_DIM, N_READS>;
|
||||||
|
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||||
|
in.data<DataType>(), out.data<DataType>(), axis_size);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -72,7 +72,6 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
|||||||
}
|
}
|
||||||
|
|
||||||
NO_GPU(ArgPartition)
|
NO_GPU(ArgPartition)
|
||||||
NO_GPU(ArgReduce)
|
|
||||||
NO_GPU(BlockMaskedMM)
|
NO_GPU(BlockMaskedMM)
|
||||||
NO_GPU_MULTI(Compiled)
|
NO_GPU_MULTI(Compiled)
|
||||||
NO_GPU(Convolution)
|
NO_GPU(Convolution)
|
||||||
@ -86,18 +85,15 @@ NO_GPU(GatherMM)
|
|||||||
NO_GPU(GatherQMM)
|
NO_GPU(GatherQMM)
|
||||||
NO_GPU(Hadamard)
|
NO_GPU(Hadamard)
|
||||||
NO_GPU(Load)
|
NO_GPU(Load)
|
||||||
NO_GPU(LogSumExp)
|
|
||||||
NO_GPU_MULTI(LUF)
|
NO_GPU_MULTI(LUF)
|
||||||
NO_GPU(Partition)
|
NO_GPU(Partition)
|
||||||
NO_GPU_MULTI(QRF)
|
NO_GPU_MULTI(QRF)
|
||||||
NO_GPU(QuantizedMatmul)
|
NO_GPU(QuantizedMatmul)
|
||||||
NO_GPU(Reduce)
|
|
||||||
NO_GPU(Scan)
|
NO_GPU(Scan)
|
||||||
NO_GPU(Scatter)
|
NO_GPU(Scatter)
|
||||||
NO_GPU(ScatterAxis)
|
NO_GPU(ScatterAxis)
|
||||||
NO_GPU(Select)
|
NO_GPU(Select)
|
||||||
NO_GPU(SliceUpdate)
|
NO_GPU(SliceUpdate)
|
||||||
NO_GPU(Softmax)
|
|
||||||
NO_GPU_MULTI(SVD)
|
NO_GPU_MULTI(SVD)
|
||||||
NO_GPU(Inverse)
|
NO_GPU(Inverse)
|
||||||
NO_GPU(Cholesky)
|
NO_GPU(Cholesky)
|
||||||
@ -105,8 +101,6 @@ NO_GPU_MULTI(Eig)
|
|||||||
NO_GPU_MULTI(Eigh)
|
NO_GPU_MULTI(Eigh)
|
||||||
|
|
||||||
namespace fast {
|
namespace fast {
|
||||||
NO_GPU_USE_FALLBACK(LayerNorm)
|
|
||||||
NO_GPU_MULTI(LayerNormVJP)
|
|
||||||
NO_GPU_USE_FALLBACK(RMSNorm)
|
NO_GPU_USE_FALLBACK(RMSNorm)
|
||||||
NO_GPU_MULTI(RMSNormVJP)
|
NO_GPU_MULTI(RMSNormVJP)
|
||||||
NO_GPU_USE_FALLBACK(RoPE)
|
NO_GPU_USE_FALLBACK(RoPE)
|
||||||
|
82
mlx/backend/cuda/reduce.cu
Normal file
82
mlx/backend/cuda/reduce.cu
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <thrust/device_ptr.h>
|
||||||
|
#include <thrust/fill.h>
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Reduce::eval_gpu");
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
array in = inputs[0];
|
||||||
|
|
||||||
|
// Make sure no identity reductions trickle down here.
|
||||||
|
assert(!axes_.empty());
|
||||||
|
assert(out.size() != in.size());
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
// Fill out with init value.
|
||||||
|
if (in.size() == 0) {
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, {
|
||||||
|
using InType = cuda_type_t<CTYPE>;
|
||||||
|
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||||
|
thrust::fill_n(
|
||||||
|
cu::thrust_policy(stream),
|
||||||
|
thrust::device_pointer_cast(out.data<OutType>()),
|
||||||
|
out.data_size(),
|
||||||
|
cu::ReduceInit<OP, InType>::value());
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reduce.
|
||||||
|
ReductionPlan plan = get_reduction_plan(in, axes_);
|
||||||
|
|
||||||
|
// If it is a general reduce then copy the input to a contiguous array and
|
||||||
|
// recompute the plan.
|
||||||
|
if (plan.type == GeneralReduce) {
|
||||||
|
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
||||||
|
copy_gpu(in, in_copy, CopyType::General, s);
|
||||||
|
encoder.add_temporary(in_copy);
|
||||||
|
in = in_copy;
|
||||||
|
plan = get_reduction_plan(in, axes_);
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((plan.type == ContiguousAllReduce) ||
|
||||||
|
(plan.type == ContiguousReduce && plan.shape.size() == 1)) {
|
||||||
|
segmented_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {
|
||||||
|
row_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (plan.type == ContiguousStridedReduce ||
|
||||||
|
plan.type == GeneralStridedReduce) {
|
||||||
|
col_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
throw std::runtime_error("No plan reached in reduce.");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
278
mlx/backend/cuda/reduce/col_reduce.cu
Normal file
278
mlx/backend/cuda/reduce/col_reduce.cu
Normal file
@ -0,0 +1,278 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||||
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cooperative_groups/reduce.h>
|
||||||
|
#include <cub/block/block_load.cuh>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
struct ColReduceArgs {
|
||||||
|
// The size of the contiguous column reduction.
|
||||||
|
size_t reduction_size;
|
||||||
|
int64_t reduction_stride;
|
||||||
|
|
||||||
|
// Input shape and strides excluding the reduction axes.
|
||||||
|
Shape shape;
|
||||||
|
Strides strides;
|
||||||
|
int ndim;
|
||||||
|
|
||||||
|
// Input shape and strides of the reduction axes (including last dimension).
|
||||||
|
Shape reduce_shape;
|
||||||
|
Strides reduce_strides;
|
||||||
|
int reduce_ndim;
|
||||||
|
|
||||||
|
// The number of column we are reducing. Namely prod(reduce_shape).
|
||||||
|
size_t non_col_reductions;
|
||||||
|
|
||||||
|
ColReduceArgs(
|
||||||
|
const array& in,
|
||||||
|
const ReductionPlan& plan,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
assert(!plan.shape.empty());
|
||||||
|
reduction_size = plan.shape.back();
|
||||||
|
reduction_stride = plan.strides.back();
|
||||||
|
|
||||||
|
int64_t stride_back = 1;
|
||||||
|
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
|
||||||
|
while (!shape_vec.empty() && stride_back < reduction_stride) {
|
||||||
|
stride_back *= shape_vec.back();
|
||||||
|
shape_vec.pop_back();
|
||||||
|
strides_vec.pop_back();
|
||||||
|
}
|
||||||
|
std::tie(shape_vec, strides_vec) =
|
||||||
|
collapse_contiguous_dims(shape_vec, strides_vec);
|
||||||
|
shape = const_param(shape_vec);
|
||||||
|
strides = const_param(strides_vec);
|
||||||
|
ndim = shape_vec.size();
|
||||||
|
|
||||||
|
reduce_shape = const_param(plan.shape);
|
||||||
|
reduce_strides = const_param(plan.strides);
|
||||||
|
reduce_ndim = plan.shape.size();
|
||||||
|
|
||||||
|
non_col_reductions = 1;
|
||||||
|
for (int i = 0; i < reduce_ndim - 1; i++) {
|
||||||
|
non_col_reductions *= reduce_shape[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||||
|
__global__ void col_reduce_small(
|
||||||
|
const T* in,
|
||||||
|
U* out,
|
||||||
|
const __grid_constant__ ColReduceArgs args) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
|
int column =
|
||||||
|
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||||
|
if (column * N_READS >= args.reduction_stride) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
||||||
|
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||||
|
|
||||||
|
Op op;
|
||||||
|
U totals[N_READS];
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
totals[i] = ReduceInit<Op, T>::value();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read input to local.
|
||||||
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
|
loop.next(
|
||||||
|
block.thread_index().y,
|
||||||
|
args.reduce_shape.data(),
|
||||||
|
args.reduce_strides.data());
|
||||||
|
for (size_t r = block.thread_index().y;
|
||||||
|
r < args.non_col_reductions * args.reduction_size;
|
||||||
|
r += block.dim_threads().y) {
|
||||||
|
U vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
column,
|
||||||
|
make_cast_iterator<U>(in + loop.location()),
|
||||||
|
vals,
|
||||||
|
args.reduction_stride,
|
||||||
|
ReduceInit<Op, T>::value());
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
totals[i] = op(vals[i], totals[i]);
|
||||||
|
}
|
||||||
|
loop.next(
|
||||||
|
block.dim_threads().y,
|
||||||
|
args.reduce_shape.data(),
|
||||||
|
args.reduce_strides.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do block reduce when each column has more than 1 element to reduce.
|
||||||
|
if (block.dim_threads().y > 1) {
|
||||||
|
__shared__ U shared_vals[32 * 8 * N_READS];
|
||||||
|
size_t col =
|
||||||
|
block.thread_index().y * block.dim_threads().x + block.thread_index().x;
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
shared_vals[col * N_READS + i] = totals[i];
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
if (block.thread_index().y == 0) {
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
totals[i] = shared_vals[block.thread_index().x * N_READS + i];
|
||||||
|
}
|
||||||
|
for (int j = 1; j < block.dim_threads().y; j++) {
|
||||||
|
col = j * block.dim_threads().x + block.thread_index().x;
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
totals[i] = op(shared_vals[col * N_READS + i], totals[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write result.
|
||||||
|
if (block.thread_index().y == 0) {
|
||||||
|
cub::StoreDirectBlocked(
|
||||||
|
column,
|
||||||
|
out + out_idx * args.reduction_stride,
|
||||||
|
totals,
|
||||||
|
args.reduction_stride);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
typename Op,
|
||||||
|
int NDIM,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int N_READS = 4>
|
||||||
|
__global__ void col_reduce_looped(
|
||||||
|
const T* in,
|
||||||
|
U* out,
|
||||||
|
const __grid_constant__ ColReduceArgs args) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
|
||||||
|
constexpr int n_warps = BN / N_READS;
|
||||||
|
|
||||||
|
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
||||||
|
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||||
|
|
||||||
|
Op op;
|
||||||
|
U totals[N_READS];
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
totals[i] = ReduceInit<Op, T>::value();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read input to local.
|
||||||
|
int r = block.thread_rank() / n_warps;
|
||||||
|
int column = block.thread_rank() % n_warps;
|
||||||
|
int in_offset = grid.block_index().x * BN;
|
||||||
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
|
loop.next(r, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
|
for (; r < args.non_col_reductions * args.reduction_size; r += BM) {
|
||||||
|
U vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
column,
|
||||||
|
make_cast_iterator<U>(in + loop.location() + in_offset),
|
||||||
|
vals,
|
||||||
|
args.reduction_stride - in_offset,
|
||||||
|
ReduceInit<Op, T>::value());
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
totals[i] = op(vals[i], totals[i]);
|
||||||
|
}
|
||||||
|
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do warp reduce for each output.
|
||||||
|
constexpr int n_outputs = BN / n_warps;
|
||||||
|
static_assert(BM == 32 && n_outputs == N_READS);
|
||||||
|
__shared__ U shared_vals[BM * BN];
|
||||||
|
size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS;
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
shared_vals[col + i] = totals[i];
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
|
||||||
|
for (int i = 0; i < n_outputs; i++) {
|
||||||
|
totals[i] = cg::reduce(warp, shared_vals[col + i], op);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write result.
|
||||||
|
if (warp.thread_rank() == 0) {
|
||||||
|
size_t out_offset = grid.block_index().x * BN;
|
||||||
|
cub::StoreDirectBlocked(
|
||||||
|
warp.meta_group_rank(),
|
||||||
|
out + out_idx * args.reduction_stride + out_offset,
|
||||||
|
totals,
|
||||||
|
args.reduction_stride - out_offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
inline auto output_grid_for_col_reduce(
|
||||||
|
const array& out,
|
||||||
|
const cu::ColReduceArgs& args) {
|
||||||
|
auto out_shape = out.shape();
|
||||||
|
auto out_strides = out.strides();
|
||||||
|
while (!out_shape.empty() && out_strides.back() < args.reduction_stride) {
|
||||||
|
out_shape.pop_back();
|
||||||
|
out_strides.pop_back();
|
||||||
|
}
|
||||||
|
return get_2d_grid_dims(out_shape, out_strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
void col_reduce(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan) {
|
||||||
|
cu::ColReduceArgs args(in, plan, axes);
|
||||||
|
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
|
using InType = cuda_type_t<CTYPE>;
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||||
|
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||||
|
constexpr int N_READS = 4;
|
||||||
|
dim3 block_dims;
|
||||||
|
dim3 num_blocks = output_grid_for_col_reduce(out, args);
|
||||||
|
num_blocks.z = num_blocks.y;
|
||||||
|
num_blocks.y = num_blocks.x;
|
||||||
|
auto kernel =
|
||||||
|
cu::col_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
||||||
|
size_t total = args.non_col_reductions * args.reduction_size;
|
||||||
|
if (total < 32) {
|
||||||
|
size_t stride_blocks =
|
||||||
|
cuda::ceil_div(args.reduction_stride, N_READS);
|
||||||
|
block_dims.x = std::min(stride_blocks, 32ul);
|
||||||
|
block_dims.y = std::min(total, 8ul);
|
||||||
|
num_blocks.x = cuda::ceil_div(stride_blocks, block_dims.x);
|
||||||
|
} else {
|
||||||
|
constexpr int BM = 32;
|
||||||
|
constexpr int BN = 32;
|
||||||
|
block_dims.x = BM * BN / N_READS;
|
||||||
|
num_blocks.x = cuda::ceil_div(args.reduction_stride, BN);
|
||||||
|
kernel = cu::
|
||||||
|
col_reduce_looped<InType, OutType, OP, NDIM, BM, BN, N_READS>;
|
||||||
|
}
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
in.data<InType>(), out.data<OutType>(), args);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
74
mlx/backend/cuda/reduce/reduce.cuh
Normal file
74
mlx/backend/cuda/reduce/reduce.cuh
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/reduce.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
|
||||||
|
#include "mlx/backend/cuda/reduce/reduce_ops.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
// Dispatch dynamic ndim to constexpr.
|
||||||
|
// The behavior follows get_kernel_reduce_ndim in metal/reduce.cpp file.
|
||||||
|
#define MLX_SWITCH_REDUCE_NDIM(ndim, NDIM, ...) \
|
||||||
|
if (ndim == 1) { \
|
||||||
|
constexpr uint32_t NDIM = 1; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else if (ndim == 2) { \
|
||||||
|
constexpr uint32_t NDIM = 2; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else { \
|
||||||
|
constexpr uint32_t NDIM = 5; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dispatch reduce ops to constexpr.
|
||||||
|
#define MLX_SWITCH_REDUCE_OPS(REDUCE, OP, ...) \
|
||||||
|
if (REDUCE == Reduce::ReduceType::And) { \
|
||||||
|
using OP = cu::And; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else if (REDUCE == Reduce::ReduceType::Or) { \
|
||||||
|
using OP = cu::Or; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else if (REDUCE == Reduce::ReduceType::Sum) { \
|
||||||
|
using OP = cu::Sum; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else if (REDUCE == Reduce::ReduceType::Prod) { \
|
||||||
|
using OP = cu::Prod; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else if (REDUCE == Reduce::ReduceType::Max) { \
|
||||||
|
using OP = cu::Max; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else if (REDUCE == Reduce::ReduceType::Min) { \
|
||||||
|
using OP = cu::Min; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else { \
|
||||||
|
throw std::invalid_argument("Unknown reduce type."); \
|
||||||
|
}
|
||||||
|
|
||||||
|
void segmented_reduce(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan);
|
||||||
|
|
||||||
|
void row_reduce(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan);
|
||||||
|
|
||||||
|
void col_reduce(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
144
mlx/backend/cuda/reduce/reduce_ops.cuh
Normal file
144
mlx/backend/cuda/reduce/reduce_ops.cuh
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/kernels/utils.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
// Reduce ops.
|
||||||
|
struct And {
|
||||||
|
__device__ bool operator()(bool a, bool b) {
|
||||||
|
return a && b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Or {
|
||||||
|
__device__ bool operator()(bool a, bool b) {
|
||||||
|
return a || b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Sum {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T a, T b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Prod {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T a, T b) {
|
||||||
|
return a * b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Min {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T a, T b) {
|
||||||
|
return a < b ? a : b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Max {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T a, T b) {
|
||||||
|
return a > b ? a : b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Traits to get the result type of reduce op.
|
||||||
|
template <typename Op, typename T>
|
||||||
|
struct ReduceResult;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceResult<And, T> {
|
||||||
|
using type = bool;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceResult<Or, T> {
|
||||||
|
using type = bool;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceResult<Sum, T> {
|
||||||
|
using type = cuda::std::conditional_t<
|
||||||
|
(cuda::std::is_integral_v<T> && sizeof(T) <= 4),
|
||||||
|
int32_t,
|
||||||
|
T>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceResult<Prod, T> {
|
||||||
|
using type = cuda::std::conditional_t<
|
||||||
|
(cuda::std::is_integral_v<T> && sizeof(T) <= 4),
|
||||||
|
int32_t,
|
||||||
|
T>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceResult<Min, T> {
|
||||||
|
using type = T;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceResult<Max, T> {
|
||||||
|
using type = T;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Traits to get the init value of reduce op.
|
||||||
|
template <typename Op, typename T>
|
||||||
|
struct ReduceInit;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceInit<And, T> {
|
||||||
|
static constexpr __host__ __device__ bool value() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceInit<Or, T> {
|
||||||
|
static constexpr __host__ __device__ bool value() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceInit<Sum, T> {
|
||||||
|
static constexpr __host__ __device__ auto value() {
|
||||||
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
return T{0, 0};
|
||||||
|
} else {
|
||||||
|
return typename ReduceResult<Sum, T>::type{0};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceInit<Prod, T> {
|
||||||
|
static constexpr __host__ __device__ auto value() {
|
||||||
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
return T{1, 1};
|
||||||
|
} else {
|
||||||
|
return typename ReduceResult<Prod, T>::type{1};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceInit<Min, T> {
|
||||||
|
static constexpr __host__ __device__ T value() {
|
||||||
|
return Limits<T>::max();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceInit<Max, T> {
|
||||||
|
static constexpr __host__ __device__ T value() {
|
||||||
|
return Limits<T>::min();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
250
mlx/backend/cuda/reduce/row_reduce.cu
Normal file
250
mlx/backend/cuda/reduce/row_reduce.cu
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||||
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cooperative_groups/reduce.h>
|
||||||
|
#include <cub/block/block_load.cuh>
|
||||||
|
#include <cub/block/block_reduce.cuh>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
struct RowReduceArgs {
|
||||||
|
// The size of the row being reduced, i.e. the size of last dimension.
|
||||||
|
int row_size;
|
||||||
|
|
||||||
|
// Input shape and strides excluding the reduction axes.
|
||||||
|
Shape shape;
|
||||||
|
Strides strides;
|
||||||
|
int ndim;
|
||||||
|
|
||||||
|
// Input shape and strides of the reduction axes excluding last dimension.
|
||||||
|
Shape reduce_shape;
|
||||||
|
Strides reduce_strides;
|
||||||
|
int reduce_ndim;
|
||||||
|
|
||||||
|
// The number of rows we are reducing. Namely prod(reduce_shape).
|
||||||
|
size_t non_row_reductions;
|
||||||
|
|
||||||
|
RowReduceArgs(
|
||||||
|
const array& in,
|
||||||
|
const ReductionPlan& plan,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
assert(!plan.shape.empty());
|
||||||
|
row_size = plan.shape.back();
|
||||||
|
|
||||||
|
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
|
||||||
|
std::tie(shape_vec, strides_vec) =
|
||||||
|
collapse_contiguous_dims(shape_vec, strides_vec);
|
||||||
|
shape = const_param(shape_vec);
|
||||||
|
strides = const_param(strides_vec);
|
||||||
|
ndim = shape_vec.size();
|
||||||
|
|
||||||
|
reduce_shape = const_param(plan.shape);
|
||||||
|
reduce_strides = const_param(plan.strides);
|
||||||
|
reduce_ndim = plan.shape.size() - 1;
|
||||||
|
|
||||||
|
non_row_reductions = 1;
|
||||||
|
for (int i = 0; i < reduce_ndim; i++) {
|
||||||
|
non_row_reductions *= reduce_shape[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||||
|
__global__ void row_reduce_small(
|
||||||
|
const T* in,
|
||||||
|
U* out,
|
||||||
|
size_t out_size,
|
||||||
|
const __grid_constant__ RowReduceArgs args) {
|
||||||
|
size_t out_idx = cg::this_grid().thread_rank();
|
||||||
|
if (out_idx >= out_size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Op op;
|
||||||
|
|
||||||
|
U total_val = ReduceInit<Op, T>::value();
|
||||||
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
|
|
||||||
|
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||||
|
|
||||||
|
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||||
|
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
|
||||||
|
U vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
r,
|
||||||
|
make_cast_iterator<U>(in + loop.location()),
|
||||||
|
vals,
|
||||||
|
args.row_size,
|
||||||
|
ReduceInit<Op, T>::value());
|
||||||
|
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||||
|
}
|
||||||
|
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
out[out_idx] = total_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||||
|
__global__ void row_reduce_small_warp(
|
||||||
|
const T* in,
|
||||||
|
U* out,
|
||||||
|
size_t out_size,
|
||||||
|
const __grid_constant__ RowReduceArgs args) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
|
||||||
|
size_t out_idx = grid.thread_rank() / WARP_SIZE;
|
||||||
|
if (out_idx >= out_size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Op op;
|
||||||
|
|
||||||
|
U total_val = ReduceInit<Op, T>::value();
|
||||||
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
|
|
||||||
|
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||||
|
|
||||||
|
for (size_t n = warp.thread_rank(); n < args.non_row_reductions;
|
||||||
|
n += WARP_SIZE) {
|
||||||
|
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
|
||||||
|
U vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
r,
|
||||||
|
make_cast_iterator<U>(in + loop.location()),
|
||||||
|
vals,
|
||||||
|
args.row_size,
|
||||||
|
ReduceInit<Op, T>::value());
|
||||||
|
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||||
|
}
|
||||||
|
loop.next(WARP_SIZE, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
total_val = cg::reduce(warp, total_val, op);
|
||||||
|
|
||||||
|
if (warp.thread_rank() == 0) {
|
||||||
|
out[out_idx] = total_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
typename Op,
|
||||||
|
int NDIM,
|
||||||
|
int BLOCK_DIM_X,
|
||||||
|
int N_READS = 4>
|
||||||
|
__global__ void row_reduce_looped(
|
||||||
|
const T* in,
|
||||||
|
U* out,
|
||||||
|
size_t out_size,
|
||||||
|
const __grid_constant__ RowReduceArgs args) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
|
size_t out_idx = grid.thread_rank() / BLOCK_DIM_X;
|
||||||
|
if (out_idx >= out_size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Op op;
|
||||||
|
|
||||||
|
U total_val = ReduceInit<Op, T>::value();
|
||||||
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
|
|
||||||
|
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||||
|
|
||||||
|
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||||
|
for (size_t r = 0; r < cuda::ceil_div(args.row_size, BLOCK_DIM_X * N_READS);
|
||||||
|
r++) {
|
||||||
|
U vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
r * BLOCK_DIM_X + block.thread_index().x,
|
||||||
|
make_cast_iterator<U>(in + loop.location()),
|
||||||
|
vals,
|
||||||
|
args.row_size,
|
||||||
|
ReduceInit<Op, T>::value());
|
||||||
|
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||||
|
}
|
||||||
|
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef cub::BlockReduce<U, BLOCK_DIM_X> BlockReduceT;
|
||||||
|
__shared__ typename BlockReduceT::TempStorage temp;
|
||||||
|
|
||||||
|
total_val = BlockReduceT(temp).Reduce(total_val, op);
|
||||||
|
|
||||||
|
if (block.thread_rank() == 0) {
|
||||||
|
out[out_idx] = total_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void row_reduce(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan) {
|
||||||
|
cu::RowReduceArgs args(in, plan, axes);
|
||||||
|
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
|
using InType = cuda_type_t<CTYPE>;
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||||
|
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||||
|
constexpr size_t N_READS = 4;
|
||||||
|
dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
|
dim3 block_dims, num_blocks;
|
||||||
|
auto kernel =
|
||||||
|
cu::row_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
||||||
|
if (args.row_size <= 64) {
|
||||||
|
if ((args.non_row_reductions < 32 && args.row_size <= 8) ||
|
||||||
|
(args.non_row_reductions <= 8)) {
|
||||||
|
block_dims.x = std::min(out_dims.x, 1024u);
|
||||||
|
num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x);
|
||||||
|
num_blocks.y = out_dims.y;
|
||||||
|
} else {
|
||||||
|
block_dims.x = WARP_SIZE;
|
||||||
|
num_blocks.y = out_dims.x;
|
||||||
|
num_blocks.z = out_dims.y;
|
||||||
|
kernel =
|
||||||
|
cu::row_reduce_small_warp<InType, OutType, OP, NDIM, N_READS>;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
size_t num_threads = cuda::ceil_div(args.row_size, N_READS);
|
||||||
|
num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE;
|
||||||
|
MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, {
|
||||||
|
num_blocks.y = out_dims.x;
|
||||||
|
num_blocks.z = out_dims.y;
|
||||||
|
block_dims.x = BLOCK_DIM_X;
|
||||||
|
kernel = cu::row_reduce_looped<
|
||||||
|
InType,
|
||||||
|
OutType,
|
||||||
|
OP,
|
||||||
|
NDIM,
|
||||||
|
BLOCK_DIM_X,
|
||||||
|
N_READS>;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
in.data<InType>(), out.data<OutType>(), out.size(), args);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
84
mlx/backend/cuda/reduce/segmented_reduce.cu
Normal file
84
mlx/backend/cuda/reduce/segmented_reduce.cu
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||||
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
|
||||||
|
#include <thrust/device_ptr.h>
|
||||||
|
#include <cub/device/device_reduce.cuh>
|
||||||
|
#include <cub/device/device_segmented_reduce.cuh>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||||
|
// Allocate temporary storage.
|
||||||
|
size_t size;
|
||||||
|
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...));
|
||||||
|
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||||
|
encoder.add_temporary(temp);
|
||||||
|
// Run op.
|
||||||
|
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data<void>(), size, args...));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||||
|
// Allocate temporary storage.
|
||||||
|
size_t size;
|
||||||
|
CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...));
|
||||||
|
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||||
|
encoder.add_temporary(temp);
|
||||||
|
// Run op.
|
||||||
|
CHECK_CUDA_ERROR(
|
||||||
|
cub::DeviceSegmentedReduce::Reduce(temp.data<void>(), size, args...));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MultiplyOp {
|
||||||
|
int factor;
|
||||||
|
__device__ int operator()(int i) {
|
||||||
|
return i * factor;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void segmented_reduce(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan) {
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
using InType = cuda_type_t<CTYPE>;
|
||||||
|
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||||
|
auto in_iter = cu::make_cast_iterator<OutType>(
|
||||||
|
thrust::device_pointer_cast(in.data<InType>()));
|
||||||
|
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
|
||||||
|
auto init = cu::ReduceInit<OP, InType>::value();
|
||||||
|
|
||||||
|
if (plan.type == ContiguousAllReduce) {
|
||||||
|
cub_all_reduce(
|
||||||
|
encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream);
|
||||||
|
} else if (plan.type == ContiguousReduce) {
|
||||||
|
auto offsets = thrust::make_transform_iterator(
|
||||||
|
thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()});
|
||||||
|
cub_segmented_reduce(
|
||||||
|
encoder,
|
||||||
|
in_iter,
|
||||||
|
out_ptr,
|
||||||
|
out.size(),
|
||||||
|
offsets,
|
||||||
|
offsets + 1,
|
||||||
|
OP(),
|
||||||
|
init,
|
||||||
|
stream);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Unsupported plan in segmented_reduce.");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
160
mlx/backend/cuda/softmax.cu
Normal file
160
mlx/backend/cuda/softmax.cu
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cooperative_groups/reduce.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <cub/block/block_load.cuh>
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline __device__ T softmax_exp(T x) {
|
||||||
|
// Softmax doesn't need high precision exponential cause x is gonna be in
|
||||||
|
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
|
||||||
|
return __expf(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4>
|
||||||
|
__global__ void softmax(const T* in, T* out, int axis_size) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
|
||||||
|
in += grid.block_rank() * axis_size;
|
||||||
|
out += grid.block_rank() * axis_size;
|
||||||
|
|
||||||
|
cg::greater<AccT> max_op;
|
||||||
|
cg::plus<AccT> plus_op;
|
||||||
|
|
||||||
|
// Thread reduce.
|
||||||
|
AccT prevmax;
|
||||||
|
AccT maxval = Limits<AccT>::finite_min();
|
||||||
|
AccT normalizer = 0;
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||||
|
AccT vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
r * BLOCK_DIM + block.thread_rank(),
|
||||||
|
make_cast_iterator<AccT>(in),
|
||||||
|
vals,
|
||||||
|
axis_size,
|
||||||
|
Limits<AccT>::finite_min());
|
||||||
|
prevmax = maxval;
|
||||||
|
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
|
||||||
|
// Online normalizer calculation for softmax:
|
||||||
|
// https://github.com/NVIDIA/online-softmax
|
||||||
|
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
normalizer = normalizer + softmax_exp(vals[i] - maxval);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// First warp reduce.
|
||||||
|
prevmax = maxval;
|
||||||
|
maxval = cg::reduce(warp, maxval, max_op);
|
||||||
|
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||||
|
normalizer = cg::reduce(warp, normalizer, plus_op);
|
||||||
|
|
||||||
|
__shared__ AccT local_max[WARP_SIZE];
|
||||||
|
__shared__ AccT local_normalizer[WARP_SIZE];
|
||||||
|
|
||||||
|
// Write to shared memory and do second warp reduce.
|
||||||
|
prevmax = maxval;
|
||||||
|
if (warp.thread_rank() == 0) {
|
||||||
|
local_max[warp.meta_group_rank()] = maxval;
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
maxval = warp.thread_rank() < warp.meta_group_size()
|
||||||
|
? local_max[warp.thread_rank()]
|
||||||
|
: Limits<AccT>::finite_min();
|
||||||
|
maxval = cg::reduce(warp, maxval, max_op);
|
||||||
|
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||||
|
if (warp.thread_rank() == 0) {
|
||||||
|
local_normalizer[warp.meta_group_rank()] = normalizer;
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
normalizer = warp.thread_rank() < warp.meta_group_size()
|
||||||
|
? local_normalizer[warp.thread_rank()]
|
||||||
|
: AccT{};
|
||||||
|
normalizer = cg::reduce(warp, normalizer, plus_op);
|
||||||
|
normalizer = 1 / normalizer;
|
||||||
|
|
||||||
|
// Write output.
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||||
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
|
T vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(index, in, vals, axis_size);
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
vals[i] = softmax_exp(static_cast<AccT>(vals[i]) - maxval) * normalizer;
|
||||||
|
}
|
||||||
|
cub::StoreDirectBlocked(index, out, vals, axis_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Softmax::eval_gpu");
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& s = stream();
|
||||||
|
|
||||||
|
// Make sure that the last dimension is contiguous.
|
||||||
|
auto set_output = [&s, &out](const array& x) {
|
||||||
|
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||||
|
if (x.is_donatable()) {
|
||||||
|
out.copy_shared_buffer(x);
|
||||||
|
} else {
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc(x.data_size() * x.itemsize()),
|
||||||
|
x.data_size(),
|
||||||
|
x.strides(),
|
||||||
|
x.flags());
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
} else {
|
||||||
|
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||||
|
copy_gpu(x, x_copy, CopyType::General, s);
|
||||||
|
out.copy_shared_buffer(x_copy);
|
||||||
|
return x_copy;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
array in = set_output(inputs[0]);
|
||||||
|
bool precise = in.dtype() != float32 && precise_;
|
||||||
|
|
||||||
|
int axis_size = in.shape().back();
|
||||||
|
int n_rows = in.data_size() / axis_size;
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "softmax", CTYPE, {
|
||||||
|
using DataType = cuda_type_t<CTYPE>;
|
||||||
|
constexpr int N_READS = 4;
|
||||||
|
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||||
|
auto kernel = cu::softmax<DataType, DataType, BLOCK_DIM, N_READS>;
|
||||||
|
if (precise) {
|
||||||
|
kernel = cu::softmax<DataType, float, BLOCK_DIM, N_READS>;
|
||||||
|
}
|
||||||
|
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||||
|
in.data<DataType>(), out.data<DataType>(), axis_size);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
Loading…
Reference in New Issue
Block a user