From aa7b47481a9188407dd99eaaa456c2b87c621ff0 Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 8 Aug 2025 15:23:30 +0900 Subject: [PATCH] [CUDA] Optimize set_mm_device_pointers for small ndim (#2473) --- .../cuda/gemms/cublas_batched_gemm_12_9.cu | 197 ++++++++++++++---- 1 file changed, 154 insertions(+), 43 deletions(-) diff --git a/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu b/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu index 86733fb06..da7163b42 100644 --- a/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu +++ b/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu @@ -10,7 +10,34 @@ namespace mlx::core::cu { namespace cg = cooperative_groups; -__global__ void set_mm_device_pointers( +template +__global__ void set_mm_device_pointers_nd( + int8_t** pointers, + int8_t* a_start, + int8_t* b_start, + int8_t* out_start, + int item_size, + const __grid_constant__ cuda::std::array batch_shape, + const __grid_constant__ cuda::std::array a_batch_strides, + const __grid_constant__ cuda::std::array b_batch_strides, + int64_t batch_stride, + int batch_count) { + auto index = cg::this_grid().thread_rank(); + if (index >= batch_count) { + return; + } + auto [a_offset, b_offset] = elem_to_loc_nd( + index, + batch_shape.data(), + a_batch_strides.data(), + b_batch_strides.data()); + pointers[index] = a_start + item_size * a_offset; + pointers[index + batch_count] = b_start + item_size * b_offset; + pointers[index + 2 * batch_count] = + out_start + item_size * index * batch_stride; +} + +__global__ void set_mm_device_pointers_g( int8_t** pointers, int8_t* a_start, int8_t* b_start, @@ -38,7 +65,38 @@ __global__ void set_mm_device_pointers( out_start + item_size * index * batch_stride; } -__global__ void set_addmm_device_pointers( +template +__global__ void set_addmm_device_pointers_nd( + int8_t** pointers, + int8_t* a_start, + int8_t* b_start, + int8_t* c_start, + int8_t* out_start, + int item_size, + const __grid_constant__ cuda::std::array batch_shape, + const __grid_constant__ cuda::std::array a_batch_strides, + const __grid_constant__ cuda::std::array b_batch_strides, + const __grid_constant__ cuda::std::array c_batch_strides, + int64_t batch_stride, + int batch_count) { + auto index = cg::this_grid().thread_rank(); + if (index >= batch_count) { + return; + } + auto [a_offset, b_offset, c_offset] = elem_to_loc_nd( + index, + batch_shape.data(), + a_batch_strides.data(), + b_batch_strides.data(), + c_batch_strides.data()); + pointers[index] = a_start + item_size * a_offset; + pointers[index + batch_count] = b_start + item_size * b_offset; + pointers[index + 2 * batch_count] = c_start + item_size * c_offset; + pointers[index + 3 * batch_count] = + out_start + item_size * index * batch_stride; +} + +__global__ void set_addmm_device_pointers_g( int8_t** pointers, int8_t* a_start, int8_t* b_start, @@ -89,37 +147,62 @@ void Matmul::run_batched( const mlx::core::Shape& batch_shape, const mlx::core::Strides& a_batch_strides, const mlx::core::Strides& b_batch_strides) { - auto batch_count = out.size() / (M_ * N_); + int batch_count = out.size() / (M_ * N_); set_pointer_mode(a_desc_, batch_count); set_pointer_mode(b_desc_, batch_count); set_pointer_mode(out_desc_, batch_count); // Launch kernel to set device offsets auto pointers = array( - allocator::malloc(batch_count * sizeof(uint64_t) * 3), - {static_cast(batch_count * 3)}, + allocator::malloc(batch_count * sizeof(void*) * 3), + {batch_count * 3}, uint64); encoder.add_temporary(pointers); - int block_size = 512; encoder.set_output_array(pointers); - encoder.add_kernel_node( - cu::set_mm_device_pointers, - cuda::ceil_div(pointers.size(), block_size), - block_size, - 0, - pointers.data(), - a.data(), - b.data(), - out.data(), - static_cast(out.dtype().size()), - const_param(batch_shape), - const_param(a_batch_strides), - const_param(b_batch_strides), - static_cast(M_) * N_, - static_cast(batch_shape.size()), - batch_count); + int block_dims = std::min(batch_count, 256); + int num_blocks = cuda::ceil_div(batch_count, block_dims); + int64_t batch_stride = M_ * N_; + int item_size = out.itemsize(); + + int ndim = batch_shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto ndim_constant) { + encoder.add_kernel_node( + cu::set_mm_device_pointers_nd, + num_blocks, + block_dims, + 0, + pointers.data(), + a.data(), + b.data(), + out.data(), + item_size, + const_param(batch_shape), + const_param(a_batch_strides), + const_param(b_batch_strides), + batch_stride, + batch_count); + }); + } else { + encoder.add_kernel_node( + cu::set_mm_device_pointers_g, + num_blocks, + block_dims, + 0, + pointers.data(), + a.data(), + b.data(), + out.data(), + item_size, + const_param(batch_shape), + const_param(a_batch_strides), + const_param(b_batch_strides), + batch_stride, + ndim, + batch_count); + } // Run matmul encoder.set_input_array(pointers); @@ -150,7 +233,7 @@ void Matmul::run_batched( const mlx::core::Strides& c_batch_strides, float alpha, float beta) { - auto batch_count = out.size() / (M_ * N_); + int batch_count = out.size() / (M_ * N_); set_pointer_mode(a_desc_, batch_count); set_pointer_mode(b_desc_, batch_count); set_pointer_mode(c_desc_, batch_count); @@ -159,30 +242,58 @@ void Matmul::run_batched( // Launch kernel to set device offsets auto pointers = array( allocator::malloc(batch_count * sizeof(uint64_t) * 4), - {static_cast(batch_count * 4)}, + {batch_count * 4}, uint64); encoder.add_temporary(pointers); - int block_size = 512; encoder.set_output_array(pointers); - encoder.add_kernel_node( - cu::set_addmm_device_pointers, - cuda::ceil_div(pointers.size(), block_size), - block_size, - 0, - pointers.data(), - a.data(), - b.data(), - c.data(), - out.data(), - static_cast(out.dtype().size()), - const_param(batch_shape), - const_param(a_batch_strides), - const_param(b_batch_strides), - const_param(c_batch_strides), - static_cast(M_) * N_, - static_cast(batch_shape.size()), - batch_count); + + int block_dims = std::min(batch_count, 256); + int num_blocks = cuda::ceil_div(batch_count, block_dims); + int64_t batch_stride = M_ * N_; + int item_size = out.itemsize(); + + int ndim = batch_shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto ndim_constant) { + encoder.add_kernel_node( + cu::set_addmm_device_pointers_nd, + num_blocks, + block_dims, + 0, + pointers.data(), + a.data(), + b.data(), + c.data(), + out.data(), + item_size, + const_param(batch_shape), + const_param(a_batch_strides), + const_param(b_batch_strides), + const_param(c_batch_strides), + batch_stride, + batch_count); + }); + } else { + encoder.add_kernel_node( + cu::set_addmm_device_pointers_g, + num_blocks, + block_dims, + 0, + pointers.data(), + a.data(), + b.data(), + c.data(), + out.data(), + item_size, + const_param(batch_shape), + const_param(a_batch_strides), + const_param(b_batch_strides), + const_param(c_batch_strides), + batch_stride, + ndim, + batch_count); + } // Run matmul encoder.set_input_array(pointers);