[CUDA] Do vectorized store/load in contiguous elementwise ops (#2342)

* Do vectorized store/load in unary ops

* Do vectorized store/load in binary_two ops

* Do vectorized store/load in copy ops

* Do vectorized store/load in ternary ops

* Use int32_t for IdxT

* binary => binary_two in binary_two.cu

* Fix tests on large arrays

* Use uint as index type

* Contig uses uint as index and non-contig uses int
This commit is contained in:
Cheng
2025-07-10 10:48:43 +09:00
committed by GitHub
parent e14ee12491
commit 85873cb162
5 changed files with 223 additions and 96 deletions

View File

@@ -17,52 +17,119 @@ namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT>
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void
binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[0], b[0]);
out_a[0] = out[0];
out_b[0] = out[1];
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
auto out = Op{}(a[0], b[0]);
out_a[i] = out[0];
out_b[i] = out[1];
}
} else {
AlignedVector<Out, N_READS> out_a_vec;
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a[0], b[0]);
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
store_vector<N_READS>(out_b, index, out_b_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void
binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[0], b[index]);
out_a[index] = out[0];
out_b[index] = out[1];
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
auto out = Op{}(a[0], b[i]);
out_a[i] = out[0];
out_b[i] = out[1];
}
} else {
auto b_vec = load_vector<N_READS>(b, index);
AlignedVector<Out, N_READS> out_a_vec;
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a[0], b_vec.val[i]);
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
store_vector<N_READS>(out_b, index, out_b_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void
binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[index], b[0]);
out_a[index] = out[0];
out_b[index] = out[1];
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
auto out = Op{}(a[i], b[0]);
out_a[i] = out[0];
out_b[i] = out[1];
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
AlignedVector<Out, N_READS> out_a_vec;
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec.val[i], b[0]);
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
store_vector<N_READS>(out_b, index, out_b_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void
binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[index], b[index]);
out_a[index] = out[0];
out_b[index] = out[1];
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
auto out = Op{}(a[i], b[i]);
out_a[i] = out[0];
out_b[i] = out[1];
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
auto b_vec = load_vector<N_READS>(b, index);
AlignedVector<Out, N_READS> out_a_vec;
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec.val[i], b_vec.val[i]);
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
store_vector<N_READS>(out_b, index, out_b_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
__global__ void binary_g_nd(
__global__ void binary_two_g_nd(
const In* a,
const In* b,
Out* out_a,
@@ -82,7 +149,7 @@ __global__ void binary_g_nd(
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_g(
__global__ void binary_two_g(
const In* a,
const In* b,
Out* out_a,
@@ -103,7 +170,7 @@ __global__ void binary_g(
}
template <typename Op, typename In, typename Out>
constexpr bool supports_binary_op() {
constexpr bool supports_binary_two_op() {
if (std::is_same_v<Op, DivMod>) {
return std::is_same_v<In, Out> &&
(std::is_integral_v<Out> || is_floating_v<Out>);
@@ -114,7 +181,7 @@ constexpr bool supports_binary_op() {
} // namespace cu
template <typename Op>
void binary_op_gpu_inplace(
void binary_two_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
@@ -141,7 +208,7 @@ void binary_op_gpu_inplace(
dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) {
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
if constexpr (cu::supports_binary_two_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
@@ -161,8 +228,12 @@ void binary_op_gpu_inplace(
int ndim = shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu::
binary_g_nd<Op, InType, OutType, IdxT, dims_constant()>;
auto kernel = cu::binary_two_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large());
encoder.add_kernel_node(
@@ -179,7 +250,7 @@ void binary_op_gpu_inplace(
const_param<dims_constant()>(b_strides));
});
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto kernel = cu::binary_two_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large());
encoder.add_kernel_node(
@@ -198,22 +269,25 @@ void binary_op_gpu_inplace(
}
});
} else {
dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) {
dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
// TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;
auto kernel = cu::binary_two_ss<Op, InType, OutType, IdxT, N_READS>;
if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
kernel = cu::binary_two_sv<Op, InType, OutType, IdxT, N_READS>;
} else if (bopt == BinaryOpType::VectorScalar) {
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
kernel = cu::binary_two_vs<Op, InType, OutType, IdxT, N_READS>;
} else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
kernel = cu::binary_two_vv<Op, InType, OutType, IdxT, N_READS>;
}
auto [num_blocks, block_dims] = get_launch_args(
kernel,
out_a.data_size(),
out_a.shape(),
out_a.strides(),
large());
large(),
N_READS);
encoder.add_kernel_node(
kernel,
num_blocks,
@@ -237,7 +311,7 @@ void binary_op_gpu_inplace(
}
template <typename Op>
void binary_op_gpu(
void binary_two_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
@@ -247,7 +321,7 @@ void binary_op_gpu(
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt);
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
binary_two_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
void DivMod::eval_gpu(
@@ -255,7 +329,7 @@ void DivMod::eval_gpu(
std::vector<array>& outputs) {
nvtx3::scoped_range r("DivMod::eval_gpu");
auto& s = outputs[0].primitive().stream();
binary_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s);
binary_two_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s);
}
} // namespace mlx::core