diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index d9b9fd8af..0585dc76a 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -17,35 +17,106 @@ namespace cu { namespace cg = cooperative_groups; -template +template __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - out[index] = Op{}(a[0], b[0]); + int remaining = size - index * N_READS; + if (remaining <= 0) { + return; + } + + if (remaining < N_READS) { + for (int i = 0; i < remaining; ++i) { + IdxT offset = index * N_READS + i; + out[offset] = Op{}(a[0], b[0]); + } + } else { + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = Op{}(a[0], b[0]); + } + + store_vector(out, index, out_vec); } } -template +template __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - out[index] = Op{}(a[0], b[index]); + int remaining = size - index * N_READS; + if (remaining <= 0) { + return; + } + + if (remaining < N_READS) { + for (int i = 0; i < remaining; ++i) { + IdxT offset = index * N_READS + i; + out[offset] = Op{}(a[0], b[offset]); + } + } else { + auto b_vec = load_vector(b, index); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = Op{}(a[0], b_vec.val[i]); + } + + store_vector(out, index, out_vec); } } -template +template __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - out[index] = Op{}(a[index], b[0]); + int remaining = size - index * N_READS; + if (remaining <= 0) { + return; + } + + if (remaining < N_READS) { + for (int i = 0; i < remaining; ++i) { + IdxT offset = index * N_READS + i; + out[offset] = Op{}(a[offset], b[0]); + } + } else { + auto a_vec = load_vector(a, index); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = Op{}(a_vec.val[i], b[0]); + } + + store_vector(out, index, out_vec); } } -template +template __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - out[index] = Op{}(a[index], b[index]); + int remaining = size - index * N_READS; + if (remaining <= 0) { + return; + } + + if (remaining < N_READS) { + for (int i = 0; i < remaining; ++i) { + IdxT offset = index * N_READS + i; + out[offset] = Op{}(a[offset], b[offset]); + } + } else { + auto a_vec = load_vector(a, index); + auto b_vec = load_vector(b, index); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i]); + } + + store_vector(out, index, out_vec); } } @@ -198,16 +269,23 @@ void binary_op_gpu_inplace( } else { dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { using IdxT = std::conditional_t; - auto kernel = cu::binary_ss; + // TODO: Choose optimized value based on type size. + constexpr int N_READS = 4; + auto kernel = cu::binary_ss; if (bopt == BinaryOpType::ScalarVector) { - kernel = cu::binary_sv; + kernel = cu::binary_sv; } else if (bopt == BinaryOpType::VectorScalar) { - kernel = cu::binary_vs; + kernel = cu::binary_vs; } else if (bopt == BinaryOpType::VectorVector) { - kernel = cu::binary_vv; + kernel = cu::binary_vv; } auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), large()); + kernel, + out.data_size(), + out.shape(), + out.strides(), + large(), + N_READS); encoder.add_kernel_node( kernel, num_blocks, diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index 6e8abdd7c..89b609c45 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -28,6 +28,27 @@ namespace mlx::core::cu { using Shape = cuda::std::array; using Strides = cuda::std::array; +// Vectorized load/store. +template +struct alignas(sizeof(T) * N) AlignedVector { + T val[N]; +}; + +template +inline __device__ AlignedVector load_vector( + const T* ptr, + uint32_t offset) { + auto* from = reinterpret_cast*>(ptr); + return from[offset]; +} + +template +inline __device__ void +store_vector(T* ptr, uint32_t offset, const AlignedVector& vec) { + auto* to = reinterpret_cast*>(ptr); + to[offset] = vec; +} + /////////////////////////////////////////////////////////////////////////////// // Type limits utils ///////////////////////////////////////////////////////////////////////////////