mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	[CUDA] More sizes for gemv (#2429)
* route more to gemv * route more sizes to custom gemv
This commit is contained in:
		| @@ -11,7 +11,6 @@ namespace mlx::core::cu { | ||||
|  | ||||
| namespace cg = cooperative_groups; | ||||
|  | ||||
| static constexpr int n_per_thread = 4; | ||||
| static constexpr int rows_per_block = 8; | ||||
|  | ||||
| template <typename T, int rows_per_block, int n_per_thread> | ||||
| @@ -74,8 +73,23 @@ __global__ void gemv_batched( | ||||
| } | ||||
|  | ||||
| bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) { | ||||
|   return K % (WARP_SIZE * n_per_thread) == 0 && | ||||
|       ((M == 1 && b_transposed) || (N == 1 && !a_transposed)); | ||||
|   bool is_multiple = K % 32 == 0 || K % 64 == 0 || K % 128 == 0; | ||||
|   return is_multiple && ((M == 1 && b_transposed) || (N == 1 && !a_transposed)); | ||||
| } | ||||
|  | ||||
| template <typename F> | ||||
| void dispatch_n_per_thread(int n_per_thread, F&& f) { | ||||
|   switch (n_per_thread) { | ||||
|     case 1: | ||||
|       f(std::integral_constant<int, 1>{}); | ||||
|       break; | ||||
|     case 2: | ||||
|       f(std::integral_constant<int, 2>{}); | ||||
|       break; | ||||
|     case 4: | ||||
|       f(std::integral_constant<int, 4>{}); | ||||
|       break; | ||||
|   } | ||||
| } | ||||
|  | ||||
| void gemv( | ||||
| @@ -114,33 +128,39 @@ void gemv( | ||||
|       rows = M; | ||||
|     } | ||||
|     uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block; | ||||
|     if (batch_count == 1) { | ||||
|       auto kernel = gemv_single<DataType, rows_per_block, n_per_thread>; | ||||
|       encoder.add_kernel_node( | ||||
|           kernel, | ||||
|           num_blocks_x, | ||||
|           block_dims, | ||||
|           mat, | ||||
|           vec, | ||||
|           out.data<DataType>(), | ||||
|           rows, | ||||
|           cols); | ||||
|     } else { | ||||
|       auto kernel = gemv_batched<DataType, rows_per_block, n_per_thread>; | ||||
|       encoder.add_kernel_node( | ||||
|           kernel, | ||||
|           dim3{num_blocks_x, batch_count}, | ||||
|           block_dims, | ||||
|           mat, | ||||
|           vec, | ||||
|           out.data<DataType>(), | ||||
|           rows, | ||||
|           cols, | ||||
|           const_param(batch_shape), | ||||
|           mat_strides, | ||||
|           vec_strides, | ||||
|           batch_shape.size()); | ||||
|     int n_per_t = 4; | ||||
|     while (K % (n_per_t * WARP_SIZE) != 0) { | ||||
|       n_per_t >>= 1; | ||||
|     } | ||||
|     dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) { | ||||
|       if (batch_count == 1) { | ||||
|         auto kernel = gemv_single<DataType, rows_per_block, n_per_thread()>; | ||||
|         encoder.add_kernel_node( | ||||
|             kernel, | ||||
|             num_blocks_x, | ||||
|             block_dims, | ||||
|             mat, | ||||
|             vec, | ||||
|             out.data<DataType>(), | ||||
|             rows, | ||||
|             cols); | ||||
|       } else { | ||||
|         auto kernel = gemv_batched<DataType, rows_per_block, n_per_thread()>; | ||||
|         encoder.add_kernel_node( | ||||
|             kernel, | ||||
|             dim3{num_blocks_x, batch_count}, | ||||
|             block_dims, | ||||
|             mat, | ||||
|             vec, | ||||
|             out.data<DataType>(), | ||||
|             rows, | ||||
|             cols, | ||||
|             const_param(batch_shape), | ||||
|             mat_strides, | ||||
|             vec_strides, | ||||
|             batch_shape.size()); | ||||
|       } | ||||
|     }); | ||||
|   }); | ||||
| } | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun