[CUDA] Do vectorized store/load in binary ops (#2330)

This commit is contained in:
Cheng 2025-07-08 00:44:14 +09:00 committed by GitHub
parent 19facd4b20
commit 9d10239af7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 116 additions and 17 deletions

View File

@ -17,35 +17,106 @@ namespace cu {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT> template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) { __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) { int remaining = size - index * N_READS;
out[index] = Op{}(a[0], b[0]); if (remaining <= 0) {
return;
}
if (remaining < N_READS) {
for (int i = 0; i < remaining; ++i) {
IdxT offset = index * N_READS + i;
out[offset] = Op{}(a[0], b[0]);
}
} else {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a[0], b[0]);
}
store_vector<N_READS>(out, index, out_vec);
} }
} }
template <typename Op, typename In, typename Out, typename IdxT> template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) { __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) { int remaining = size - index * N_READS;
out[index] = Op{}(a[0], b[index]); if (remaining <= 0) {
return;
}
if (remaining < N_READS) {
for (int i = 0; i < remaining; ++i) {
IdxT offset = index * N_READS + i;
out[offset] = Op{}(a[0], b[offset]);
}
} else {
auto b_vec = load_vector<N_READS>(b, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a[0], b_vec.val[i]);
}
store_vector<N_READS>(out, index, out_vec);
} }
} }
template <typename Op, typename In, typename Out, typename IdxT> template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) { __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) { int remaining = size - index * N_READS;
out[index] = Op{}(a[index], b[0]); if (remaining <= 0) {
return;
}
if (remaining < N_READS) {
for (int i = 0; i < remaining; ++i) {
IdxT offset = index * N_READS + i;
out[offset] = Op{}(a[offset], b[0]);
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a_vec.val[i], b[0]);
}
store_vector<N_READS>(out, index, out_vec);
} }
} }
template <typename Op, typename In, typename Out, typename IdxT> template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) { __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) { int remaining = size - index * N_READS;
out[index] = Op{}(a[index], b[index]); if (remaining <= 0) {
return;
}
if (remaining < N_READS) {
for (int i = 0; i < remaining; ++i) {
IdxT offset = index * N_READS + i;
out[offset] = Op{}(a[offset], b[offset]);
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
auto b_vec = load_vector<N_READS>(b, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i]);
}
store_vector<N_READS>(out, index, out_vec);
} }
} }
@ -198,16 +269,23 @@ void binary_op_gpu_inplace(
} else { } else {
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>; using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>; // TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
if (bopt == BinaryOpType::ScalarVector) { if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_sv<Op, InType, OutType, IdxT>; kernel = cu::binary_sv<Op, InType, OutType, IdxT, N_READS>;
} else if (bopt == BinaryOpType::VectorScalar) { } else if (bopt == BinaryOpType::VectorScalar) {
kernel = cu::binary_vs<Op, InType, OutType, IdxT>; kernel = cu::binary_vs<Op, InType, OutType, IdxT, N_READS>;
} else if (bopt == BinaryOpType::VectorVector) { } else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_vv<Op, InType, OutType, IdxT>; kernel = cu::binary_vv<Op, InType, OutType, IdxT, N_READS>;
} }
auto [num_blocks, block_dims] = get_launch_args( auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), large()); kernel,
out.data_size(),
out.shape(),
out.strides(),
large(),
N_READS);
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
num_blocks, num_blocks,

View File

@ -28,6 +28,27 @@ namespace mlx::core::cu {
using Shape = cuda::std::array<int32_t, MAX_NDIM>; using Shape = cuda::std::array<int32_t, MAX_NDIM>;
using Strides = cuda::std::array<int64_t, MAX_NDIM>; using Strides = cuda::std::array<int64_t, MAX_NDIM>;
// Vectorized load/store.
template <typename T, int N>
struct alignas(sizeof(T) * N) AlignedVector {
T val[N];
};
template <int N, typename T>
inline __device__ AlignedVector<T, N> load_vector(
const T* ptr,
uint32_t offset) {
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
return from[offset];
}
template <int N, typename T>
inline __device__ void
store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Type limits utils // Type limits utils
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////