From 3eb59aab6ec5d418965f978783a6c70051bf695b Mon Sep 17 00:00:00 2001 From: Cheng Date: Tue, 8 Jul 2025 00:22:12 +0000 Subject: [PATCH] Do vectorized store/load in copy ops --- mlx/backend/cuda/copy/copy_contiguous.cu | 59 ++++++++++++++++++++---- 1 file changed, 50 insertions(+), 9 deletions(-) diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu index 408350129..60f66f984 100644 --- a/mlx/backend/cuda/copy/copy_contiguous.cu +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -10,19 +10,53 @@ namespace cu { namespace cg = cooperative_groups; -template +template __global__ void copy_s(const In* in, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - out[index] = CastOp{}(in[0]); + int remaining = size - index * N_READS; + if (remaining <= 0) { + return; + } + + if (remaining < N_READS) { + for (int i = 0; i < remaining; ++i) { + IdxT offset = index * N_READS + i; + out[offset] = CastOp{}(in[0]); + } + } else { + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = CastOp{}(in[0]); + } + + store_vector(out, index, out_vec); } } -template +template __global__ void copy_v(const In* in, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - out[index] = CastOp{}(in[index]); + int remaining = size - index * N_READS; + if (remaining <= 0) { + return; + } + + if (remaining < N_READS) { + for (int i = 0; i < remaining; ++i) { + IdxT offset = index * N_READS + i; + out[offset] = CastOp{}(in[offset]); + } + } else { + auto in_vec = load_vector(in, index); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = CastOp{}(in_vec.val[i]); + } + + store_vector(out, index, out_vec); } } @@ -41,12 +75,19 @@ void copy_contiguous( using InType = cuda_type_t; using OutType = cuda_type_t; using IdxT = std::conditional_t; - auto kernel = cu::copy_s; + // TODO: Choose optimized value based on type size. + constexpr int N_READS = 4; + auto kernel = cu::copy_s; if (ctype == CopyType::Vector) { - kernel = cu::copy_v; + kernel = cu::copy_v; } auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), large()); + kernel, + out.data_size(), + out.shape(), + out.strides(), + large(), + N_READS); encoder.add_kernel_node( kernel, num_blocks,