mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	[CUDA] Use GEMM with epilogue instead of AddMM (#2569)
This commit is contained in:
		| @@ -85,10 +85,10 @@ cublasLtMatrixLayout_t create_matrix_layout( | ||||
|     int32_t batch_count, | ||||
|     int64_t batch_stride) { | ||||
|   cublasLtMatrixLayout_t desc; | ||||
|   if (transposed) { | ||||
|     std::swap(rows, cols); | ||||
|   } | ||||
|   CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld)); | ||||
|   cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW; | ||||
|   CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( | ||||
|       desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t))); | ||||
|   if (batch_count > 1) { | ||||
|     CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( | ||||
|         desc, | ||||
| @@ -138,25 +138,34 @@ CublasGemm::CublasGemm( | ||||
|       CUBLASLT_MATMUL_DESC_POINTER_MODE, | ||||
|       &pointer_mode, | ||||
|       sizeof(int32_t))); | ||||
|   cublasOperation_t op = CUBLAS_OP_N; | ||||
|  | ||||
|   // In cublasLt matrices use column-major layout, while it is possible to use | ||||
|   // the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias | ||||
|   // epilogue does not work with the option. So instead we swap A and B to make | ||||
|   // cublasLt return the row-major result, which works because: | ||||
|   // - the data of a matrix in row-major layout is identical to its transpose in | ||||
|   //   column-major layout | ||||
|   // - C^T = (A @ B)^T = B^T @ A^T | ||||
|   cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N; | ||||
|   CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( | ||||
|       matmul_desc_, | ||||
|       CUBLASLT_MATMUL_DESC_TRANSA, | ||||
|       &op, | ||||
|       &a_op, | ||||
|       sizeof(cublasOperation_t))); | ||||
|   cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N; | ||||
|   CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( | ||||
|       matmul_desc_, | ||||
|       CUBLASLT_MATMUL_DESC_TRANSB, | ||||
|       &op, | ||||
|       &b_op, | ||||
|       sizeof(cublasOperation_t))); | ||||
|  | ||||
|   auto type = dtype_to_cublas_type(dtype); | ||||
|   a_desc_ = create_matrix_layout( | ||||
|       type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride); | ||||
|       type, b_cols, b_rows, b_transposed, ldb, batch_count, b_batch_stride); | ||||
|   b_desc_ = create_matrix_layout( | ||||
|       type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride); | ||||
|       type, a_cols, a_rows, a_transposed, lda, batch_count, a_batch_stride); | ||||
|   out_desc_ = create_matrix_layout( | ||||
|       type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols); | ||||
|       type, b_cols, a_rows, false, b_cols, batch_count, a_rows * b_cols); | ||||
| } | ||||
|  | ||||
| CublasGemm::CublasGemm( | ||||
| @@ -191,7 +200,7 @@ CublasGemm::CublasGemm( | ||||
|           b_batch_stride) { | ||||
|   auto type = dtype_to_cublas_type(dtype); | ||||
|   c_desc_ = create_matrix_layout( | ||||
|       type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride); | ||||
|       type, b_cols, a_rows, false, ldc, batch_count, c_batch_stride); | ||||
| } | ||||
|  | ||||
| CublasGemm::~CublasGemm() { | ||||
| @@ -213,14 +222,25 @@ void CublasGemm::set_out( | ||||
|   CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); | ||||
|   out_desc_ = create_matrix_layout( | ||||
|       dtype_to_cublas_type(dtype), | ||||
|       rows, | ||||
|       cols, | ||||
|       rows, | ||||
|       transposed, | ||||
|       ld, | ||||
|       batch_count, | ||||
|       batch_stride); | ||||
| } | ||||
|  | ||||
| void CublasGemm::set_bias(void* bias) { | ||||
|   cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS; | ||||
|   CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( | ||||
|       matmul_desc_, | ||||
|       CUBLASLT_MATMUL_DESC_EPILOGUE, | ||||
|       &epilogue, | ||||
|       sizeof(epilogue))); | ||||
|   CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( | ||||
|       matmul_desc_, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias))); | ||||
| } | ||||
|  | ||||
| void CublasGemm::run( | ||||
|     cu::CommandEncoder& encoder, | ||||
|     array& out, | ||||
| @@ -330,9 +350,9 @@ void CublasGemm::execute( | ||||
|       handle_, | ||||
|       matmul_desc_, | ||||
|       &alpha, | ||||
|       a, | ||||
|       b, // a and b are swapped | ||||
|       a_desc_, | ||||
|       b, | ||||
|       a, | ||||
|       b_desc_, | ||||
|       &beta, | ||||
|       c ? c : out, | ||||
|   | ||||
| @@ -55,6 +55,8 @@ class CublasGemm { | ||||
|       int32_t batch_count, | ||||
|       int64_t batch_stride); | ||||
|  | ||||
|   void set_bias(void* bias); | ||||
|  | ||||
|   void run( | ||||
|       cu::CommandEncoder& encoder, | ||||
|       array& out, | ||||
|   | ||||
| @@ -11,6 +11,7 @@ | ||||
| #include <numeric> | ||||
|  | ||||
| namespace mlx::core { | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| std::tuple<bool, int64_t, array> | ||||
| @@ -28,6 +29,74 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) { | ||||
|   } | ||||
| } | ||||
|  | ||||
| void gemm_and_bias( | ||||
|     cu::CommandEncoder& encoder, | ||||
|     int M, | ||||
|     int N, | ||||
|     int K, | ||||
|     bool a_transposed, | ||||
|     int64_t lda, | ||||
|     bool b_transposed, | ||||
|     int64_t ldb, | ||||
|     array& out, | ||||
|     const array& a, | ||||
|     const array& b, | ||||
|     void* bias = nullptr) { | ||||
|   // Check and collapse batch dimensions | ||||
|   auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b); | ||||
|  | ||||
|   auto batch_count = out.size() / (M * N); | ||||
|  | ||||
|   // Collapse batches into M if needed | ||||
|   if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 && | ||||
|       a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K && | ||||
|       b_batch_strides.back() == 0) { | ||||
|     M *= batch_shape.back(); | ||||
|     batch_count = 1; | ||||
|  | ||||
|     a_batch_strides = {0}; | ||||
|     b_batch_strides = {0}; | ||||
|     batch_shape = {1}; | ||||
|   } | ||||
|  | ||||
|   // Use gemmv when possible | ||||
|   if (!bias && cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) { | ||||
|     cu::gemv( | ||||
|         a, | ||||
|         b, | ||||
|         out, | ||||
|         M, | ||||
|         N, | ||||
|         K, | ||||
|         batch_count, | ||||
|         batch_shape, | ||||
|         a_batch_strides, | ||||
|         b_batch_strides, | ||||
|         encoder); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   // Invoke cublasLt | ||||
|   CublasGemm gemm( | ||||
|       encoder.device(), | ||||
|       a.dtype(), | ||||
|       a_transposed, | ||||
|       M, | ||||
|       K, | ||||
|       lda, | ||||
|       b_transposed, | ||||
|       K, | ||||
|       N, | ||||
|       ldb, | ||||
|       batch_shape.back(), | ||||
|       a_batch_strides.back(), | ||||
|       b_batch_strides.back()); | ||||
|   if (bias) { | ||||
|     gemm.set_bias(bias); | ||||
|   } | ||||
|   gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides); | ||||
| } | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
| @@ -48,9 +117,6 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|  | ||||
|   out.set_data(allocator::malloc(out.nbytes())); | ||||
|  | ||||
|   ///////////////////////////////////////////////////////////////////////////// | ||||
|   // Init checks and prep | ||||
|  | ||||
|   int M = a_pre.shape(-2); | ||||
|   int N = b_pre.shape(-1); | ||||
|   int K = a_pre.shape(-1); | ||||
| @@ -60,58 +126,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); | ||||
|   auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); | ||||
|  | ||||
|   ///////////////////////////////////////////////////////////////////////////// | ||||
|   // Check and collapse batch dimensions | ||||
|  | ||||
|   auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b); | ||||
|  | ||||
|   auto batch_count = out.size() / (M * N); | ||||
|  | ||||
|   // Collapse batches into M if needed | ||||
|   if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 && | ||||
|       a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K && | ||||
|       b_batch_strides.back() == 0) { | ||||
|     M *= batch_shape.back(); | ||||
|     batch_count = 1; | ||||
|  | ||||
|     a_batch_strides = {0}; | ||||
|     b_batch_strides = {0}; | ||||
|     batch_shape = {1}; | ||||
|   } | ||||
|  | ||||
|   if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) { | ||||
|     cu::gemv( | ||||
|         a, | ||||
|         b, | ||||
|         out, | ||||
|         M, | ||||
|         N, | ||||
|         K, | ||||
|         batch_count, | ||||
|         batch_shape, | ||||
|         a_batch_strides, | ||||
|         b_batch_strides, | ||||
|         encoder); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   ///////////////////////////////////////////////////////////////////////////// | ||||
|   // Invoke cublasLt | ||||
|   CublasGemm gemm( | ||||
|       cu::device(s.device), | ||||
|       a.dtype(), | ||||
|       a_transposed, | ||||
|       M, | ||||
|       K, | ||||
|       lda, | ||||
|       b_transposed, | ||||
|       K, | ||||
|       N, | ||||
|       ldb, | ||||
|       batch_shape.back(), | ||||
|       a_batch_strides.back(), | ||||
|       b_batch_strides.back()); | ||||
|   gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides); | ||||
|   gemm_and_bias( | ||||
|       encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b); | ||||
| } | ||||
|  | ||||
| void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
| @@ -136,6 +152,27 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); | ||||
|   auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); | ||||
|  | ||||
|   ///////////////////////////////////////////////////////////////////////////// | ||||
|   // Dispatch to GEMM with epilogue or AddMM | ||||
|  | ||||
|   if (beta_ == 1 && c.strides(-1) == 1 && c.data_size() == out.shape(-1)) { | ||||
|     out.set_data(allocator::malloc(out.nbytes())); | ||||
|     gemm_and_bias( | ||||
|         encoder, | ||||
|         M, | ||||
|         N, | ||||
|         K, | ||||
|         a_transposed, | ||||
|         lda, | ||||
|         b_transposed, | ||||
|         ldb, | ||||
|         out, | ||||
|         a, | ||||
|         b, | ||||
|         c.data<void>()); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   int64_t ldc; | ||||
|   { | ||||
|     auto stx = c.strides()[c.ndim() - 2]; | ||||
| @@ -177,7 +214,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   } | ||||
|  | ||||
|   ///////////////////////////////////////////////////////////////////////////// | ||||
|   // Invoke cublasLt | ||||
|   // Invoke cublasLt with AddMM settings | ||||
|  | ||||
|   CublasGemm gemm( | ||||
|       cu::device(s.device), | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Cheng
					Cheng