diff --git a/mlx/backend/cuda/gemms/cublas_gemm.cpp b/mlx/backend/cuda/gemms/cublas_gemm.cpp index 836385dfe..89d2a9e7b 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm.cpp @@ -85,10 +85,10 @@ cublasLtMatrixLayout_t create_matrix_layout( int32_t batch_count, int64_t batch_stride) { cublasLtMatrixLayout_t desc; + if (transposed) { + std::swap(rows, cols); + } CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld)); - cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW; - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( - desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t))); if (batch_count > 1) { CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( desc, @@ -138,25 +138,34 @@ CublasGemm::CublasGemm( CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(int32_t))); - cublasOperation_t op = CUBLAS_OP_N; + + // In cublasLt matrices use column-major layout, while it is possible to use + // the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias + // epilogue does not work with the option. So instead we swap A and B to make + // cublasLt return the row-major result, which works because: + // - the data of a matrix in row-major layout is identical to its transpose in + // column-major layout + // - C^T = (A @ B)^T = B^T @ A^T + cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N; CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, CUBLASLT_MATMUL_DESC_TRANSA, - &op, + &a_op, sizeof(cublasOperation_t))); + cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N; CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, CUBLASLT_MATMUL_DESC_TRANSB, - &op, + &b_op, sizeof(cublasOperation_t))); auto type = dtype_to_cublas_type(dtype); a_desc_ = create_matrix_layout( - type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride); + type, b_cols, b_rows, b_transposed, ldb, batch_count, b_batch_stride); b_desc_ = create_matrix_layout( - type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride); + type, a_cols, a_rows, a_transposed, lda, batch_count, a_batch_stride); out_desc_ = create_matrix_layout( - type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols); + type, b_cols, a_rows, false, b_cols, batch_count, a_rows * b_cols); } CublasGemm::CublasGemm( @@ -191,7 +200,7 @@ CublasGemm::CublasGemm( b_batch_stride) { auto type = dtype_to_cublas_type(dtype); c_desc_ = create_matrix_layout( - type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride); + type, b_cols, a_rows, false, ldc, batch_count, c_batch_stride); } CublasGemm::~CublasGemm() { @@ -213,14 +222,25 @@ void CublasGemm::set_out( CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); out_desc_ = create_matrix_layout( dtype_to_cublas_type(dtype), - rows, cols, + rows, transposed, ld, batch_count, batch_stride); } +void CublasGemm::set_bias(void* bias) { + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_EPILOGUE, + &epilogue, + sizeof(epilogue))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias))); +} + void CublasGemm::run( cu::CommandEncoder& encoder, array& out, @@ -330,9 +350,9 @@ void CublasGemm::execute( handle_, matmul_desc_, &alpha, - a, + b, // a and b are swapped a_desc_, - b, + a, b_desc_, &beta, c ? c : out, diff --git a/mlx/backend/cuda/gemms/cublas_gemm.h b/mlx/backend/cuda/gemms/cublas_gemm.h index 1b06fb2f7..857910e7f 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.h +++ b/mlx/backend/cuda/gemms/cublas_gemm.h @@ -55,6 +55,8 @@ class CublasGemm { int32_t batch_count, int64_t batch_stride); + void set_bias(void* bias); + void run( cu::CommandEncoder& encoder, array& out, diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index b11fae538..66cd025df 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -11,6 +11,7 @@ #include namespace mlx::core { + namespace { std::tuple @@ -28,6 +29,74 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) { } } +void gemm_and_bias( + cu::CommandEncoder& encoder, + int M, + int N, + int K, + bool a_transposed, + int64_t lda, + bool b_transposed, + int64_t ldb, + array& out, + const array& a, + const array& b, + void* bias = nullptr) { + // Check and collapse batch dimensions + auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b); + + auto batch_count = out.size() / (M * N); + + // Collapse batches into M if needed + if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K && + b_batch_strides.back() == 0) { + M *= batch_shape.back(); + batch_count = 1; + + a_batch_strides = {0}; + b_batch_strides = {0}; + batch_shape = {1}; + } + + // Use gemmv when possible + if (!bias && cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) { + cu::gemv( + a, + b, + out, + M, + N, + K, + batch_count, + batch_shape, + a_batch_strides, + b_batch_strides, + encoder); + return; + } + + // Invoke cublasLt + CublasGemm gemm( + encoder.device(), + a.dtype(), + a_transposed, + M, + K, + lda, + b_transposed, + K, + N, + ldb, + batch_shape.back(), + a_batch_strides.back(), + b_batch_strides.back()); + if (bias) { + gemm.set_bias(bias); + } + gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides); +} + } // namespace void Matmul::eval_gpu(const std::vector& inputs, array& out) { @@ -48,9 +117,6 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { out.set_data(allocator::malloc(out.nbytes())); - ///////////////////////////////////////////////////////////////////////////// - // Init checks and prep - int M = a_pre.shape(-2); int N = b_pre.shape(-1); int K = a_pre.shape(-1); @@ -60,58 +126,8 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); - ///////////////////////////////////////////////////////////////////////////// - // Check and collapse batch dimensions - - auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b); - - auto batch_count = out.size() / (M * N); - - // Collapse batches into M if needed - if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 && - a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K && - b_batch_strides.back() == 0) { - M *= batch_shape.back(); - batch_count = 1; - - a_batch_strides = {0}; - b_batch_strides = {0}; - batch_shape = {1}; - } - - if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) { - cu::gemv( - a, - b, - out, - M, - N, - K, - batch_count, - batch_shape, - a_batch_strides, - b_batch_strides, - encoder); - return; - } - - ///////////////////////////////////////////////////////////////////////////// - // Invoke cublasLt - CublasGemm gemm( - cu::device(s.device), - a.dtype(), - a_transposed, - M, - K, - lda, - b_transposed, - K, - N, - ldb, - batch_shape.back(), - a_batch_strides.back(), - b_batch_strides.back()); - gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides); + gemm_and_bias( + encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b); } void AddMM::eval_gpu(const std::vector& inputs, array& out) { @@ -136,6 +152,27 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); + ///////////////////////////////////////////////////////////////////////////// + // Dispatch to GEMM with epilogue or AddMM + + if (beta_ == 1 && c.strides(-1) == 1 && c.data_size() == out.shape(-1)) { + out.set_data(allocator::malloc(out.nbytes())); + gemm_and_bias( + encoder, + M, + N, + K, + a_transposed, + lda, + b_transposed, + ldb, + out, + a, + b, + c.data()); + return; + } + int64_t ldc; { auto stx = c.strides()[c.ndim() - 2]; @@ -177,7 +214,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { } ///////////////////////////////////////////////////////////////////////////// - // Invoke cublasLt + // Invoke cublasLt with AddMM settings CublasGemm gemm( cu::device(s.device),