Redirect steel_gemm_regular

This commit is contained in:
Jagrit Digani 2025-06-11 08:49:07 -07:00
parent 96a7017442
commit 13eccfa887
2 changed files with 59 additions and 4 deletions

View File

@ -164,11 +164,13 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
wn = 2; \ wn = 2; \
} }
void steel_matmul_regular( template <bool CHECK_AB>
void steel_matmul_regular_axpby(
const Stream& s, const Stream& s,
metal::Device& d, metal::Device& d,
const array& a, const array& a,
const array& b, const array& b,
const array& c,
array& out, array& out,
int M, int M,
int N, int N,
@ -184,7 +186,10 @@ void steel_matmul_regular(
Strides batch_strides, Strides batch_strides,
int64_t A_batch_stride, int64_t A_batch_stride,
int64_t B_batch_stride, int64_t B_batch_stride,
int64_t matrix_stride_out) { int64_t matrix_stride_out,
int64_t C_batch_stride /* = 0*/,
float alpha /* = 1.0f */,
float beta /* = 0.0f */) {
using namespace mlx::steel; using namespace mlx::steel;
// Determine dispatch kernel // Determine dispatch kernel

View File

@ -6,7 +6,34 @@
namespace mlx::core { namespace mlx::core {
void steel_matmul_regular( template <bool CHECK_AB = true>
void steel_matmul_regular_axpby(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
const array& c,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
int ldd,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies,
Shape batch_shape,
Strides batch_strides,
int64_t A_batch_stride,
int64_t B_batch_stride,
int64_t matrix_stride_out,
int64_t C_batch_stride = 0,
float alpha = 1.0f,
float beta = 0.0f);
inline void steel_matmul_regular(
const Stream& s, const Stream& s,
metal::Device& d, metal::Device& d,
const array& a, const array& a,
@ -26,7 +53,30 @@ void steel_matmul_regular(
Strides batch_strides, Strides batch_strides,
int64_t A_batch_stride, int64_t A_batch_stride,
int64_t B_batch_stride, int64_t B_batch_stride,
int64_t matrix_stride_out); int64_t matrix_stride_out) {
return steel_matmul_regular_axpby<false>(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* const array& c = */ b,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int batch_size_out = */ batch_size_out,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* int ldd = */ ldd,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ batch_shape,
/* Strides batch_strides = */ batch_strides,
/* int64_t A_batch_stride = */ A_batch_stride,
/* int64_t B_batch_stride = */ B_batch_stride,
/* int64_t matrix_stride_out = */ matrix_stride_out);
}
void steel_matmul( void steel_matmul(
const Stream& s, const Stream& s,