mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 16:48:10 +08:00
[CUDA] Use GEMM with epilogue instead of AddMM (#2569)
This commit is contained in:
@@ -85,10 +85,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
|
|||||||
int32_t batch_count,
|
int32_t batch_count,
|
||||||
int64_t batch_stride) {
|
int64_t batch_stride) {
|
||||||
cublasLtMatrixLayout_t desc;
|
cublasLtMatrixLayout_t desc;
|
||||||
|
if (transposed) {
|
||||||
|
std::swap(rows, cols);
|
||||||
|
}
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
|
||||||
cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
|
||||||
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
|
|
||||||
if (batch_count > 1) {
|
if (batch_count > 1) {
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||||
desc,
|
desc,
|
||||||
@@ -138,25 +138,34 @@ CublasGemm::CublasGemm(
|
|||||||
CUBLASLT_MATMUL_DESC_POINTER_MODE,
|
CUBLASLT_MATMUL_DESC_POINTER_MODE,
|
||||||
&pointer_mode,
|
&pointer_mode,
|
||||||
sizeof(int32_t)));
|
sizeof(int32_t)));
|
||||||
cublasOperation_t op = CUBLAS_OP_N;
|
|
||||||
|
// In cublasLt matrices use column-major layout, while it is possible to use
|
||||||
|
// the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias
|
||||||
|
// epilogue does not work with the option. So instead we swap A and B to make
|
||||||
|
// cublasLt return the row-major result, which works because:
|
||||||
|
// - the data of a matrix in row-major layout is identical to its transpose in
|
||||||
|
// column-major layout
|
||||||
|
// - C^T = (A @ B)^T = B^T @ A^T
|
||||||
|
cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
matmul_desc_,
|
matmul_desc_,
|
||||||
CUBLASLT_MATMUL_DESC_TRANSA,
|
CUBLASLT_MATMUL_DESC_TRANSA,
|
||||||
&op,
|
&a_op,
|
||||||
sizeof(cublasOperation_t)));
|
sizeof(cublasOperation_t)));
|
||||||
|
cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
matmul_desc_,
|
matmul_desc_,
|
||||||
CUBLASLT_MATMUL_DESC_TRANSB,
|
CUBLASLT_MATMUL_DESC_TRANSB,
|
||||||
&op,
|
&b_op,
|
||||||
sizeof(cublasOperation_t)));
|
sizeof(cublasOperation_t)));
|
||||||
|
|
||||||
auto type = dtype_to_cublas_type(dtype);
|
auto type = dtype_to_cublas_type(dtype);
|
||||||
a_desc_ = create_matrix_layout(
|
a_desc_ = create_matrix_layout(
|
||||||
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
|
type, b_cols, b_rows, b_transposed, ldb, batch_count, b_batch_stride);
|
||||||
b_desc_ = create_matrix_layout(
|
b_desc_ = create_matrix_layout(
|
||||||
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
|
type, a_cols, a_rows, a_transposed, lda, batch_count, a_batch_stride);
|
||||||
out_desc_ = create_matrix_layout(
|
out_desc_ = create_matrix_layout(
|
||||||
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
type, b_cols, a_rows, false, b_cols, batch_count, a_rows * b_cols);
|
||||||
}
|
}
|
||||||
|
|
||||||
CublasGemm::CublasGemm(
|
CublasGemm::CublasGemm(
|
||||||
@@ -191,7 +200,7 @@ CublasGemm::CublasGemm(
|
|||||||
b_batch_stride) {
|
b_batch_stride) {
|
||||||
auto type = dtype_to_cublas_type(dtype);
|
auto type = dtype_to_cublas_type(dtype);
|
||||||
c_desc_ = create_matrix_layout(
|
c_desc_ = create_matrix_layout(
|
||||||
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
|
type, b_cols, a_rows, false, ldc, batch_count, c_batch_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
CublasGemm::~CublasGemm() {
|
CublasGemm::~CublasGemm() {
|
||||||
@@ -213,14 +222,25 @@ void CublasGemm::set_out(
|
|||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
|
||||||
out_desc_ = create_matrix_layout(
|
out_desc_ = create_matrix_layout(
|
||||||
dtype_to_cublas_type(dtype),
|
dtype_to_cublas_type(dtype),
|
||||||
rows,
|
|
||||||
cols,
|
cols,
|
||||||
|
rows,
|
||||||
transposed,
|
transposed,
|
||||||
ld,
|
ld,
|
||||||
batch_count,
|
batch_count,
|
||||||
batch_stride);
|
batch_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CublasGemm::set_bias(void* bias) {
|
||||||
|
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
|
matmul_desc_,
|
||||||
|
CUBLASLT_MATMUL_DESC_EPILOGUE,
|
||||||
|
&epilogue,
|
||||||
|
sizeof(epilogue)));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
|
matmul_desc_, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)));
|
||||||
|
}
|
||||||
|
|
||||||
void CublasGemm::run(
|
void CublasGemm::run(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
@@ -330,9 +350,9 @@ void CublasGemm::execute(
|
|||||||
handle_,
|
handle_,
|
||||||
matmul_desc_,
|
matmul_desc_,
|
||||||
&alpha,
|
&alpha,
|
||||||
a,
|
b, // a and b are swapped
|
||||||
a_desc_,
|
a_desc_,
|
||||||
b,
|
a,
|
||||||
b_desc_,
|
b_desc_,
|
||||||
&beta,
|
&beta,
|
||||||
c ? c : out,
|
c ? c : out,
|
||||||
|
@@ -55,6 +55,8 @@ class CublasGemm {
|
|||||||
int32_t batch_count,
|
int32_t batch_count,
|
||||||
int64_t batch_stride);
|
int64_t batch_stride);
|
||||||
|
|
||||||
|
void set_bias(void* bias);
|
||||||
|
|
||||||
void run(
|
void run(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
|
@@ -11,6 +11,7 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
std::tuple<bool, int64_t, array>
|
std::tuple<bool, int64_t, array>
|
||||||
@@ -28,6 +29,74 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void gemm_and_bias(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
bool a_transposed,
|
||||||
|
int64_t lda,
|
||||||
|
bool b_transposed,
|
||||||
|
int64_t ldb,
|
||||||
|
array& out,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
void* bias = nullptr) {
|
||||||
|
// Check and collapse batch dimensions
|
||||||
|
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
|
||||||
|
|
||||||
|
auto batch_count = out.size() / (M * N);
|
||||||
|
|
||||||
|
// Collapse batches into M if needed
|
||||||
|
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
|
||||||
|
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
|
||||||
|
b_batch_strides.back() == 0) {
|
||||||
|
M *= batch_shape.back();
|
||||||
|
batch_count = 1;
|
||||||
|
|
||||||
|
a_batch_strides = {0};
|
||||||
|
b_batch_strides = {0};
|
||||||
|
batch_shape = {1};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use gemmv when possible
|
||||||
|
if (!bias && cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
|
||||||
|
cu::gemv(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
out,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
batch_count,
|
||||||
|
batch_shape,
|
||||||
|
a_batch_strides,
|
||||||
|
b_batch_strides,
|
||||||
|
encoder);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invoke cublasLt
|
||||||
|
CublasGemm gemm(
|
||||||
|
encoder.device(),
|
||||||
|
a.dtype(),
|
||||||
|
a_transposed,
|
||||||
|
M,
|
||||||
|
K,
|
||||||
|
lda,
|
||||||
|
b_transposed,
|
||||||
|
K,
|
||||||
|
N,
|
||||||
|
ldb,
|
||||||
|
batch_shape.back(),
|
||||||
|
a_batch_strides.back(),
|
||||||
|
b_batch_strides.back());
|
||||||
|
if (bias) {
|
||||||
|
gemm.set_bias(bias);
|
||||||
|
}
|
||||||
|
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -48,9 +117,6 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Init checks and prep
|
|
||||||
|
|
||||||
int M = a_pre.shape(-2);
|
int M = a_pre.shape(-2);
|
||||||
int N = b_pre.shape(-1);
|
int N = b_pre.shape(-1);
|
||||||
int K = a_pre.shape(-1);
|
int K = a_pre.shape(-1);
|
||||||
@@ -60,58 +126,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
||||||
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
gemm_and_bias(
|
||||||
// Check and collapse batch dimensions
|
encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b);
|
||||||
|
|
||||||
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
|
|
||||||
|
|
||||||
auto batch_count = out.size() / (M * N);
|
|
||||||
|
|
||||||
// Collapse batches into M if needed
|
|
||||||
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
|
|
||||||
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
|
|
||||||
b_batch_strides.back() == 0) {
|
|
||||||
M *= batch_shape.back();
|
|
||||||
batch_count = 1;
|
|
||||||
|
|
||||||
a_batch_strides = {0};
|
|
||||||
b_batch_strides = {0};
|
|
||||||
batch_shape = {1};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
|
|
||||||
cu::gemv(
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
out,
|
|
||||||
M,
|
|
||||||
N,
|
|
||||||
K,
|
|
||||||
batch_count,
|
|
||||||
batch_shape,
|
|
||||||
a_batch_strides,
|
|
||||||
b_batch_strides,
|
|
||||||
encoder);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Invoke cublasLt
|
|
||||||
CublasGemm gemm(
|
|
||||||
cu::device(s.device),
|
|
||||||
a.dtype(),
|
|
||||||
a_transposed,
|
|
||||||
M,
|
|
||||||
K,
|
|
||||||
lda,
|
|
||||||
b_transposed,
|
|
||||||
K,
|
|
||||||
N,
|
|
||||||
ldb,
|
|
||||||
batch_shape.back(),
|
|
||||||
a_batch_strides.back(),
|
|
||||||
b_batch_strides.back());
|
|
||||||
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -136,6 +152,27 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
||||||
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Dispatch to GEMM with epilogue or AddMM
|
||||||
|
|
||||||
|
if (beta_ == 1 && c.strides(-1) == 1 && c.data_size() == out.shape(-1)) {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
gemm_and_bias(
|
||||||
|
encoder,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
a_transposed,
|
||||||
|
lda,
|
||||||
|
b_transposed,
|
||||||
|
ldb,
|
||||||
|
out,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
c.data<void>());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
int64_t ldc;
|
int64_t ldc;
|
||||||
{
|
{
|
||||||
auto stx = c.strides()[c.ndim() - 2];
|
auto stx = c.strides()[c.ndim() - 2];
|
||||||
@@ -177,7 +214,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Invoke cublasLt
|
// Invoke cublasLt with AddMM settings
|
||||||
|
|
||||||
CublasGemm gemm(
|
CublasGemm gemm(
|
||||||
cu::device(s.device),
|
cu::device(s.device),
|
||||||
|
Reference in New Issue
Block a user