mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-28 13:11:26 +08:00
add cuda gemv (#2400)
This commit is contained in:
parent
1e496ddb82
commit
d107d8d495
@ -20,6 +20,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemv.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||
|
@ -128,7 +128,7 @@ __global__ void binary_g(
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_4d(
|
||||
auto [a_idx, b_idx] = elem_to_loc(
|
||||
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
||||
out[index] = Op{}(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
@ -160,7 +160,7 @@ __global__ void binary_two_g(
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_4d(
|
||||
auto [a_idx, b_idx] = elem_to_loc(
|
||||
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
||||
auto out = Op{}(a[a_idx], b[b_idx]);
|
||||
out_a[index] = out[0];
|
||||
|
@ -37,7 +37,7 @@ __global__ void copy_gg(
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [idx_in, idx_out] = elem_to_loc_4d(
|
||||
auto [idx_in, idx_out] = elem_to_loc(
|
||||
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
|
||||
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
|
||||
}
|
||||
|
@ -41,7 +41,7 @@ __global__ void copy_gg_dynamic(
|
||||
const int64_t* offset_out) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [idx_in, idx_out] = elem_to_loc_4d(
|
||||
auto [idx_in, idx_out] = elem_to_loc(
|
||||
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
|
||||
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
|
||||
}
|
||||
|
@ -34,7 +34,7 @@ __global__ void copy_g(
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
IdxT idx_in = elem_to_loc_4d(index, shape.data(), strides_in.data(), ndim);
|
||||
IdxT idx_in = elem_to_loc(index, shape.data(), strides_in.data(), ndim);
|
||||
out[index] = CastOp<In, Out>{}(in[idx_in]);
|
||||
}
|
||||
}
|
||||
|
@ -218,20 +218,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
|
||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||
}
|
||||
|
||||
// Optimized version when ndim is larger than 4.
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ IdxT
|
||||
elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
|
||||
IdxT loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc(
|
||||
IdxT elem,
|
||||
const int* shape,
|
||||
const int64_t* a_strides,
|
||||
@ -249,7 +237,7 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
||||
}
|
||||
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
|
||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc(
|
||||
IdxT elem,
|
||||
const int* shape,
|
||||
const int64_t* a_strides,
|
||||
|
147
mlx/backend/cuda/gemv.cu
Normal file
147
mlx/backend/cuda/gemv.cu
Normal file
@ -0,0 +1,147 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/gemv.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
static constexpr int n_per_thread = 4;
|
||||
static constexpr int rows_per_block = 8;
|
||||
|
||||
template <typename T, int rows_per_block, int n_per_thread>
|
||||
__device__ void
|
||||
gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
auto g_idx = block.group_index();
|
||||
auto t_idx = block.thread_index();
|
||||
int row = g_idx.x * rows_per_block + t_idx.y;
|
||||
|
||||
if (row < rows) {
|
||||
float sum = 0.0f;
|
||||
for (int col = n_per_thread * warp.thread_rank(); col < cols;
|
||||
col += (WARP_SIZE * n_per_thread)) {
|
||||
auto local_mat = load_vector<n_per_thread>(mat + row * cols + col, 0);
|
||||
auto local_vec = load_vector<n_per_thread>(vec + col, 0);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < n_per_thread; ++j) {
|
||||
sum += static_cast<float>(local_mat.val[j]) *
|
||||
static_cast<float>(local_vec.val[j]);
|
||||
}
|
||||
}
|
||||
|
||||
sum = cg::reduce(warp, sum, cg::plus<float>{});
|
||||
if (warp.thread_rank() == 0) {
|
||||
out[row] = static_cast<T>(sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int rows_per_block, int n_per_thread>
|
||||
__global__ void
|
||||
gemv_single(const T* mat, const T* vec, T* out, int rows, int cols) {
|
||||
gemv_impl<T, rows_per_block, n_per_thread>(mat, vec, out, rows, cols);
|
||||
}
|
||||
|
||||
template <typename T, int rows_per_block, int n_per_thread>
|
||||
__global__ void gemv_batched(
|
||||
const T* mat,
|
||||
const T* vec,
|
||||
T* out,
|
||||
int rows,
|
||||
int cols,
|
||||
const __grid_constant__ Shape batch_shape,
|
||||
const __grid_constant__ Strides mat_batch_strides,
|
||||
const __grid_constant__ Strides vec_batch_strides,
|
||||
int batch_ndim) {
|
||||
auto block = cg::this_thread_block();
|
||||
auto batch_idx = block.group_index().y;
|
||||
auto [vec_offset, mat_offset] = elem_to_loc(
|
||||
batch_idx,
|
||||
batch_shape.data(),
|
||||
vec_batch_strides.data(),
|
||||
mat_batch_strides.data(),
|
||||
batch_ndim);
|
||||
gemv_impl<T, rows_per_block, n_per_thread>(
|
||||
mat + mat_offset, vec + vec_offset, out + batch_idx * rows, rows, cols);
|
||||
}
|
||||
|
||||
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) {
|
||||
return K % (WARP_SIZE * n_per_thread) == 0 &&
|
||||
((M == 1 && b_transposed) || (N == 1 && !a_transposed));
|
||||
}
|
||||
|
||||
void gemv(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
uint32_t batch_count,
|
||||
const mlx::core::Shape& batch_shape,
|
||||
const mlx::core::Strides& a_batch_strides,
|
||||
const mlx::core::Strides& b_batch_strides,
|
||||
CommandEncoder& encoder) {
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
dispatch_float_types(out.dtype(), "gemv", [&](auto type_tag) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
dim3 block_dims{WARP_SIZE, rows_per_block};
|
||||
const DataType* mat;
|
||||
const DataType* vec;
|
||||
int rows;
|
||||
int cols = K;
|
||||
auto mat_strides = const_param(a_batch_strides);
|
||||
auto vec_strides = const_param(b_batch_strides);
|
||||
|
||||
if (M == 1) {
|
||||
mat = b.data<DataType>();
|
||||
vec = a.data<DataType>();
|
||||
rows = N;
|
||||
std::swap(mat_strides, vec_strides);
|
||||
} else {
|
||||
mat = a.data<DataType>();
|
||||
vec = b.data<DataType>();
|
||||
rows = M;
|
||||
}
|
||||
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
|
||||
if (batch_count == 1) {
|
||||
auto kernel = gemv_single<DataType, rows_per_block, n_per_thread>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks_x,
|
||||
block_dims,
|
||||
mat,
|
||||
vec,
|
||||
out.data<DataType>(),
|
||||
rows,
|
||||
cols);
|
||||
} else {
|
||||
auto kernel = gemv_batched<DataType, rows_per_block, n_per_thread>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
dim3{num_blocks_x, batch_count},
|
||||
block_dims,
|
||||
mat,
|
||||
vec,
|
||||
out.data<DataType>(),
|
||||
rows,
|
||||
cols,
|
||||
const_param(batch_shape),
|
||||
mat_strides,
|
||||
vec_strides,
|
||||
batch_shape.size());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
24
mlx/backend/cuda/gemv.h
Normal file
24
mlx/backend/cuda/gemv.h
Normal file
@ -0,0 +1,24 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed);
|
||||
|
||||
void gemv(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
uint32_t batch_count,
|
||||
const mlx::core::Shape& batch_shape,
|
||||
const mlx::core::Strides& a_batch_strides,
|
||||
const mlx::core::Strides& b_batch_strides,
|
||||
CommandEncoder& encoder);
|
||||
|
||||
} // namespace mlx::core::cu
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include "mlx/backend/common/matmul.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/gemv.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
@ -353,6 +354,22 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
batch_shape = {1};
|
||||
}
|
||||
|
||||
if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
|
||||
cu::gemv(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
batch_count,
|
||||
batch_shape,
|
||||
a_batch_strides,
|
||||
b_batch_strides,
|
||||
encoder);
|
||||
return;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Invoke cublasLt
|
||||
|
||||
|
@ -76,7 +76,7 @@ __global__ void ternary_g(
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx, c_idx] = elem_to_loc_4d(
|
||||
auto [a_idx, b_idx, c_idx] = elem_to_loc(
|
||||
index,
|
||||
shape.data(),
|
||||
a_strides.data(),
|
||||
|
@ -47,7 +47,7 @@ __global__ void unary_g(
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto idx = elem_to_loc_4d(index, shape.data(), strides.data(), ndim);
|
||||
auto idx = elem_to_loc(index, shape.data(), strides.data(), ndim);
|
||||
out[index] = Op{}(in[idx]);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user