Redirect steel_gemm

This commit is contained in:
Jagrit Digani 2025-06-11 09:26:07 -07:00
parent 3ad2574d1a
commit 2e49b57ea5
2 changed files with 210 additions and 25 deletions

View File

@ -494,11 +494,13 @@ inline void steel_gemm_splitk(
/* std::vector<array>& copies = */ copies);
}
void steel_matmul(
template <bool CHECK_AB>
void steel_matmul_axpby(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
const array& c,
array& out,
int M,
int N,
@ -511,32 +513,56 @@ void steel_matmul(
std::vector<array>& copies,
Shape batch_shape /* = {} */,
Strides A_batch_stride /* = {} */,
Strides B_batch_stride /* = {} */) {
Strides B_batch_stride /* = {} */,
Strides C_batch_stride /* = {} */,
float alpha /* = 1.0f */,
float beta /* = 0.0f */) {
using namespace mlx::steel;
if (batch_shape.empty()) {
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b);
if constexpr (CHECK_AB) {
auto [batch_shape_, A_bstride_, B_bstride_, C_bstride_] =
collapse_batches(a, b, c);
batch_shape = batch_shape_;
A_batch_stride = A_bstride_;
B_batch_stride = B_bstride_;
// Collapse batches into M if needed
if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
B_batch_stride.back() == 0) {
M *= batch_shape.back();
batch_size_out = 1;
batch_shape = batch_shape_;
A_batch_stride = A_bstride_;
B_batch_stride = B_bstride_;
C_batch_stride = C_bstride_;
// Collapse batches into M if needed
if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
C_batch_stride.back() == M * c.strides()[c.ndim() - 2] &&
B_batch_stride.back() == 0) {
M *= batch_shape.back();
batch_size_out = 1;
A_batch_stride = {0};
B_batch_stride = {0};
batch_shape = {1};
A_batch_stride = {0};
B_batch_stride = {0};
C_batch_stride = {0};
batch_shape = {1};
}
} else {
auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b);
batch_shape = batch_shape_;
A_batch_stride = A_bstride_;
B_batch_stride = B_bstride_;
// Collapse batches into M if needed
if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
B_batch_stride.back() == 0) {
M *= batch_shape.back();
batch_size_out = 1;
A_batch_stride = {0};
B_batch_stride = {0};
batch_shape = {1};
}
}
}
size_t matrix_stride_out = size_t(M) * N;
/////////////////////////////////////////////////////////////////////////////
// Split K specialization
@ -545,11 +571,12 @@ void steel_matmul(
int _tk = K / 16;
if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) {
return steel_gemm_splitk(
return steel_gemm_splitk_axpby<CHECK_AB>(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* const array& c = */ c,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
@ -559,7 +586,9 @@ void steel_matmul(
/* int ldb = */ ldb,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies);
/* std::vector<array>& copies = */ copies,
/* float alpha = */ alpha,
/* float beta = */ beta);
}
/////////////////////////////////////////////////////////////////////////////
@ -567,12 +596,21 @@ void steel_matmul(
auto batch_strides = A_batch_stride;
batch_strides.insert(
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
if (CHECK_AB && !C_batch_stride.empty()) {
batch_strides.insert(
batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end());
}
return steel_matmul_regular(
int64_t A_batch_stride_ = A_batch_stride.empty() ? 0 : A_batch_stride.back();
int64_t B_batch_stride_ = B_batch_stride.empty() ? 0 : B_batch_stride.back();
int64_t C_batch_stride_ = C_batch_stride.empty() ? 0 : C_batch_stride.back();
return steel_matmul_regular_axpby<CHECK_AB>(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* const array& c = */ c,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
@ -586,11 +624,114 @@ void steel_matmul(
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ std::move(batch_shape),
/* Strides batch_strides = */ std::move(batch_strides),
/* int64_t A_batch_stride = */ A_batch_stride.back(),
/* int64_t B_batch_stride = */ B_batch_stride.back(),
/* int64_t matrix_stride_out = */ matrix_stride_out);
/* int64_t A_batch_stride = */ A_batch_stride_,
/* int64_t B_batch_stride = */ B_batch_stride_,
/* int64_t matrix_stride_out = */ int64_t(M) * N,
/* int64_t C_batch_stride = */ C_batch_stride_,
/* float alpha = */ alpha,
/* float beta = */ beta);
}
// void steel_matmul(
// const Stream& s,
// metal::Device& d,
// const array& a,
// const array& b,
// array& out,
// int M,
// int N,
// int K,
// int batch_size_out,
// int lda,
// int ldb,
// bool transpose_a,
// bool transpose_b,
// std::vector<array>& copies,
// Shape batch_shape /* = {} */,
// Strides A_batch_stride /* = {} */,
// Strides B_batch_stride /* = {} */) {
// return
// using namespace mlx::steel;
// if (batch_shape.empty()) {
// /////////////////////////////////////////////////////////////////////////////
// // Check and collapse batch dimensions
// auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b);
// batch_shape = batch_shape_;
// A_batch_stride = A_bstride_;
// B_batch_stride = B_bstride_;
// // Collapse batches into M if needed
// if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
// a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
// B_batch_stride.back() == 0) {
// M *= batch_shape.back();
// batch_size_out = 1;
// A_batch_stride = {0};
// B_batch_stride = {0};
// batch_shape = {1};
// }
// }
// size_t matrix_stride_out = size_t(M) * N;
// /////////////////////////////////////////////////////////////////////////////
// // Split K specialization
// int _tm = M / 16;
// int _tn = N / 16;
// int _tk = K / 16;
// if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) {
// return steel_gemm_splitk(
// /* const Stream& s = */ s,
// /* metal::Device& d = */ d,
// /* const array& a = */ a,
// /* const array& b = */ b,
// /* array& out = */ out,
// /* int M = */ M,
// /* int N = */ N,
// /* int K = */ K,
// /* int batch_size_out = */ batch_size_out,
// /* int lda = */ lda,
// /* int ldb = */ ldb,
// /* bool transpose_a = */ transpose_a,
// /* bool transpose_b = */ transpose_b,
// /* std::vector<array>& copies = */ copies);
// }
// /////////////////////////////////////////////////////////////////////////////
// // Regular kernel dispatch
// auto batch_strides = A_batch_stride;
// batch_strides.insert(
// batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
// return steel_matmul_regular(
// /* const Stream& s = */ s,
// /* metal::Device& d = */ d,
// /* const array& a = */ a,
// /* const array& b = */ b,
// /* array& out = */ out,
// /* int M = */ M,
// /* int N = */ N,
// /* int K = */ K,
// /* int batch_size_out = */ batch_size_out,
// /* int lda = */ lda,
// /* int ldb = */ ldb,
// /* int ldd = */ N,
// /* bool transpose_a = */ transpose_a,
// /* bool transpose_b = */ transpose_b,
// /* std::vector<array>& copies = */ copies,
// /* Shape batch_shape = */ std::move(batch_shape),
// /* Strides batch_strides = */ std::move(batch_strides),
// /* int64_t A_batch_stride = */ A_batch_stride.back(),
// /* int64_t B_batch_stride = */ B_batch_stride.back(),
// /* int64_t matrix_stride_out = */ matrix_stride_out);
// }
template <bool CHECK_AB = true>
void gemv_axbpy(
const Stream& s,

View File

@ -78,7 +78,31 @@ inline void steel_matmul_regular(
/* int64_t matrix_stride_out = */ matrix_stride_out);
}
void steel_matmul(
template <bool CHECK_AB = true>
void steel_matmul_axpby(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
const array& c,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies,
Shape batch_shape = {},
Strides A_batch_stride = {},
Strides B_batch_stride = {},
Strides C_batch_stride = {},
float alpha = 1.0f,
float beta = 0.0f);
inline void steel_matmul(
const Stream& s,
metal::Device& d,
const array& a,
@ -95,6 +119,26 @@ void steel_matmul(
std::vector<array>& copies,
Shape batch_shape = {},
Strides A_batch_stride = {},
Strides B_batch_stride = {});
Strides B_batch_stride = {}) {
return steel_matmul_axpby<false>(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* const array& c = */ b,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int batch_size_out = */ batch_size_out,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ batch_shape,
/* Strides A_batch_stride = */ A_batch_stride,
/* Strides B_batch_stride = */ B_batch_stride);
}
} // namespace mlx::core