Rearrange steel_gemm_regular

This commit is contained in:
Jagrit Digani 2025-06-11 08:38:52 -07:00
parent c2f1c2a338
commit 96a7017442
3 changed files with 26 additions and 26 deletions

View File

@ -155,26 +155,26 @@ void explicit_gemm_conv_group_ND_gpu(
// Perform gemm // Perform gemm
std::vector<array> copies = {in_unfolded, wt_transpose}; std::vector<array> copies = {in_unfolded, wt_transpose};
return steel_matmul_regular( return steel_matmul_regular(
s, /* const Stream& s = */ s,
d, /* Device& d = */ d,
/* a = */ in_unfolded, /* const array& a = */ in_unfolded,
/* b = */ wt_transpose, /* const array& b = */ wt_transpose,
/* c = */ out, /* array& c = */ out,
/* M = */ implicit_M, /* int M = */ implicit_M,
/* N = */ implicit_N, /* int N = */ implicit_N,
/* K = */ implicit_K, /* int K = */ implicit_K,
/* batch_size_out = */ groups, /* int batch_size_out = */ groups,
/* a_cols = */ implicit_K * groups, /* int lda = */ implicit_K * groups,
/* b_cols = */ implicit_K, /* int ldb = */ implicit_K,
/* out_cols = */ implicit_N * groups, /* int ldd = */ implicit_N * groups,
/* a_transposed = */ false, /* bool transpose_a = */ false,
/* b_transposed = */ true, /* bool transpose_b = */ true,
/* batch_shape = */ {1}, /* std::vector<array>& copies = */ copies,
/* batch_strides = */ {0}, /* Shape batch_shape = */ {1},
/* A_batch_strides = */ size_t(implicit_K), /* Strides batch_strides = */ {0},
/* B_batch_strides = */ size_t(implicit_N) * implicit_K, /* int64_t A_batch_strides = */ int64_t(implicit_K),
/* matrix_stride_out = */ size_t(implicit_N), /* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K,
/*copies = */ copies); /* int64_t matrix_stride_out = */ int64_t(implicit_N));
} }
void implicit_gemm_conv_2D_gpu( void implicit_gemm_conv_2D_gpu(

View File

@ -179,12 +179,12 @@ void steel_matmul_regular(
int ldd, int ldd,
bool transpose_a, bool transpose_a,
bool transpose_b, bool transpose_b,
std::vector<array>& copies,
Shape batch_shape, Shape batch_shape,
Strides batch_strides, Strides batch_strides,
int64_t A_batch_stride, int64_t A_batch_stride,
int64_t B_batch_stride, int64_t B_batch_stride,
int64_t matrix_stride_out, int64_t matrix_stride_out) {
std::vector<array>& copies) {
using namespace mlx::steel; using namespace mlx::steel;
// Determine dispatch kernel // Determine dispatch kernel
@ -563,12 +563,12 @@ void steel_matmul(
/* int ldd = */ N, /* int ldd = */ N,
/* bool transpose_a = */ transpose_a, /* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b, /* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ std::move(batch_shape), /* Shape batch_shape = */ std::move(batch_shape),
/* Strides batch_strides = */ std::move(batch_strides), /* Strides batch_strides = */ std::move(batch_strides),
/* int64_t A_batch_stride = */ A_batch_stride.back(), /* int64_t A_batch_stride = */ A_batch_stride.back(),
/* int64_t B_batch_stride = */ B_batch_stride.back(), /* int64_t B_batch_stride = */ B_batch_stride.back(),
/* int64_t matrix_stride_out = */ matrix_stride_out, /* int64_t matrix_stride_out = */ matrix_stride_out);
/* std::vector<array>& copies = */ copies);
} }
template <bool CHECK_AB = true> template <bool CHECK_AB = true>

View File

@ -21,12 +21,12 @@ void steel_matmul_regular(
int ldd, int ldd,
bool transpose_a, bool transpose_a,
bool transpose_b, bool transpose_b,
std::vector<array>& copies,
Shape batch_shape, Shape batch_shape,
Strides batch_strides, Strides batch_strides,
int64_t A_batch_stride, int64_t A_batch_stride,
int64_t B_batch_stride, int64_t B_batch_stride,
int64_t matrix_stride_out, int64_t matrix_stride_out);
std::vector<array>& copies);
void steel_matmul( void steel_matmul(
const Stream& s, const Stream& s,