From 85873cb162d0802c925350cbc68e1410bce3f1ad Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 10 Jul 2025 10:48:43 +0900 Subject: [PATCH] [CUDA] Do vectorized store/load in contiguous elementwise ops (#2342) * Do vectorized store/load in unary ops * Do vectorized store/load in binary_two ops * Do vectorized store/load in copy ops * Do vectorized store/load in ternary ops * Use int32_t for IdxT * binary => binary_two in binary_two.cu * Fix tests on large arrays * Use uint as index type * Contig uses uint as index and non-contig uses int --- mlx/backend/cuda/binary.cu | 46 ++----- mlx/backend/cuda/binary_two.cu | 156 +++++++++++++++++------ mlx/backend/cuda/copy/copy_contiguous.cu | 49 +++++-- mlx/backend/cuda/ternary.cu | 34 ++++- mlx/backend/cuda/unary.cu | 34 ++++- 5 files changed, 223 insertions(+), 96 deletions(-) diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index 0585dc76a7..fc5b8c4967 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -20,15 +20,10 @@ namespace cg = cooperative_groups; template __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - int remaining = size - index * N_READS; - if (remaining <= 0) { - return; - } - if (remaining < N_READS) { - for (int i = 0; i < remaining; ++i) { - IdxT offset = index * N_READS + i; - out[offset] = Op{}(a[0], b[0]); + if ((index + 1) * N_READS > size) { + for (int i = index * N_READS; i < size; ++i) { + out[i] = Op{}(a[0], b[0]); } } else { AlignedVector out_vec; @@ -44,15 +39,10 @@ __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) { template __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - int remaining = size - index * N_READS; - if (remaining <= 0) { - return; - } - if (remaining < N_READS) { - for (int i = 0; i < remaining; ++i) { - IdxT offset = index * N_READS + i; - out[offset] = Op{}(a[0], b[offset]); + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = Op{}(a[0], b[i]); } } else { auto b_vec = load_vector(b, index); @@ -70,15 +60,10 @@ __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) { template __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - int remaining = size - index * N_READS; - if (remaining <= 0) { - return; - } - if (remaining < N_READS) { - for (int i = 0; i < remaining; ++i) { - IdxT offset = index * N_READS + i; - out[offset] = Op{}(a[offset], b[0]); + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = Op{}(a[i], b[0]); } } else { auto a_vec = load_vector(a, index); @@ -96,15 +81,10 @@ __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) { template __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - int remaining = size - index * N_READS; - if (remaining <= 0) { - return; - } - if (remaining < N_READS) { - for (int i = 0; i < remaining; ++i) { - IdxT offset = index * N_READS + i; - out[offset] = Op{}(a[offset], b[offset]); + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = Op{}(a[i], b[i]); } } else { auto a_vec = load_vector(a, index); @@ -267,7 +247,7 @@ void binary_op_gpu_inplace( } }); } else { - dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { using IdxT = std::conditional_t; // TODO: Choose optimized value based on type size. constexpr int N_READS = 4; diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 9582b0378c..4b6e24581b 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -17,52 +17,119 @@ namespace cu { namespace cg = cooperative_groups; -template +template __global__ void -binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { +binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto out = Op{}(a[0], b[0]); - out_a[0] = out[0]; - out_b[0] = out[1]; + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + auto out = Op{}(a[0], b[0]); + out_a[i] = out[0]; + out_b[i] = out[1]; + } + } else { + AlignedVector out_a_vec; + AlignedVector out_b_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + auto out = Op{}(a[0], b[0]); + out_a_vec.val[i] = out[0]; + out_b_vec.val[i] = out[1]; + } + + store_vector(out_a, index, out_a_vec); + store_vector(out_b, index, out_b_vec); } } -template +template __global__ void -binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { +binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto out = Op{}(a[0], b[index]); - out_a[index] = out[0]; - out_b[index] = out[1]; + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + auto out = Op{}(a[0], b[i]); + out_a[i] = out[0]; + out_b[i] = out[1]; + } + } else { + auto b_vec = load_vector(b, index); + + AlignedVector out_a_vec; + AlignedVector out_b_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + auto out = Op{}(a[0], b_vec.val[i]); + out_a_vec.val[i] = out[0]; + out_b_vec.val[i] = out[1]; + } + + store_vector(out_a, index, out_a_vec); + store_vector(out_b, index, out_b_vec); } } -template +template __global__ void -binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { +binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto out = Op{}(a[index], b[0]); - out_a[index] = out[0]; - out_b[index] = out[1]; + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + auto out = Op{}(a[i], b[0]); + out_a[i] = out[0]; + out_b[i] = out[1]; + } + } else { + auto a_vec = load_vector(a, index); + + AlignedVector out_a_vec; + AlignedVector out_b_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + auto out = Op{}(a_vec.val[i], b[0]); + out_a_vec.val[i] = out[0]; + out_b_vec.val[i] = out[1]; + } + + store_vector(out_a, index, out_a_vec); + store_vector(out_b, index, out_b_vec); } } -template +template __global__ void -binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { +binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto out = Op{}(a[index], b[index]); - out_a[index] = out[0]; - out_b[index] = out[1]; + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + auto out = Op{}(a[i], b[i]); + out_a[i] = out[0]; + out_b[i] = out[1]; + } + } else { + auto a_vec = load_vector(a, index); + auto b_vec = load_vector(b, index); + + AlignedVector out_a_vec; + AlignedVector out_b_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + auto out = Op{}(a_vec.val[i], b_vec.val[i]); + out_a_vec.val[i] = out[0]; + out_b_vec.val[i] = out[1]; + } + + store_vector(out_a, index, out_a_vec); + store_vector(out_b, index, out_b_vec); } } template -__global__ void binary_g_nd( +__global__ void binary_two_g_nd( const In* a, const In* b, Out* out_a, @@ -82,7 +149,7 @@ __global__ void binary_g_nd( } template -__global__ void binary_g( +__global__ void binary_two_g( const In* a, const In* b, Out* out_a, @@ -103,7 +170,7 @@ __global__ void binary_g( } template -constexpr bool supports_binary_op() { +constexpr bool supports_binary_two_op() { if (std::is_same_v) { return std::is_same_v && (std::is_integral_v || is_floating_v); @@ -114,7 +181,7 @@ constexpr bool supports_binary_op() { } // namespace cu template -void binary_op_gpu_inplace( +void binary_two_op_gpu_inplace( const std::vector& inputs, std::vector& outputs, std::string_view op, @@ -141,7 +208,7 @@ void binary_op_gpu_inplace( dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) { using CTYPE_IN = MLX_GET_TYPE(in_type_tag); using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); - if constexpr (cu::supports_binary_op()) { + if constexpr (cu::supports_binary_two_op()) { using InType = cuda_type_t; using OutType = cuda_type_t; @@ -161,8 +228,12 @@ void binary_op_gpu_inplace( int ndim = shape.size(); if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto kernel = cu:: - binary_g_nd; + auto kernel = cu::binary_two_g_nd< + Op, + InType, + OutType, + IdxT, + dims_constant()>; auto [num_blocks, block_dims] = get_launch_args(kernel, out_a, large()); encoder.add_kernel_node( @@ -179,7 +250,7 @@ void binary_op_gpu_inplace( const_param(b_strides)); }); } else { - auto kernel = cu::binary_g; + auto kernel = cu::binary_two_g; auto [num_blocks, block_dims] = get_launch_args(kernel, out_a, large()); encoder.add_kernel_node( @@ -198,22 +269,25 @@ void binary_op_gpu_inplace( } }); } else { - dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) { + dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) { using IdxT = std::conditional_t; - auto kernel = cu::binary_ss; + // TODO: Choose optimized value based on type size. + constexpr int N_READS = 4; + auto kernel = cu::binary_two_ss; if (bopt == BinaryOpType::ScalarVector) { - kernel = cu::binary_sv; + kernel = cu::binary_two_sv; } else if (bopt == BinaryOpType::VectorScalar) { - kernel = cu::binary_vs; + kernel = cu::binary_two_vs; } else if (bopt == BinaryOpType::VectorVector) { - kernel = cu::binary_vv; + kernel = cu::binary_two_vv; } auto [num_blocks, block_dims] = get_launch_args( kernel, out_a.data_size(), out_a.shape(), out_a.strides(), - large()); + large(), + N_READS); encoder.add_kernel_node( kernel, num_blocks, @@ -237,7 +311,7 @@ void binary_op_gpu_inplace( } template -void binary_op_gpu( +void binary_two_op_gpu( const std::vector& inputs, std::vector& outputs, std::string_view op, @@ -247,7 +321,7 @@ void binary_op_gpu( auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, outputs[0], bopt); set_binary_op_output_data(a, b, outputs[1], bopt); - binary_op_gpu_inplace(inputs, outputs, op, s); + binary_two_op_gpu_inplace(inputs, outputs, op, s); } void DivMod::eval_gpu( @@ -255,7 +329,7 @@ void DivMod::eval_gpu( std::vector& outputs) { nvtx3::scoped_range r("DivMod::eval_gpu"); auto& s = outputs[0].primitive().stream(); - binary_op_gpu(inputs, outputs, get_primitive_string(this), s); + binary_two_op_gpu(inputs, outputs, get_primitive_string(this), s); } } // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu index 4083501299..4e9eaccb7e 100644 --- a/mlx/backend/cuda/copy/copy_contiguous.cu +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -10,19 +10,43 @@ namespace cu { namespace cg = cooperative_groups; -template +template __global__ void copy_s(const In* in, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - out[index] = CastOp{}(in[0]); + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = cast_to(in[0]); + } + } else { + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = cast_to(in[0]); + } + + store_vector(out, index, out_vec); } } -template +template __global__ void copy_v(const In* in, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - out[index] = CastOp{}(in[index]); + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = cast_to(in[i]); + } + } else { + auto in_vec = load_vector(in, index); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = cast_to(in_vec.val[i]); + } + + store_vector(out, index, out_vec); } } @@ -41,12 +65,19 @@ void copy_contiguous( using InType = cuda_type_t; using OutType = cuda_type_t; using IdxT = std::conditional_t; - auto kernel = cu::copy_s; + // TODO: Choose optimized value based on type size. + constexpr int N_READS = 4; + auto kernel = cu::copy_s; if (ctype == CopyType::Vector) { - kernel = cu::copy_v; + kernel = cu::copy_v; } auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), large()); + kernel, + out.data_size(), + out.shape(), + out.strides(), + large(), + N_READS); encoder.add_kernel_node( kernel, num_blocks, diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index aa6523f274..eb69442c2a 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -15,12 +15,27 @@ namespace cu { namespace cg = cooperative_groups; -template +template __global__ void ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - out[index] = Op{}(a[index], b[index], c[index]); + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = Op{}(a[i], b[i], c[i]); + } + } else { + auto a_vec = load_vector(a, index); + auto b_vec = load_vector(b, index); + auto c_vec = load_vector(c, index); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i], c_vec.val[i]); + } + + store_vector(out, index, out_vec); } } @@ -149,11 +164,18 @@ void ternary_op_gpu_inplace( } }); } else { - dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { using IdxT = std::conditional_t; - auto kernel = cu::ternary_v; + // TODO: Choose optimized value based on type size. + constexpr int N_READS = 4; + auto kernel = cu::ternary_v; auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), large()); + kernel, + out.data_size(), + out.shape(), + out.strides(), + large(), + N_READS); encoder.add_kernel_node( kernel, num_blocks, diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 3f1a62d24c..1fe1b557bd 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -18,11 +18,24 @@ namespace cu { namespace cg = cooperative_groups; -template +template __global__ void unary_v(const In* in, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - out[index] = Op{}(in[index]); + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = Op{}(in[i]); + } + } else { + auto in_vec = load_vector(in, index); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = Op{}(in_vec.val[i]); + } + + store_vector(out, index, out_vec); } } @@ -112,14 +125,20 @@ void unary_op_gpu_inplace( using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_unary_op()) { dispatch_bool(large, [&](auto large) { - using IdxT = std::conditional_t; using InType = cuda_type_t; using OutType = cuda_type_t; - using IdxT = std::conditional_t; if (contig) { - auto kernel = cu::unary_v; + using IdxT = std::conditional_t; + // TODO: Choose optimized value based on type size. + constexpr int N_READS = 4; + auto kernel = cu::unary_v; auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), large); + kernel, + out.data_size(), + out.shape(), + out.strides(), + large, + N_READS); encoder.add_kernel_node( kernel, num_blocks, @@ -128,6 +147,7 @@ void unary_op_gpu_inplace( out.data(), out.data_size()); } else { + using IdxT = std::conditional_t; auto [shape, strides] = collapse_contiguous_dims(in); auto kernel = cu::unary_g; auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);