From 96a70174427bff79223ded1f1618c6c5cba0e330 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 11 Jun 2025 08:38:52 -0700 Subject: [PATCH] Rearrange steel_gemm_regular --- mlx/backend/metal/conv.cpp | 40 ++++++++++++++++++------------------ mlx/backend/metal/matmul.cpp | 8 ++++---- mlx/backend/metal/matmul.h | 4 ++-- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 697afa6a1..9eb6a6385 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -155,26 +155,26 @@ void explicit_gemm_conv_group_ND_gpu( // Perform gemm std::vector copies = {in_unfolded, wt_transpose}; return steel_matmul_regular( - s, - d, - /* a = */ in_unfolded, - /* b = */ wt_transpose, - /* c = */ out, - /* M = */ implicit_M, - /* N = */ implicit_N, - /* K = */ implicit_K, - /* batch_size_out = */ groups, - /* a_cols = */ implicit_K * groups, - /* b_cols = */ implicit_K, - /* out_cols = */ implicit_N * groups, - /* a_transposed = */ false, - /* b_transposed = */ true, - /* batch_shape = */ {1}, - /* batch_strides = */ {0}, - /* A_batch_strides = */ size_t(implicit_K), - /* B_batch_strides = */ size_t(implicit_N) * implicit_K, - /* matrix_stride_out = */ size_t(implicit_N), - /*copies = */ copies); + /* const Stream& s = */ s, + /* Device& d = */ d, + /* const array& a = */ in_unfolded, + /* const array& b = */ wt_transpose, + /* array& c = */ out, + /* int M = */ implicit_M, + /* int N = */ implicit_N, + /* int K = */ implicit_K, + /* int batch_size_out = */ groups, + /* int lda = */ implicit_K * groups, + /* int ldb = */ implicit_K, + /* int ldd = */ implicit_N * groups, + /* bool transpose_a = */ false, + /* bool transpose_b = */ true, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ {1}, + /* Strides batch_strides = */ {0}, + /* int64_t A_batch_strides = */ int64_t(implicit_K), + /* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K, + /* int64_t matrix_stride_out = */ int64_t(implicit_N)); } void implicit_gemm_conv_2D_gpu( diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index b5deaf0b3..7703022c7 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -179,12 +179,12 @@ void steel_matmul_regular( int ldd, bool transpose_a, bool transpose_b, + std::vector& copies, Shape batch_shape, Strides batch_strides, int64_t A_batch_stride, int64_t B_batch_stride, - int64_t matrix_stride_out, - std::vector& copies) { + int64_t matrix_stride_out) { using namespace mlx::steel; // Determine dispatch kernel @@ -563,12 +563,12 @@ void steel_matmul( /* int ldd = */ N, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, /* Shape batch_shape = */ std::move(batch_shape), /* Strides batch_strides = */ std::move(batch_strides), /* int64_t A_batch_stride = */ A_batch_stride.back(), /* int64_t B_batch_stride = */ B_batch_stride.back(), - /* int64_t matrix_stride_out = */ matrix_stride_out, - /* std::vector& copies = */ copies); + /* int64_t matrix_stride_out = */ matrix_stride_out); } template diff --git a/mlx/backend/metal/matmul.h b/mlx/backend/metal/matmul.h index 09ffe05a8..fb37ae6b2 100644 --- a/mlx/backend/metal/matmul.h +++ b/mlx/backend/metal/matmul.h @@ -21,12 +21,12 @@ void steel_matmul_regular( int ldd, bool transpose_a, bool transpose_b, + std::vector& copies, Shape batch_shape, Strides batch_strides, int64_t A_batch_stride, int64_t B_batch_stride, - int64_t matrix_stride_out, - std::vector& copies); + int64_t matrix_stride_out); void steel_matmul( const Stream& s,