From d107d8d495afd056848af937d610725c642ece61 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 22 Jul 2025 08:24:13 -0700 Subject: [PATCH] add cuda gemv (#2400) --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/binary.cu | 2 +- mlx/backend/cuda/binary_two.cu | 2 +- mlx/backend/cuda/copy/copy_general.cu | 2 +- mlx/backend/cuda/copy/copy_general_dynamic.cu | 2 +- mlx/backend/cuda/copy/copy_general_input.cu | 2 +- mlx/backend/cuda/device/utils.cuh | 16 +- mlx/backend/cuda/gemv.cu | 147 ++++++++++++++++++ mlx/backend/cuda/gemv.h | 24 +++ mlx/backend/cuda/matmul.cpp | 17 ++ mlx/backend/cuda/ternary.cu | 2 +- mlx/backend/cuda/unary.cu | 2 +- 12 files changed, 198 insertions(+), 21 deletions(-) create mode 100644 mlx/backend/cuda/gemv.cu create mode 100644 mlx/backend/cuda/gemv.h diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index a98308044..460507952 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -20,6 +20,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/gemv.cu ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index 3eade024d..84a0dd04e 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -128,7 +128,7 @@ __global__ void binary_g( int ndim) { IdxT index = cg::this_grid().thread_rank(); if (index < size) { - auto [a_idx, b_idx] = elem_to_loc_4d( + auto [a_idx, b_idx] = elem_to_loc( index, shape.data(), a_strides.data(), b_strides.data(), ndim); out[index] = Op{}(a[a_idx], b[b_idx]); } diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 3ac8a9516..dfcd81347 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -160,7 +160,7 @@ __global__ void binary_two_g( int ndim) { IdxT index = cg::this_grid().thread_rank(); if (index < size) { - auto [a_idx, b_idx] = elem_to_loc_4d( + auto [a_idx, b_idx] = elem_to_loc( index, shape.data(), a_strides.data(), b_strides.data(), ndim); auto out = Op{}(a[a_idx], b[b_idx]); out_a[index] = out[0]; diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu index 5c7f9f954..e92160b95 100644 --- a/mlx/backend/cuda/copy/copy_general.cu +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -37,7 +37,7 @@ __global__ void copy_gg( int ndim) { IdxT index = cg::this_grid().thread_rank(); if (index < size) { - auto [idx_in, idx_out] = elem_to_loc_4d( + auto [idx_in, idx_out] = elem_to_loc( index, shape.data(), strides_in.data(), strides_out.data(), ndim); out[idx_out] = CastOp{}(in[idx_in]); } diff --git a/mlx/backend/cuda/copy/copy_general_dynamic.cu b/mlx/backend/cuda/copy/copy_general_dynamic.cu index 1b643111f..419dd73fb 100644 --- a/mlx/backend/cuda/copy/copy_general_dynamic.cu +++ b/mlx/backend/cuda/copy/copy_general_dynamic.cu @@ -41,7 +41,7 @@ __global__ void copy_gg_dynamic( const int64_t* offset_out) { IdxT index = cg::this_grid().thread_rank(); if (index < size) { - auto [idx_in, idx_out] = elem_to_loc_4d( + auto [idx_in, idx_out] = elem_to_loc( index, shape.data(), strides_in.data(), strides_out.data(), ndim); out[idx_out + *offset_out] = CastOp{}(in[idx_in + *offset_in]); } diff --git a/mlx/backend/cuda/copy/copy_general_input.cu b/mlx/backend/cuda/copy/copy_general_input.cu index 1ac7712e6..c66f3a777 100644 --- a/mlx/backend/cuda/copy/copy_general_input.cu +++ b/mlx/backend/cuda/copy/copy_general_input.cu @@ -34,7 +34,7 @@ __global__ void copy_g( int ndim) { IdxT index = cg::this_grid().thread_rank(); if (index < size) { - IdxT idx_in = elem_to_loc_4d(index, shape.data(), strides_in.data(), ndim); + IdxT idx_in = elem_to_loc(index, shape.data(), strides_in.data(), ndim); out[index] = CastOp{}(in[idx_in]); } } diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index 3745637da..c5ae14b38 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -218,20 +218,8 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_nd( return cuda::std::make_tuple(a_loc, b_loc, c_loc); } -// Optimized version when ndim is larger than 4. template -inline __host__ __device__ IdxT -elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) { - IdxT loc = 0; - for (int i = ndim - 1; i >= 0; --i) { - loc += (elem % shape[i]) * IdxT(strides[i]); - elem /= shape[i]; - } - return loc; -} - -template -inline __host__ __device__ cuda::std::tuple elem_to_loc_4d( +inline __host__ __device__ cuda::std::tuple elem_to_loc( IdxT elem, const int* shape, const int64_t* a_strides, @@ -249,7 +237,7 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_4d( } template -inline __host__ __device__ cuda::std::tuple elem_to_loc_4d( +inline __host__ __device__ cuda::std::tuple elem_to_loc( IdxT elem, const int* shape, const int64_t* a_strides, diff --git a/mlx/backend/cuda/gemv.cu b/mlx/backend/cuda/gemv.cu new file mode 100644 index 000000000..fe0f7a327 --- /dev/null +++ b/mlx/backend/cuda/gemv.cu @@ -0,0 +1,147 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/gemv.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" + +#include +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +static constexpr int n_per_thread = 4; +static constexpr int rows_per_block = 8; + +template +__device__ void +gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) { + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + auto g_idx = block.group_index(); + auto t_idx = block.thread_index(); + int row = g_idx.x * rows_per_block + t_idx.y; + + if (row < rows) { + float sum = 0.0f; + for (int col = n_per_thread * warp.thread_rank(); col < cols; + col += (WARP_SIZE * n_per_thread)) { + auto local_mat = load_vector(mat + row * cols + col, 0); + auto local_vec = load_vector(vec + col, 0); +#pragma unroll + for (int j = 0; j < n_per_thread; ++j) { + sum += static_cast(local_mat.val[j]) * + static_cast(local_vec.val[j]); + } + } + + sum = cg::reduce(warp, sum, cg::plus{}); + if (warp.thread_rank() == 0) { + out[row] = static_cast(sum); + } + } +} + +template +__global__ void +gemv_single(const T* mat, const T* vec, T* out, int rows, int cols) { + gemv_impl(mat, vec, out, rows, cols); +} + +template +__global__ void gemv_batched( + const T* mat, + const T* vec, + T* out, + int rows, + int cols, + const __grid_constant__ Shape batch_shape, + const __grid_constant__ Strides mat_batch_strides, + const __grid_constant__ Strides vec_batch_strides, + int batch_ndim) { + auto block = cg::this_thread_block(); + auto batch_idx = block.group_index().y; + auto [vec_offset, mat_offset] = elem_to_loc( + batch_idx, + batch_shape.data(), + vec_batch_strides.data(), + mat_batch_strides.data(), + batch_ndim); + gemv_impl( + mat + mat_offset, vec + vec_offset, out + batch_idx * rows, rows, cols); +} + +bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) { + return K % (WARP_SIZE * n_per_thread) == 0 && + ((M == 1 && b_transposed) || (N == 1 && !a_transposed)); +} + +void gemv( + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + uint32_t batch_count, + const mlx::core::Shape& batch_shape, + const mlx::core::Strides& a_batch_strides, + const mlx::core::Strides& b_batch_strides, + CommandEncoder& encoder) { + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + dispatch_float_types(out.dtype(), "gemv", [&](auto type_tag) { + using DataType = cuda_type_t; + dim3 block_dims{WARP_SIZE, rows_per_block}; + const DataType* mat; + const DataType* vec; + int rows; + int cols = K; + auto mat_strides = const_param(a_batch_strides); + auto vec_strides = const_param(b_batch_strides); + + if (M == 1) { + mat = b.data(); + vec = a.data(); + rows = N; + std::swap(mat_strides, vec_strides); + } else { + mat = a.data(); + vec = b.data(); + rows = M; + } + uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block; + if (batch_count == 1) { + auto kernel = gemv_single; + encoder.add_kernel_node( + kernel, + num_blocks_x, + block_dims, + mat, + vec, + out.data(), + rows, + cols); + } else { + auto kernel = gemv_batched; + encoder.add_kernel_node( + kernel, + dim3{num_blocks_x, batch_count}, + block_dims, + mat, + vec, + out.data(), + rows, + cols, + const_param(batch_shape), + mat_strides, + vec_strides, + batch_shape.size()); + } + }); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/gemv.h b/mlx/backend/cuda/gemv.h new file mode 100644 index 000000000..27173aad5 --- /dev/null +++ b/mlx/backend/cuda/gemv.h @@ -0,0 +1,24 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device.h" + +namespace mlx::core::cu { + +bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed); + +void gemv( + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + uint32_t batch_count, + const mlx::core::Shape& batch_shape, + const mlx::core::Strides& a_batch_strides, + const mlx::core::Strides& b_batch_strides, + CommandEncoder& encoder); + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index c50fe7fee..1bca7c730 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -2,6 +2,7 @@ #include "mlx/backend/common/matmul.h" #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/gemv.h" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" @@ -353,6 +354,22 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { batch_shape = {1}; } + if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) { + cu::gemv( + a, + b, + out, + M, + N, + K, + batch_count, + batch_shape, + a_batch_strides, + b_batch_strides, + encoder); + return; + } + ///////////////////////////////////////////////////////////////////////////// // Invoke cublasLt diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index eb69442c2..9b208c423 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -76,7 +76,7 @@ __global__ void ternary_g( int ndim) { IdxT index = cg::this_grid().thread_rank(); if (index < size) { - auto [a_idx, b_idx, c_idx] = elem_to_loc_4d( + auto [a_idx, b_idx, c_idx] = elem_to_loc( index, shape.data(), a_strides.data(), diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 6b7c94bb8..83bf83417 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -47,7 +47,7 @@ __global__ void unary_g( int ndim) { IdxT index = cg::this_grid().thread_rank(); if (index < size) { - auto idx = elem_to_loc_4d(index, shape.data(), strides.data(), ndim); + auto idx = elem_to_loc(index, shape.data(), strides.data(), ndim); out[index] = Op{}(in[idx]); } }