diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu index 64c67a176..6ac42751a 100644 --- a/mlx/backend/cuda/copy/copy_general.cu +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -10,37 +10,80 @@ namespace cu { namespace cg = cooperative_groups; -template +template __global__ void copy_gg_nd( const In* in, Out* out, - IdxT size, + IdxT size_rest, const __grid_constant__ cuda::std::array shape, const __grid_constant__ cuda::std::array strides_in, const __grid_constant__ cuda::std::array strides_out) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto [idx_in, idx_out] = elem_to_loc_nd( - index, shape.data(), strides_in.data(), strides_out.data()); - out[idx_out] = CastOp{}(in[idx_in]); + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[NDIM - 1]; + auto in_stride_x = strides_in[NDIM - 1]; + auto out_stride_x = strides_out[NDIM - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto [idx_in, idx_out] = elem_to_loc_nd( + index_rest * shape_x, + shape.data(), + strides_in.data(), + strides_out.data()); + + auto in_vec = + load_vector(in + idx_in, index_x, shape_x, in_stride_x, In(0)); + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = CastOp{}(in_vec[i]); + } + store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x); } -template +template __global__ void copy_gg( const In* in, Out* out, - IdxT size, + IdxT size_rest, const __grid_constant__ Shape shape, const __grid_constant__ Strides strides_in, const __grid_constant__ Strides strides_out, int ndim) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto [idx_in, idx_out] = elem_to_loc( - index, shape.data(), strides_in.data(), strides_out.data(), ndim); - out[idx_out] = CastOp{}(in[idx_in]); + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[ndim - 1]; + auto in_stride_x = strides_in[ndim - 1]; + auto out_stride_x = strides_out[ndim - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto [idx_in, idx_out] = elem_to_loc( + index_rest * shape_x, + shape.data(), + strides_in.data(), + strides_out.data(), + ndim); + + auto in_vec = + load_vector(in + idx_in, index_x, shape_x, in_stride_x, In(0)); + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = CastOp{}(in_vec[i]); + } + store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x); } } // namespace cu @@ -69,33 +112,52 @@ void copy_general( size_t data_size = 1; for (auto& s : shape) data_size *= s; + + int work_per_thread = 1; + auto dim0 = ndim > 0 ? shape.back() : 1; + auto rest = data_size / dim0; + if (dim0 >= 4) { + work_per_thread = 4; + } + + dim0 = (dim0 + work_per_thread - 1) / work_per_thread; + auto block_dims = get_block_dims(dim0, rest, 1); + uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); + uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); + if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto ndim_constant) { - auto [num_blocks, block_dims] = - get_launch_args(data_size, shape, out.strides(), large()); + auto kernel = + cu::copy_gg_nd; + if (work_per_thread == 4) { + kernel = + cu::copy_gg_nd; + } encoder.add_kernel_node( - cu::copy_gg_nd, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, in_ptr, out_ptr, - data_size, + rest, const_param(shape), const_param(strides_in), const_param(strides_out)); }); } else { // ndim >= 4 - auto [num_blocks, block_dims] = - get_launch_args(data_size, shape, out.strides(), large()); + auto kernel = cu::copy_gg; + if (work_per_thread == 4) { + kernel = cu::copy_gg; + } encoder.add_kernel_node( - cu::copy_gg, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, in_ptr, out_ptr, - data_size, + rest, const_param(shape), const_param(strides_in), const_param(strides_out), diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index bc055c9df..7ebc5d654 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -146,6 +146,23 @@ inline __device__ void store_vector( } } +template +inline __device__ void store_vector( + T* ptr, + uint32_t offset, + const AlignedVector& vec, + SizeT size, + int64_t stride) { + if (is_aligned(ptr) && (offset + 1) * N <= size && stride == 1) { + auto* to = reinterpret_cast*>(ptr); + to[offset] = vec; + } else { + for (int i = 0; (offset * N + i) < size && i < N; ++i) { + ptr[stride * (offset * N + i)] = vec[i]; + } + } +} + /////////////////////////////////////////////////////////////////////////////// // Type limits utils ///////////////////////////////////////////////////////////////////////////////