diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 96888da97..bf2b96de8 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -37,19 +37,35 @@ __global__ void unary_v(const In* in, Out* out, IdxT size) { } } -template +template __global__ void unary_g( const In* in, Out* out, - IdxT size, + IdxT size_rest, const __grid_constant__ Shape shape, const __grid_constant__ Strides strides, int ndim) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto idx = elem_to_loc(index, shape.data(), strides.data(), ndim); - out[index] = Op{}(in[idx]); + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[ndim - 1]; + auto stride_x = strides[ndim - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto idx = elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim); + auto in_vec = load_vector( + in + idx, index_x, shape_x, stride_x, In(0)); + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = Op{}(in_vec[i]); + } + store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } template @@ -127,8 +143,7 @@ void unary_op_gpu_inplace( using OutType = cuda_type_t; if (contig) { using IdxT = std::conditional_t; - // TODO: Choose optimized value based on type size. - constexpr int N_READS = 4; + constexpr int N_READS = 16 / sizeof(OutType); auto [num_blocks, block_dims] = get_launch_args( out.data_size(), out.shape(), out.strides(), large, N_READS); encoder.add_kernel_node( @@ -142,18 +157,30 @@ void unary_op_gpu_inplace( } else { using IdxT = std::conditional_t; auto [shape, strides] = collapse_contiguous_dims(in); - auto [num_blocks, block_dims] = get_launch_args(out, large); + auto ndim = shape.size(); + int work_per_thread = 1; + auto kernel = cu::unary_g; + auto dim0 = ndim > 0 ? shape.back() : 1; + auto rest = out.size() / dim0; + if (dim0 >= 4) { + kernel = cu::unary_g; + work_per_thread = 4; + } + dim0 = (dim0 + work_per_thread - 1) / work_per_thread; + auto block_dims = get_block_dims(dim0, rest, 1); + uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); + uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); encoder.add_kernel_node( - cu::unary_g, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, in.data(), out.data(), - out.data_size(), + rest, const_param(shape), const_param(strides), - shape.size()); + ndim); } }); } else {