mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Update addmm
This commit is contained in:
parent
b3013042ca
commit
dd5e833068
@ -632,106 +632,6 @@ void steel_matmul_axpby(
|
|||||||
/* float beta = */ beta);
|
/* float beta = */ beta);
|
||||||
}
|
}
|
||||||
|
|
||||||
// void steel_matmul(
|
|
||||||
// const Stream& s,
|
|
||||||
// metal::Device& d,
|
|
||||||
// const array& a,
|
|
||||||
// const array& b,
|
|
||||||
// array& out,
|
|
||||||
// int M,
|
|
||||||
// int N,
|
|
||||||
// int K,
|
|
||||||
// int batch_size_out,
|
|
||||||
// int lda,
|
|
||||||
// int ldb,
|
|
||||||
// bool transpose_a,
|
|
||||||
// bool transpose_b,
|
|
||||||
// std::vector<array>& copies,
|
|
||||||
// Shape batch_shape /* = {} */,
|
|
||||||
// Strides A_batch_stride /* = {} */,
|
|
||||||
// Strides B_batch_stride /* = {} */) {
|
|
||||||
|
|
||||||
// return
|
|
||||||
|
|
||||||
// using namespace mlx::steel;
|
|
||||||
|
|
||||||
// if (batch_shape.empty()) {
|
|
||||||
// /////////////////////////////////////////////////////////////////////////////
|
|
||||||
// // Check and collapse batch dimensions
|
|
||||||
// auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b);
|
|
||||||
|
|
||||||
// batch_shape = batch_shape_;
|
|
||||||
// A_batch_stride = A_bstride_;
|
|
||||||
// B_batch_stride = B_bstride_;
|
|
||||||
// // Collapse batches into M if needed
|
|
||||||
// if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
|
|
||||||
// a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
|
|
||||||
// B_batch_stride.back() == 0) {
|
|
||||||
// M *= batch_shape.back();
|
|
||||||
// batch_size_out = 1;
|
|
||||||
|
|
||||||
// A_batch_stride = {0};
|
|
||||||
// B_batch_stride = {0};
|
|
||||||
// batch_shape = {1};
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// size_t matrix_stride_out = size_t(M) * N;
|
|
||||||
|
|
||||||
// /////////////////////////////////////////////////////////////////////////////
|
|
||||||
// // Split K specialization
|
|
||||||
|
|
||||||
// int _tm = M / 16;
|
|
||||||
// int _tn = N / 16;
|
|
||||||
// int _tk = K / 16;
|
|
||||||
|
|
||||||
// if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) {
|
|
||||||
// return steel_gemm_splitk(
|
|
||||||
// /* const Stream& s = */ s,
|
|
||||||
// /* metal::Device& d = */ d,
|
|
||||||
// /* const array& a = */ a,
|
|
||||||
// /* const array& b = */ b,
|
|
||||||
// /* array& out = */ out,
|
|
||||||
// /* int M = */ M,
|
|
||||||
// /* int N = */ N,
|
|
||||||
// /* int K = */ K,
|
|
||||||
// /* int batch_size_out = */ batch_size_out,
|
|
||||||
// /* int lda = */ lda,
|
|
||||||
// /* int ldb = */ ldb,
|
|
||||||
// /* bool transpose_a = */ transpose_a,
|
|
||||||
// /* bool transpose_b = */ transpose_b,
|
|
||||||
// /* std::vector<array>& copies = */ copies);
|
|
||||||
// }
|
|
||||||
|
|
||||||
// /////////////////////////////////////////////////////////////////////////////
|
|
||||||
// // Regular kernel dispatch
|
|
||||||
// auto batch_strides = A_batch_stride;
|
|
||||||
// batch_strides.insert(
|
|
||||||
// batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
|
|
||||||
|
|
||||||
// return steel_matmul_regular(
|
|
||||||
// /* const Stream& s = */ s,
|
|
||||||
// /* metal::Device& d = */ d,
|
|
||||||
// /* const array& a = */ a,
|
|
||||||
// /* const array& b = */ b,
|
|
||||||
// /* array& out = */ out,
|
|
||||||
// /* int M = */ M,
|
|
||||||
// /* int N = */ N,
|
|
||||||
// /* int K = */ K,
|
|
||||||
// /* int batch_size_out = */ batch_size_out,
|
|
||||||
// /* int lda = */ lda,
|
|
||||||
// /* int ldb = */ ldb,
|
|
||||||
// /* int ldd = */ N,
|
|
||||||
// /* bool transpose_a = */ transpose_a,
|
|
||||||
// /* bool transpose_b = */ transpose_b,
|
|
||||||
// /* std::vector<array>& copies = */ copies,
|
|
||||||
// /* Shape batch_shape = */ std::move(batch_shape),
|
|
||||||
// /* Strides batch_strides = */ std::move(batch_strides),
|
|
||||||
// /* int64_t A_batch_stride = */ A_batch_stride.back(),
|
|
||||||
// /* int64_t B_batch_stride = */ B_batch_stride.back(),
|
|
||||||
// /* int64_t matrix_stride_out = */ matrix_stride_out);
|
|
||||||
// }
|
|
||||||
|
|
||||||
template <bool CHECK_AB = true>
|
template <bool CHECK_AB = true>
|
||||||
void gemv_axbpy(
|
void gemv_axbpy(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
@ -1108,46 +1008,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
/* float beta = */ beta_);
|
/* float beta = */ beta_);
|
||||||
}
|
}
|
||||||
|
|
||||||
using namespace mlx::steel;
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Split K specialization
|
|
||||||
|
|
||||||
int _tm = M / 16;
|
|
||||||
int _tn = N / 16;
|
|
||||||
int _tk = K / 16;
|
|
||||||
|
|
||||||
if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) {
|
|
||||||
return steel_gemm_splitk_axpby(
|
|
||||||
/* const Stream& s = */ s,
|
|
||||||
/* metal::Device& d = */ d,
|
|
||||||
/* const array& a = */ a,
|
|
||||||
/* const array& b = */ b,
|
|
||||||
/* const array& c = */ c,
|
|
||||||
/* array& out = */ out,
|
|
||||||
/* int M = */ M,
|
|
||||||
/* int N = */ N,
|
|
||||||
/* int K = */ K,
|
|
||||||
/* int batch_size_out = */ batch_size_out,
|
|
||||||
/* int lda = */ lda,
|
|
||||||
/* int ldb = */ ldb,
|
|
||||||
/* bool transpose_a = */ transpose_a,
|
|
||||||
/* bool transpose_b = */ transpose_b,
|
|
||||||
/* std::vector<array>& copies = */ copies,
|
|
||||||
/* float alpha = */ alpha_,
|
|
||||||
/* float beta = */ beta_);
|
|
||||||
}
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Regular addmm dispatch
|
// Regular addmm dispatch
|
||||||
|
|
||||||
Strides batch_strides = A_batch_stride;
|
return steel_matmul_axpby(
|
||||||
batch_strides.insert(
|
|
||||||
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
|
|
||||||
batch_strides.insert(
|
|
||||||
batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end());
|
|
||||||
|
|
||||||
return steel_matmul_regular_axpby(
|
|
||||||
/* const Stream& s = */ s,
|
/* const Stream& s = */ s,
|
||||||
/* metal::Device& d = */ d,
|
/* metal::Device& d = */ d,
|
||||||
/* const array& a = */ a,
|
/* const array& a = */ a,
|
||||||
@ -1160,16 +1024,13 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
/* int batch_size_out = */ batch_size_out,
|
/* int batch_size_out = */ batch_size_out,
|
||||||
/* int lda = */ lda,
|
/* int lda = */ lda,
|
||||||
/* int ldb = */ ldb,
|
/* int ldb = */ ldb,
|
||||||
/* int ldd = */ ldd,
|
|
||||||
/* bool transpose_a = */ transpose_a,
|
/* bool transpose_a = */ transpose_a,
|
||||||
/* bool transpose_b = */ transpose_b,
|
/* bool transpose_b = */ transpose_b,
|
||||||
/* std::vector<array>& copies = */ copies,
|
/* std::vector<array>& copies = */ copies,
|
||||||
/* Shape batch_shape = */ std::move(batch_shape),
|
/* Shape batch_shape = */ batch_shape,
|
||||||
/* Strides batch_strides = */ std::move(batch_strides),
|
/* Strides A_batch_stride = */ A_batch_stride,
|
||||||
/* int64_t A_batch_stride = */ A_batch_stride.back(),
|
/* Strides B_batch_stride = */ B_batch_stride,
|
||||||
/* int64_t B_batch_stride = */ B_batch_stride.back(),
|
/* Strides B_batch_stride = */ C_batch_stride,
|
||||||
/* int64_t matrix_stride_out = */ int64_t(M) * ldd,
|
|
||||||
/* int64_t C_batch_stride = */ C_batch_stride.back(),
|
|
||||||
/* float alpha = */ alpha_,
|
/* float alpha = */ alpha_,
|
||||||
/* float beta = */ beta_);
|
/* float beta = */ beta_);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user