Update addmm

This commit is contained in:
Jagrit Digani 2025-06-11 09:30:49 -07:00
parent b3013042ca
commit dd5e833068

View File

@ -632,106 +632,6 @@ void steel_matmul_axpby(
/* float beta = */ beta);
}
// void steel_matmul(
// const Stream& s,
// metal::Device& d,
// const array& a,
// const array& b,
// array& out,
// int M,
// int N,
// int K,
// int batch_size_out,
// int lda,
// int ldb,
// bool transpose_a,
// bool transpose_b,
// std::vector<array>& copies,
// Shape batch_shape /* = {} */,
// Strides A_batch_stride /* = {} */,
// Strides B_batch_stride /* = {} */) {
// return
// using namespace mlx::steel;
// if (batch_shape.empty()) {
// /////////////////////////////////////////////////////////////////////////////
// // Check and collapse batch dimensions
// auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b);
// batch_shape = batch_shape_;
// A_batch_stride = A_bstride_;
// B_batch_stride = B_bstride_;
// // Collapse batches into M if needed
// if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
// a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
// B_batch_stride.back() == 0) {
// M *= batch_shape.back();
// batch_size_out = 1;
// A_batch_stride = {0};
// B_batch_stride = {0};
// batch_shape = {1};
// }
// }
// size_t matrix_stride_out = size_t(M) * N;
// /////////////////////////////////////////////////////////////////////////////
// // Split K specialization
// int _tm = M / 16;
// int _tn = N / 16;
// int _tk = K / 16;
// if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) {
// return steel_gemm_splitk(
// /* const Stream& s = */ s,
// /* metal::Device& d = */ d,
// /* const array& a = */ a,
// /* const array& b = */ b,
// /* array& out = */ out,
// /* int M = */ M,
// /* int N = */ N,
// /* int K = */ K,
// /* int batch_size_out = */ batch_size_out,
// /* int lda = */ lda,
// /* int ldb = */ ldb,
// /* bool transpose_a = */ transpose_a,
// /* bool transpose_b = */ transpose_b,
// /* std::vector<array>& copies = */ copies);
// }
// /////////////////////////////////////////////////////////////////////////////
// // Regular kernel dispatch
// auto batch_strides = A_batch_stride;
// batch_strides.insert(
// batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
// return steel_matmul_regular(
// /* const Stream& s = */ s,
// /* metal::Device& d = */ d,
// /* const array& a = */ a,
// /* const array& b = */ b,
// /* array& out = */ out,
// /* int M = */ M,
// /* int N = */ N,
// /* int K = */ K,
// /* int batch_size_out = */ batch_size_out,
// /* int lda = */ lda,
// /* int ldb = */ ldb,
// /* int ldd = */ N,
// /* bool transpose_a = */ transpose_a,
// /* bool transpose_b = */ transpose_b,
// /* std::vector<array>& copies = */ copies,
// /* Shape batch_shape = */ std::move(batch_shape),
// /* Strides batch_strides = */ std::move(batch_strides),
// /* int64_t A_batch_stride = */ A_batch_stride.back(),
// /* int64_t B_batch_stride = */ B_batch_stride.back(),
// /* int64_t matrix_stride_out = */ matrix_stride_out);
// }
template <bool CHECK_AB = true>
void gemv_axbpy(
const Stream& s,
@ -1108,46 +1008,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/* float beta = */ beta_);
}
using namespace mlx::steel;
/////////////////////////////////////////////////////////////////////////////
// Split K specialization
int _tm = M / 16;
int _tn = N / 16;
int _tk = K / 16;
if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) {
return steel_gemm_splitk_axpby(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* const array& c = */ c,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int batch_size_out = */ batch_size_out,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies,
/* float alpha = */ alpha_,
/* float beta = */ beta_);
}
/////////////////////////////////////////////////////////////////////////////
// Regular addmm dispatch
Strides batch_strides = A_batch_stride;
batch_strides.insert(
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
batch_strides.insert(
batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end());
return steel_matmul_regular_axpby(
return steel_matmul_axpby(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
@ -1160,16 +1024,13 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/* int batch_size_out = */ batch_size_out,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* int ldd = */ ldd,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ std::move(batch_shape),
/* Strides batch_strides = */ std::move(batch_strides),
/* int64_t A_batch_stride = */ A_batch_stride.back(),
/* int64_t B_batch_stride = */ B_batch_stride.back(),
/* int64_t matrix_stride_out = */ int64_t(M) * ldd,
/* int64_t C_batch_stride = */ C_batch_stride.back(),
/* Shape batch_shape = */ batch_shape,
/* Strides A_batch_stride = */ A_batch_stride,
/* Strides B_batch_stride = */ B_batch_stride,
/* Strides B_batch_stride = */ C_batch_stride,
/* float alpha = */ alpha_,
/* float beta = */ beta_);
}