From c371baf53a7b3faa32e67bec06482f7b239c1995 Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 12 Jun 2025 05:55:22 +0900 Subject: [PATCH] CUDA backend: softmax (#2272) --- mlx/backend/cuda/CMakeLists.txt | 2 + mlx/backend/cuda/logsumexp.cu | 159 +++++++++++++++++++++++++++++++ mlx/backend/cuda/primitives.cu | 2 - mlx/backend/cuda/softmax.cu | 160 ++++++++++++++++++++++++++++++++ 4 files changed, 321 insertions(+), 2 deletions(-) create mode 100644 mlx/backend/cuda/logsumexp.cu create mode 100644 mlx/backend/cuda/softmax.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index ab0d5fe7c..410e24096 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -20,6 +20,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/random.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu @@ -27,6 +28,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu ${CMAKE_CURRENT_SOURCE_DIR}/unary.cu ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp diff --git a/mlx/backend/cuda/logsumexp.cu b/mlx/backend/cuda/logsumexp.cu new file mode 100644 index 000000000..e539ac559 --- /dev/null +++ b/mlx/backend/cuda/logsumexp.cu @@ -0,0 +1,159 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/kernels/cast_op.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include +#include + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +inline __device__ T softmax_exp(T x) { + // Softmax doesn't need high precision exponential cause x is gonna be in + // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)). + return __expf(x); +} + +template +__global__ void logsumexp(const T* in, T* out, int axis_size) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + in += grid.block_rank() * axis_size; + + cg::greater max_op; + cg::plus plus_op; + + // Thread reduce. + AccT prevmax; + AccT maxval = Limits::finite_min(); + AccT normalizer = 0; + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { + AccT vals[N_READS]; + cub::LoadDirectBlocked( + r * BLOCK_DIM + block.thread_rank(), + make_cast_iterator(in), + vals, + axis_size, + Limits::min()); + prevmax = maxval; + maxval = max_op(maxval, cub::ThreadReduce(vals, max_op)); + // Online normalizer calculation for softmax: + // https://github.com/NVIDIA/online-softmax + normalizer = normalizer * softmax_exp(prevmax - maxval); + for (int i = 0; i < N_READS; i++) { + normalizer = normalizer + softmax_exp(vals[i] - maxval); + } + } + + // First warp reduce. + prevmax = maxval; + maxval = cg::reduce(warp, maxval, max_op); + normalizer = normalizer * softmax_exp(prevmax - maxval); + normalizer = cg::reduce(warp, normalizer, plus_op); + + __shared__ AccT local_max[WARP_SIZE]; + __shared__ AccT local_normalizer[WARP_SIZE]; + + // Write to shared memory and do second warp reduce. + prevmax = maxval; + if (warp.thread_rank() == 0) { + local_max[warp.meta_group_rank()] = maxval; + } + block.sync(); + maxval = warp.thread_rank() < warp.meta_group_size() + ? local_max[warp.thread_rank()] + : Limits::finite_min(); + maxval = cg::reduce(warp, maxval, max_op); + normalizer = normalizer * softmax_exp(prevmax - maxval); + if (warp.thread_rank() == 0) { + local_normalizer[warp.meta_group_rank()] = normalizer; + } + block.sync(); + normalizer = warp.thread_rank() < warp.meta_group_size() + ? local_normalizer[warp.thread_rank()] + : AccT{}; + normalizer = cg::reduce(warp, normalizer, plus_op); + + // Write output. + if (block.thread_rank() == 0) { + out[grid.block_rank()] = isinf(maxval) ? maxval : log(normalizer) + maxval; + } +} + +} // namespace cu + +void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("LogSumExp::eval_gpu"); + assert(inputs.size() == 1); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + // Make sure that the last dimension is contiguous. + auto ensure_contiguous = [&s, &encoder](const array& x) { + if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { + return x; + } else { + auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + encoder.add_temporary(x_copy); + return x_copy; + } + }; + + auto in = ensure_contiguous(inputs[0]); + if (in.flags().row_contiguous) { + out.set_data(allocator::malloc(out.nbytes())); + } else { + auto n = in.shape(-1); + auto flags = in.flags(); + auto strides = in.strides(); + for (auto& s : strides) { + s /= n; + } + bool col_contig = strides[0] == 1; + for (int i = 1; col_contig && i < strides.size(); ++i) { + col_contig &= + (out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]); + } + flags.col_contiguous = col_contig; + out.set_data( + allocator::malloc(in.nbytes() / n), + in.data_size() / n, + std::move(strides), + flags); + } + + int axis_size = in.shape().back(); + int n_rows = in.data_size() / axis_size; + + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "logsumexp", CTYPE, { + using DataType = cuda_type_t; + constexpr int N_READS = 4; + MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { + auto kernel = cu::logsumexp; + kernel<<>>( + in.data(), out.data(), axis_size); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 5cf19711c..47bf68172 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -85,7 +85,6 @@ NO_GPU(GatherMM) NO_GPU(GatherQMM) NO_GPU(Hadamard) NO_GPU(Load) -NO_GPU(LogSumExp) NO_GPU_MULTI(LUF) NO_GPU(Partition) NO_GPU_MULTI(QRF) @@ -95,7 +94,6 @@ NO_GPU(Scatter) NO_GPU(ScatterAxis) NO_GPU(Select) NO_GPU(SliceUpdate) -NO_GPU(Softmax) NO_GPU_MULTI(SVD) NO_GPU(Inverse) NO_GPU(Cholesky) diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu new file mode 100644 index 000000000..605fc0df8 --- /dev/null +++ b/mlx/backend/cuda/softmax.cu @@ -0,0 +1,160 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/kernels/cast_op.cuh" +#include "mlx/backend/cuda/kernels/fp16_math.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include +#include + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +inline __device__ T softmax_exp(T x) { + // Softmax doesn't need high precision exponential cause x is gonna be in + // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)). + return __expf(x); +} + +template +__global__ void softmax(const T* in, T* out, int axis_size) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + in += grid.block_rank() * axis_size; + out += grid.block_rank() * axis_size; + + cg::greater max_op; + cg::plus plus_op; + + // Thread reduce. + AccT prevmax; + AccT maxval = Limits::finite_min(); + AccT normalizer = 0; + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { + AccT vals[N_READS]; + cub::LoadDirectBlocked( + r * BLOCK_DIM + block.thread_rank(), + make_cast_iterator(in), + vals, + axis_size, + Limits::finite_min()); + prevmax = maxval; + maxval = max_op(maxval, cub::ThreadReduce(vals, max_op)); + // Online normalizer calculation for softmax: + // https://github.com/NVIDIA/online-softmax + normalizer = normalizer * softmax_exp(prevmax - maxval); + for (int i = 0; i < N_READS; i++) { + normalizer = normalizer + softmax_exp(vals[i] - maxval); + } + } + + // First warp reduce. + prevmax = maxval; + maxval = cg::reduce(warp, maxval, max_op); + normalizer = normalizer * softmax_exp(prevmax - maxval); + normalizer = cg::reduce(warp, normalizer, plus_op); + + __shared__ AccT local_max[WARP_SIZE]; + __shared__ AccT local_normalizer[WARP_SIZE]; + + // Write to shared memory and do second warp reduce. + prevmax = maxval; + if (warp.thread_rank() == 0) { + local_max[warp.meta_group_rank()] = maxval; + } + block.sync(); + maxval = warp.thread_rank() < warp.meta_group_size() + ? local_max[warp.thread_rank()] + : Limits::finite_min(); + maxval = cg::reduce(warp, maxval, max_op); + normalizer = normalizer * softmax_exp(prevmax - maxval); + if (warp.thread_rank() == 0) { + local_normalizer[warp.meta_group_rank()] = normalizer; + } + block.sync(); + normalizer = warp.thread_rank() < warp.meta_group_size() + ? local_normalizer[warp.thread_rank()] + : AccT{}; + normalizer = cg::reduce(warp, normalizer, plus_op); + normalizer = 1 / normalizer; + + // Write output. + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T vals[N_READS]; + cub::LoadDirectBlocked(index, in, vals, axis_size); + for (int i = 0; i < N_READS; i++) { + vals[i] = softmax_exp(static_cast(vals[i]) - maxval) * normalizer; + } + cub::StoreDirectBlocked(index, out, vals, axis_size); + } +} + +} // namespace cu + +void Softmax::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Softmax::eval_gpu"); + assert(inputs.size() == 1); + auto& s = stream(); + + // Make sure that the last dimension is contiguous. + auto set_output = [&s, &out](const array& x) { + if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { + if (x.is_donatable()) { + out.copy_shared_buffer(x); + } else { + out.set_data( + allocator::malloc(x.data_size() * x.itemsize()), + x.data_size(), + x.strides(), + x.flags()); + } + return x; + } else { + auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + out.copy_shared_buffer(x_copy); + return x_copy; + } + }; + + array in = set_output(inputs[0]); + bool precise = in.dtype() != float32 && precise_; + + int axis_size = in.shape().back(); + int n_rows = in.data_size() / axis_size; + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "softmax", CTYPE, { + using DataType = cuda_type_t; + constexpr int N_READS = 4; + MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { + auto kernel = cu::softmax; + if (precise) { + kernel = cu::softmax; + } + kernel<<>>( + in.data(), out.data(), axis_size); + }); + }); + }); +} + +} // namespace mlx::core