Rearrange steel_gemm_regular

This commit is contained in:
Jagrit Digani 2025-06-11 08:38:52 -07:00
parent c2f1c2a338
commit 96a7017442
3 changed files with 26 additions and 26 deletions

View File

@ -155,26 +155,26 @@ void explicit_gemm_conv_group_ND_gpu(
// Perform gemm
std::vector<array> copies = {in_unfolded, wt_transpose};
return steel_matmul_regular(
s,
d,
/* a = */ in_unfolded,
/* b = */ wt_transpose,
/* c = */ out,
/* M = */ implicit_M,
/* N = */ implicit_N,
/* K = */ implicit_K,
/* batch_size_out = */ groups,
/* a_cols = */ implicit_K * groups,
/* b_cols = */ implicit_K,
/* out_cols = */ implicit_N * groups,
/* a_transposed = */ false,
/* b_transposed = */ true,
/* batch_shape = */ {1},
/* batch_strides = */ {0},
/* A_batch_strides = */ size_t(implicit_K),
/* B_batch_strides = */ size_t(implicit_N) * implicit_K,
/* matrix_stride_out = */ size_t(implicit_N),
/*copies = */ copies);
/* const Stream& s = */ s,
/* Device& d = */ d,
/* const array& a = */ in_unfolded,
/* const array& b = */ wt_transpose,
/* array& c = */ out,
/* int M = */ implicit_M,
/* int N = */ implicit_N,
/* int K = */ implicit_K,
/* int batch_size_out = */ groups,
/* int lda = */ implicit_K * groups,
/* int ldb = */ implicit_K,
/* int ldd = */ implicit_N * groups,
/* bool transpose_a = */ false,
/* bool transpose_b = */ true,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ {1},
/* Strides batch_strides = */ {0},
/* int64_t A_batch_strides = */ int64_t(implicit_K),
/* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K,
/* int64_t matrix_stride_out = */ int64_t(implicit_N));
}
void implicit_gemm_conv_2D_gpu(

View File

@ -179,12 +179,12 @@ void steel_matmul_regular(
int ldd,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies,
Shape batch_shape,
Strides batch_strides,
int64_t A_batch_stride,
int64_t B_batch_stride,
int64_t matrix_stride_out,
std::vector<array>& copies) {
int64_t matrix_stride_out) {
using namespace mlx::steel;
// Determine dispatch kernel
@ -563,12 +563,12 @@ void steel_matmul(
/* int ldd = */ N,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ std::move(batch_shape),
/* Strides batch_strides = */ std::move(batch_strides),
/* int64_t A_batch_stride = */ A_batch_stride.back(),
/* int64_t B_batch_stride = */ B_batch_stride.back(),
/* int64_t matrix_stride_out = */ matrix_stride_out,
/* std::vector<array>& copies = */ copies);
/* int64_t matrix_stride_out = */ matrix_stride_out);
}
template <bool CHECK_AB = true>

View File

@ -21,12 +21,12 @@ void steel_matmul_regular(
int ldd,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies,
Shape batch_shape,
Strides batch_strides,
int64_t A_batch_stride,
int64_t B_batch_stride,
int64_t matrix_stride_out,
std::vector<array>& copies);
int64_t matrix_stride_out);
void steel_matmul(
const Stream& s,