mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Redirect steel_gemm_regular
This commit is contained in:
parent
96a7017442
commit
13eccfa887
@ -164,11 +164,13 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
||||
wn = 2; \
|
||||
}
|
||||
|
||||
void steel_matmul_regular(
|
||||
template <bool CHECK_AB>
|
||||
void steel_matmul_regular_axpby(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
@ -184,7 +186,10 @@ void steel_matmul_regular(
|
||||
Strides batch_strides,
|
||||
int64_t A_batch_stride,
|
||||
int64_t B_batch_stride,
|
||||
int64_t matrix_stride_out) {
|
||||
int64_t matrix_stride_out,
|
||||
int64_t C_batch_stride /* = 0*/,
|
||||
float alpha /* = 1.0f */,
|
||||
float beta /* = 0.0f */) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
// Determine dispatch kernel
|
||||
|
@ -6,7 +6,34 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void steel_matmul_regular(
|
||||
template <bool CHECK_AB = true>
|
||||
void steel_matmul_regular_axpby(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int batch_size_out,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldd,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies,
|
||||
Shape batch_shape,
|
||||
Strides batch_strides,
|
||||
int64_t A_batch_stride,
|
||||
int64_t B_batch_stride,
|
||||
int64_t matrix_stride_out,
|
||||
int64_t C_batch_stride = 0,
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f);
|
||||
|
||||
inline void steel_matmul_regular(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
@ -26,7 +53,30 @@ void steel_matmul_regular(
|
||||
Strides batch_strides,
|
||||
int64_t A_batch_stride,
|
||||
int64_t B_batch_stride,
|
||||
int64_t matrix_stride_out);
|
||||
int64_t matrix_stride_out) {
|
||||
return steel_matmul_regular_axpby<false>(
|
||||
/* const Stream& s = */ s,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const array& a = */ a,
|
||||
/* const array& b = */ b,
|
||||
/* const array& c = */ b,
|
||||
/* array& out = */ out,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* int batch_size_out = */ batch_size_out,
|
||||
/* int lda = */ lda,
|
||||
/* int ldb = */ ldb,
|
||||
/* int ldd = */ ldd,
|
||||
/* bool transpose_a = */ transpose_a,
|
||||
/* bool transpose_b = */ transpose_b,
|
||||
/* std::vector<array>& copies = */ copies,
|
||||
/* Shape batch_shape = */ batch_shape,
|
||||
/* Strides batch_strides = */ batch_strides,
|
||||
/* int64_t A_batch_stride = */ A_batch_stride,
|
||||
/* int64_t B_batch_stride = */ B_batch_stride,
|
||||
/* int64_t matrix_stride_out = */ matrix_stride_out);
|
||||
}
|
||||
|
||||
void steel_matmul(
|
||||
const Stream& s,
|
||||
|
Loading…
Reference in New Issue
Block a user