From f828c5b5ae186798e986b707c72484c548d8b079 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Tue, 10 Jun 2025 16:53:23 -0700 Subject: [PATCH 01/13] Refactor gemv into a function --- mlx/backend/metal/matmul.cpp | 413 ++++++++++++++++++----------------- 1 file changed, 214 insertions(+), 199 deletions(-) diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index ed96d37ea..8f4fa8e61 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -470,6 +470,178 @@ void steel_matmul( copies); } +template +void gemv_axbpy( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape = {}, + Strides A_batch_stride = {}, + Strides B_batch_stride = {}, + Strides C_batch_stride = {}, + float alpha = 1.0f, + float beta = 0.0f) { + // Collect problem info + bool is_b_matrix = N != 1; + + auto& mat = is_b_matrix ? b : a; + auto& vec = is_b_matrix ? a : b; + bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a; + int in_vector_len = K; + int out_vector_len = is_b_matrix ? N : M; + + int mat_cols = transpose_mat ? out_vector_len : in_vector_len; + int mat_rows = transpose_mat ? in_vector_len : out_vector_len; + int mat_ld = is_b_matrix ? ldb : lda; + + auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; + auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride; + + int stride_mat = batch_strides_mat.back(); + int stride_vec = batch_strides_vec.back(); + + // Determine if inputs have simple batching / broadcasting + bool contiguous_kernel = (batch_shape.size() == 1); + + int batch_ndim = batch_shape.size(); + + // Determine dispatch kernel + int tm = 4, tn = 4; + int sm = 1, sn = 32; + int bm = 1, bn = 1; + int n_out_per_tgp; + std::ostringstream kname; + + if (transpose_mat) { + if (in_vector_len >= 8192 && out_vector_len >= 2048) { + sm = 4; + sn = 8; + } else { + sm = 8; + sn = 4; + } + + if (out_vector_len >= 2048) { + bn = 16; + } else if (out_vector_len >= 512) { + bn = 4; + } else { + bn = 2; + } + + // Specialized kernel for very small outputs + tn = out_vector_len < tn ? 1 : tn; + + n_out_per_tgp = bn * sn * tn; + kname << "gemv_t_" << type_to_name(out); + + } else { + bm = out_vector_len >= 4096 ? 8 : 4; + sn = 32; + + // Specialized kernel for very small outputs + tm = out_vector_len < tm ? 1 : tm; + + n_out_per_tgp = bm * sm * tm; + kname << "gemv_" << type_to_name(out); + } + + const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f); + + kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" + << tm << "_tn" << tn; + kname << "_nc" << !contiguous_kernel << "_axpby" << do_axpby; + + // Encode and dispatch kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder.set_compute_pipeline_state(kernel); + + int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; + MTL::Size group_dims = MTL::Size(32, bn, bm); + MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); + + compute_encoder.set_input_array(mat, 0); + compute_encoder.set_input_array(vec, 1); + compute_encoder.set_output_array(out, 3); + + compute_encoder.set_bytes(in_vector_len, 4); + compute_encoder.set_bytes(out_vector_len, 5); + compute_encoder.set_bytes(mat_ld, 6); + + compute_encoder.set_bytes(batch_ndim, 9); + compute_encoder.set_vector_bytes(batch_shape, 10); + compute_encoder.set_vector_bytes(batch_strides_vec, 11); + compute_encoder.set_vector_bytes(batch_strides_mat, 12); + + if (do_axpby) { + compute_encoder.set_input_array(c, 2); + + compute_encoder.set_bytes(alpha, 7); + compute_encoder.set_bytes(beta, 8); + + compute_encoder.set_vector_bytes(C_batch_stride, 13); + + int bias_stride = c.strides()[c.ndim() - 1]; + compute_encoder.set_bytes(bias_stride, 14); + } + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + + d.add_temporaries(std::move(copies), s.index); +} + +inline void gemv( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape = {}, + Strides A_batch_stride = {}, + Strides B_batch_stride = {}) { + return gemv_axbpy( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride); +} + void Matmul::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (!issubdtype(out.dtype(), floating)) { @@ -528,102 +700,26 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { // Route to gemv if needed if (std::min(M, N) == 1) { - // Collect problem info - bool is_b_matrix = N != 1; - - auto& mat = is_b_matrix ? b : a; - auto& vec = is_b_matrix ? a : b; - bool transpose_mat = is_b_matrix ? !b_transposed : a_transposed; - int in_vector_len = K; - int out_vector_len = is_b_matrix ? N : M; - - int mat_cols = transpose_mat ? out_vector_len : in_vector_len; - int mat_rows = transpose_mat ? in_vector_len : out_vector_len; - int mat_ld = is_b_matrix ? b_cols : a_cols; - - auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; - auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride; - - int stride_mat = batch_strides_mat.back(); - int stride_vec = batch_strides_vec.back(); - - // Determine if inputs have simple batching / broadcasting - bool contiguous_kernel = (batch_shape.size() == 1); - - int batch_ndim = batch_shape.size(); - - // Determine dispatch kernel - int tm = 4, tn = 4; - int sm = 1, sn = 32; - int bm = 1, bn = 1; - int n_out_per_tgp; - std::ostringstream kname; - - if (transpose_mat) { - if (in_vector_len >= 8192 && out_vector_len >= 2048) { - sm = 4; - sn = 8; - } else { - sm = 8; - sn = 4; - } - - if (out_vector_len >= 2048) { - bn = 16; - } else if (out_vector_len >= 512) { - bn = 4; - } else { - bn = 2; - } - - // Specialized kernel for very small outputs - tn = out_vector_len < tn ? 1 : tn; - - n_out_per_tgp = bn * sn * tn; - kname << "gemv_t_" << type_to_name(out); - - } else { - bm = out_vector_len >= 4096 ? 8 : 4; - sn = 32; - - // Specialized kernel for very small outputs - tm = out_vector_len < tm ? 1 : tm; - - n_out_per_tgp = bm * sm * tm; - kname << "gemv_" << type_to_name(out); - } - - kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" - << tm << "_tn" << tn; - kname << "_nc" << !contiguous_kernel << "_axpby0"; - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); - compute_encoder.set_compute_pipeline_state(kernel); - - int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; - MTL::Size group_dims = MTL::Size(32, bn, bm); - MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); - - compute_encoder.set_input_array(mat, 0); - compute_encoder.set_input_array(vec, 1); - compute_encoder.set_output_array(out, 3); - - compute_encoder.set_bytes(in_vector_len, 4); - compute_encoder.set_bytes(out_vector_len, 5); - compute_encoder.set_bytes(mat_ld, 6); - - compute_encoder.set_bytes(batch_ndim, 9); - compute_encoder.set_vector_bytes(batch_shape, 10); - compute_encoder.set_vector_bytes(batch_strides_vec, 11); - compute_encoder.set_vector_bytes(batch_strides_mat, 12); - - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - d.add_temporaries(std::move(copies), s.index); - return; + return gemv( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ a_cols, + /* int ldb = */ b_cols, + /* bool transpose_a = */ a_transposed, + /* bool transpose_b = */ b_transposed, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride); } + ///////////////////////////////////////////////////////////////////////////// // Gemm specialization @@ -641,7 +737,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { /* int ldb = */ b_cols, /* bool transpose_a = */ a_transposed, /* bool transpose_b = */ b_transposed, - /* std::vector& = */ copies, + /* std::vector& copies = */ copies, /* Shape batch_shape = */ batch_shape, /* Strides A_batch_stride = */ A_batch_stride, /* Strides B_batch_stride = */ B_batch_stride); @@ -726,109 +822,28 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { // Route to gemv if needed if (std::min(M, N) == 1) { - // Collect problem info - bool is_b_matrix = N != 1; - - auto& mat = is_b_matrix ? b : a; - auto& vec = is_b_matrix ? a : b; - bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a; - int in_vector_len = K; - int out_vector_len = is_b_matrix ? N : M; - - int mat_cols = transpose_mat ? out_vector_len : in_vector_len; - int mat_rows = transpose_mat ? in_vector_len : out_vector_len; - int mat_ld = is_b_matrix ? b_cols : a_cols; - - auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; - auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride; - - int stride_mat = batch_strides_mat.back(); - int stride_vec = batch_strides_vec.back(); - - // Determine if inputs have simple batching / broadcasting - bool contiguous_kernel = (batch_shape.size() == 1); - - int batch_ndim = batch_shape.size(); - - // Determine dispatch kernel - int tm = 4, tn = 4; - int sm = 1, sn = 32; - int bm = 1, bn = 1; - int n_out_per_tgp; - std::ostringstream kname; - - if (transpose_mat) { - if (in_vector_len >= 8192 && out_vector_len >= 2048) { - sm = 4; - sn = 8; - } else { - sm = 8; - sn = 4; - } - - if (out_vector_len >= 2048) { - bn = 16; - } else if (out_vector_len >= 512) { - bn = 4; - } else { - bn = 2; - } - - // Specialized kernel for very small outputs - tn = out_vector_len < tn ? 1 : tn; - - n_out_per_tgp = bn * sn * tn; - kname << "gemv_t_" << type_to_name(out); - - } else { - bm = out_vector_len >= 4096 ? 8 : 4; - sn = 32; - - // Specialized kernel for very small outputs - tm = out_vector_len < tm ? 1 : tm; - - n_out_per_tgp = bm * sm * tm; - kname << "gemv_" << type_to_name(out); - } - - kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" - << tm << "_tn" << tn; - kname << "_nc" << !contiguous_kernel << "_axpby1"; - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); - compute_encoder.set_compute_pipeline_state(kernel); - - int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; - MTL::Size group_dims = MTL::Size(32, bn, bm); - MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); - - compute_encoder.set_input_array(mat, 0); - compute_encoder.set_input_array(vec, 1); - compute_encoder.set_input_array(c, 2); - compute_encoder.set_output_array(out, 3); - - compute_encoder.set_bytes(in_vector_len, 4); - compute_encoder.set_bytes(out_vector_len, 5); - compute_encoder.set_bytes(mat_ld, 6); - - compute_encoder.set_bytes(alpha_, 7); - compute_encoder.set_bytes(beta_, 8); - - compute_encoder.set_bytes(batch_ndim, 9); - compute_encoder.set_vector_bytes(batch_shape, 10); - compute_encoder.set_vector_bytes(batch_strides_vec, 11); - compute_encoder.set_vector_bytes(batch_strides_mat, 12); - compute_encoder.set_vector_bytes(C_batch_stride, 13); - - int bias_stride = c.strides()[c.ndim() - 1]; - compute_encoder.set_bytes(bias_stride, 14); - - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - d.add_temporaries(std::move(copies), s.index); - return; + return gemv_axbpy( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride, + /* Strides C_batch_stride = */ C_batch_stride, + /* float alpha = */ alpha_, + /* float beta = */ beta_); } using namespace mlx::steel; From d192587cdf4e3de4c82cbdef821d4538b7ee39c2 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 11 Jun 2025 07:29:36 -0700 Subject: [PATCH 02/13] Refactor splitk step 1 --- mlx/backend/metal/matmul.cpp | 283 ++++++++++++++++++++--------------- 1 file changed, 163 insertions(+), 120 deletions(-) diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 8f4fa8e61..4b3d82d24 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -295,6 +295,127 @@ void steel_matmul_regular( d.add_temporaries(std::move(copies), s.index); } +void steel_gemm_splitk( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies) { + using namespace mlx::steel; + + int _tm = M / 16; + int _tn = N / 16; + int _tk = K / 16; + + int bm = M < 40 ? 16 : 32; + int bn = N < 40 ? 16 : 32; + int bk = 16; + int wm = 2, wn = 2; + + int split_k_partitions = _tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16)); + int split_k_partition_stride = M * N; + int gemm_k_iterations = (K / bk) / split_k_partitions; + int split_k_partition_size = gemm_k_iterations * bk; + + array C_split({split_k_partitions, M, N}, float32, nullptr, {}); + C_split.set_data(allocator::malloc(C_split.nbytes())); + copies.push_back(C_split); + + bool mn_aligned = M % bm == 0 && N % bn == 0; + bool k_aligned = K % bk == 0; + std::ostringstream kname; + + // clang-format off + kname << "steel_gemm_splitk_" + << (transpose_a ? 't' : 'n') + << (transpose_b ? 't' : 'n') + << "_" << type_to_name(a) + << "_" << type_to_name(C_split) + << "_bm" << bm << "_bn" << bn << "_bk" << bk + << "_wm" << wm << "_wn" << wn + << "_MN_" << (mn_aligned ? "t" : "n") << "aligned" + << "_K_" << (k_aligned ? "t" : "n") << "aligned"; // clang-format on + + // Encode and dispatch gemm kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_steel_gemm_splitk_kernel( + /* metal::Device& d = */ d, + /* const std::string& kernel_name = */ kname.str(), + /* const array& in = */ a, + /* const array& out = */ C_split, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* int bm = */ bm, + /* int bn = */ bn, + /* int bk = */ bk, + /* int wm = */ wm, + /* int wn = */ wn, + /* bool mn_aligned = */ mn_aligned, + /* bool k_aligned = */ k_aligned); + + compute_encoder.set_compute_pipeline_state(kernel); + + int tn = (N + bn - 1) / bn; + int tm = (M + bm - 1) / bm; + + GEMMSpiltKParams params{ + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ lda, + /* const int ldb = */ ldb, + /* const int ldc = */ N, + /* const int tiles_n = */ tn, + /* const int tiles_m = */ tm, + /* const int split_k_partitions = */ split_k_partitions, + /* const int split_k_partition_stride = */ split_k_partition_stride, + /* const int split_k_partition_size = */ split_k_partition_size, + /* const int gemm_k_iterations_aligned = */ gemm_k_iterations}; + + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); + + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_output_array(C_split, 2); + + compute_encoder.set_bytes(params, 3); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + + // Do accum kernel + { + auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + + type_to_name(C_split); + + auto kernel = + get_steel_gemm_splitk_accum_kernel(d, kernel_name, C_split, out, false); + compute_encoder.set_compute_pipeline_state(kernel); + + // Set the arguments for the kernel + compute_encoder.set_input_array(C_split, 0); + compute_encoder.set_output_array(out, 1); + compute_encoder.set_bytes(split_k_partitions, 2); + compute_encoder.set_bytes(split_k_partition_stride, 3); + compute_encoder.set_bytes(N, 4); + + // Launch enough thread groups for each output + MTL::Size grid_dims = MTL::Size(N, M, 1); + auto group_dims = get_block_dims(N, M, 1); + compute_encoder.dispatch_threads(grid_dims, group_dims); + } + + d.add_temporaries(std::move(copies), s.index); +} + void steel_matmul( const Stream& s, metal::Device& d, @@ -346,99 +467,21 @@ void steel_matmul( int _tk = K / 16; if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { - int bm = M < 40 ? 16 : 32; - int bn = N < 40 ? 16 : 32; - int bk = 16; - int wm = 2, wn = 2; - - int split_k_partitions = - _tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16)); - int split_k_partition_stride = M * N; - int gemm_k_iterations = (K / bk) / split_k_partitions; - int split_k_partition_size = gemm_k_iterations * bk; - - array C_split({split_k_partitions, M, N}, float32, nullptr, {}); - C_split.set_data(allocator::malloc(C_split.nbytes())); - copies.push_back(C_split); - - bool mn_aligned = M % bm == 0 && N % bn == 0; - bool k_aligned = K % bk == 0; - std::ostringstream kname; - kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n") - << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned"; - - // Encode and dispatch gemm kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = get_steel_gemm_splitk_kernel( - d, - kname.str(), - a, - C_split, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn, - mn_aligned, - k_aligned); - compute_encoder.set_compute_pipeline_state(kernel); - - int tn = (N + bn - 1) / bn; - int tm = (M + bm - 1) / bm; - - GEMMSpiltKParams params{ - /* const int M = */ M, - /* const int N = */ N, - /* const int K = */ K, - /* const int lda = */ lda, - /* const int ldb = */ ldb, - /* const int ldc = */ N, - /* const int tiles_n = */ tn, - /* const int tiles_m = */ tm, - /* const int split_k_partitions = */ split_k_partitions, - /* const int split_k_partition_stride = */ split_k_partition_stride, - /* const int split_k_partition_size = */ split_k_partition_size, - /* const int gemm_k_iterations_aligned = */ gemm_k_iterations}; - - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); - - compute_encoder.set_input_array(a, 0); - compute_encoder.set_input_array(b, 1); - compute_encoder.set_output_array(C_split, 2); - - compute_encoder.set_bytes(params, 3); - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - // Do accum kernel - { - auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + - type_to_name(C_split); - - auto kernel = get_steel_gemm_splitk_accum_kernel( - d, kernel_name, C_split, out, false); - compute_encoder.set_compute_pipeline_state(kernel); - - // Set the arguments for the kernel - compute_encoder.set_input_array(C_split, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder.set_bytes(split_k_partitions, 2); - compute_encoder.set_bytes(split_k_partition_stride, 3); - compute_encoder.set_bytes(N, 4); - - // Launch enough thread groups for each output - MTL::Size grid_dims = MTL::Size(N, M, 1); - auto group_dims = get_block_dims(N, M, 1); - compute_encoder.dispatch_threads(grid_dims, group_dims); - } - - d.add_temporaries(std::move(copies), s.index); - return; + return steel_gemm_splitk( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies); } ///////////////////////////////////////////////////////////////////////////// @@ -447,27 +490,27 @@ void steel_matmul( batch_strides.insert( batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); - steel_matmul_regular( - s, - d, - a, - b, - out, - M, - N, - K, - batch_size_out, - lda, - ldb, - N, - transpose_a, - transpose_b, - std::move(batch_shape), - std::move(batch_strides), - A_batch_stride.back(), - B_batch_stride.back(), - matrix_stride_out, - copies); + return steel_matmul_regular( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* int ldd = */ N, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* Shape batch_shape = */ std::move(batch_shape), + /* Strides batch_strides = */ std::move(batch_strides), + /* int64_t A_batch_stride = */ A_batch_stride.back(), + /* int64_t B_batch_stride = */ B_batch_stride.back(), + /* int64_t matrix_stride_out = */ matrix_stride_out, + /* std::vector& copies = */ copies); } template @@ -715,9 +758,9 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { /* bool transpose_a = */ a_transposed, /* bool transpose_b = */ b_transposed, /* std::vector& copies = */ copies, - /* Shape batch_shape = */ batch_shape, - /* Strides A_batch_stride = */ A_batch_stride, - /* Strides B_batch_stride = */ B_batch_stride); + /* Shape batch_shape = */ std::move(batch_shape), + /* Strides A_batch_stride = */ std::move(A_batch_stride), + /* Strides B_batch_stride = */ std::move(B_batch_stride)); } ///////////////////////////////////////////////////////////////////////////// @@ -738,9 +781,9 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { /* bool transpose_a = */ a_transposed, /* bool transpose_b = */ b_transposed, /* std::vector& copies = */ copies, - /* Shape batch_shape = */ batch_shape, - /* Strides A_batch_stride = */ A_batch_stride, - /* Strides B_batch_stride = */ B_batch_stride); + /* Shape batch_shape = */ std::move(batch_shape), + /* Strides A_batch_stride = */ std::move(A_batch_stride), + /* Strides B_batch_stride = */ std::move(B_batch_stride)); } void AddMM::eval_gpu(const std::vector& inputs, array& out) { From 7e9ac08a61ee58b245e8ae9216798d36b7df26a3 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 11 Jun 2025 07:47:56 -0700 Subject: [PATCH 03/13] Refactor split k axpby --- mlx/backend/metal/matmul.cpp | 184 +++++++++++++++-------------------- 1 file changed, 80 insertions(+), 104 deletions(-) diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 4b3d82d24..b5deaf0b3 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -295,11 +295,13 @@ void steel_matmul_regular( d.add_temporaries(std::move(copies), s.index); } -void steel_gemm_splitk( +template +void steel_gemm_splitk_axpby( const Stream& s, metal::Device& d, const array& a, const array& b, + const array& c, array& out, int M, int N, @@ -309,7 +311,9 @@ void steel_gemm_splitk( int ldb, bool transpose_a, bool transpose_b, - std::vector& copies) { + std::vector& copies, + float alpha = 1.0f, + float beta = 0.0f) { using namespace mlx::steel; int _tm = M / 16; @@ -393,11 +397,21 @@ void steel_gemm_splitk( // Do accum kernel { + const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f); + auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + type_to_name(C_split); - auto kernel = - get_steel_gemm_splitk_accum_kernel(d, kernel_name, C_split, out, false); + if (do_axpby) { + kernel_name = kernel_name + "_axbpy"; + } + + auto kernel = get_steel_gemm_splitk_accum_kernel( + /* metal::Device& d = */ d, + /* const std::string& kernel_name = */ kernel_name, + /* const array& in = */ C_split, + /* const array& out = */ out, + /* bool axbpy = */ do_axpby); compute_encoder.set_compute_pipeline_state(kernel); // Set the arguments for the kernel @@ -407,6 +421,17 @@ void steel_gemm_splitk( compute_encoder.set_bytes(split_k_partition_stride, 3); compute_encoder.set_bytes(N, 4); + if (do_axpby) { + int ldc = c.strides()[c.ndim() - 2]; + int fdc = c.strides()[c.ndim() - 1]; + + compute_encoder.set_input_array(c, 5); + compute_encoder.set_bytes(ldc, 6); + compute_encoder.set_bytes(fdc, 7); + compute_encoder.set_bytes(alpha, 8); + compute_encoder.set_bytes(beta, 9); + } + // Launch enough thread groups for each output MTL::Size grid_dims = MTL::Size(N, M, 1); auto group_dims = get_block_dims(N, M, 1); @@ -416,6 +441,39 @@ void steel_gemm_splitk( d.add_temporaries(std::move(copies), s.index); } +inline void steel_gemm_splitk( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies) { + return steel_gemm_splitk_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies); +} + void steel_matmul( const Stream& s, metal::Device& d, @@ -899,106 +957,24 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { int _tk = K / 16; if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { - int bm = M < 40 ? 16 : 32; - int bn = N < 40 ? 16 : 32; - int bk = 16; - int wm = 2, wn = 2; - - int split_k_partitions = - _tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16)); - int split_k_partition_stride = M * N; - int gemm_k_iterations = (K / bk) / split_k_partitions; - int split_k_partition_size = gemm_k_iterations * bk; - - array C_split({split_k_partitions, M, N}, float32, nullptr, {}); - C_split.set_data(allocator::malloc(C_split.nbytes())); - copies.push_back(C_split); - - bool mn_aligned = M % bm == 0 && N % bn == 0; - bool k_aligned = K % bk == 0; - - std::ostringstream kname; - kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n") - << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned"; - - // Encode and dispatch gemm kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = get_steel_gemm_splitk_kernel( - d, - kname.str(), - a, - C_split, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn, - mn_aligned, - k_aligned); - - compute_encoder.set_compute_pipeline_state(kernel); - - int tn = (N + bn - 1) / bn; - int tm = (M + bm - 1) / bm; - - GEMMSpiltKParams params{ - M, - N, - K, - lda, - ldb, - N, - tn, - tm, - split_k_partitions, - split_k_partition_stride, - split_k_partition_size, - gemm_k_iterations}; - - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); - - compute_encoder.set_input_array(a, 0); - compute_encoder.set_input_array(b, 1); - compute_encoder.set_output_array(C_split, 2); - - compute_encoder.set_bytes(params, 3); - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - // Do accum kernel - { - auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + - type_to_name(C_split) + "_axbpy"; - auto kernel = get_steel_gemm_splitk_accum_kernel( - d, kernel_name, C_split, out, true); - - compute_encoder.set_compute_pipeline_state(kernel); - - // Set the arguments for the kernel - compute_encoder.set_input_array(C_split, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder.set_bytes(split_k_partitions, 2); - compute_encoder.set_bytes(split_k_partition_stride, 3); - compute_encoder.set_bytes(N, 4); - compute_encoder.set_input_array(c, 5); - compute_encoder.set_bytes(ldc, 6); - compute_encoder.set_bytes(fdc, 7); - compute_encoder.set_bytes(alpha_, 8); - compute_encoder.set_bytes(beta_, 9); - - // Launch enough thread groups for each output - MTL::Size grid_dims = MTL::Size(N, M, 1); - auto group_dims = get_block_dims(N, M, 1); - compute_encoder.dispatch_threads(grid_dims, group_dims); - } - - d.add_temporaries(std::move(copies), s.index); - return; + return steel_gemm_splitk_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* float alpha = */ alpha_, + /* float beta = */ beta_); } ///////////////////////////////////////////////////////////////////////////// From 13585ba4ee3a1769dbd0b60b0dc71deacadc2bca Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 11 Jun 2025 08:38:52 -0700 Subject: [PATCH 04/13] Rearrange steel_gemm_regular --- mlx/backend/metal/conv.cpp | 40 ++++++++++++++++++------------------ mlx/backend/metal/matmul.cpp | 8 ++++---- mlx/backend/metal/matmul.h | 4 ++-- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 697afa6a1..9eb6a6385 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -155,26 +155,26 @@ void explicit_gemm_conv_group_ND_gpu( // Perform gemm std::vector copies = {in_unfolded, wt_transpose}; return steel_matmul_regular( - s, - d, - /* a = */ in_unfolded, - /* b = */ wt_transpose, - /* c = */ out, - /* M = */ implicit_M, - /* N = */ implicit_N, - /* K = */ implicit_K, - /* batch_size_out = */ groups, - /* a_cols = */ implicit_K * groups, - /* b_cols = */ implicit_K, - /* out_cols = */ implicit_N * groups, - /* a_transposed = */ false, - /* b_transposed = */ true, - /* batch_shape = */ {1}, - /* batch_strides = */ {0}, - /* A_batch_strides = */ size_t(implicit_K), - /* B_batch_strides = */ size_t(implicit_N) * implicit_K, - /* matrix_stride_out = */ size_t(implicit_N), - /*copies = */ copies); + /* const Stream& s = */ s, + /* Device& d = */ d, + /* const array& a = */ in_unfolded, + /* const array& b = */ wt_transpose, + /* array& c = */ out, + /* int M = */ implicit_M, + /* int N = */ implicit_N, + /* int K = */ implicit_K, + /* int batch_size_out = */ groups, + /* int lda = */ implicit_K * groups, + /* int ldb = */ implicit_K, + /* int ldd = */ implicit_N * groups, + /* bool transpose_a = */ false, + /* bool transpose_b = */ true, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ {1}, + /* Strides batch_strides = */ {0}, + /* int64_t A_batch_strides = */ int64_t(implicit_K), + /* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K, + /* int64_t matrix_stride_out = */ int64_t(implicit_N)); } void implicit_gemm_conv_2D_gpu( diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index b5deaf0b3..7703022c7 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -179,12 +179,12 @@ void steel_matmul_regular( int ldd, bool transpose_a, bool transpose_b, + std::vector& copies, Shape batch_shape, Strides batch_strides, int64_t A_batch_stride, int64_t B_batch_stride, - int64_t matrix_stride_out, - std::vector& copies) { + int64_t matrix_stride_out) { using namespace mlx::steel; // Determine dispatch kernel @@ -563,12 +563,12 @@ void steel_matmul( /* int ldd = */ N, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, /* Shape batch_shape = */ std::move(batch_shape), /* Strides batch_strides = */ std::move(batch_strides), /* int64_t A_batch_stride = */ A_batch_stride.back(), /* int64_t B_batch_stride = */ B_batch_stride.back(), - /* int64_t matrix_stride_out = */ matrix_stride_out, - /* std::vector& copies = */ copies); + /* int64_t matrix_stride_out = */ matrix_stride_out); } template diff --git a/mlx/backend/metal/matmul.h b/mlx/backend/metal/matmul.h index 09ffe05a8..fb37ae6b2 100644 --- a/mlx/backend/metal/matmul.h +++ b/mlx/backend/metal/matmul.h @@ -21,12 +21,12 @@ void steel_matmul_regular( int ldd, bool transpose_a, bool transpose_b, + std::vector& copies, Shape batch_shape, Strides batch_strides, int64_t A_batch_stride, int64_t B_batch_stride, - int64_t matrix_stride_out, - std::vector& copies); + int64_t matrix_stride_out); void steel_matmul( const Stream& s, From a733dae4baefe2f7873623113ae97fb961773e64 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 11 Jun 2025 08:49:07 -0700 Subject: [PATCH 05/13] Redirect steel_gemm_regular --- mlx/backend/metal/matmul.cpp | 9 ++++-- mlx/backend/metal/matmul.h | 54 ++++++++++++++++++++++++++++++++++-- 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 7703022c7..0ee189e47 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -164,11 +164,13 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) { wn = 2; \ } -void steel_matmul_regular( +template +void steel_matmul_regular_axpby( const Stream& s, metal::Device& d, const array& a, const array& b, + const array& c, array& out, int M, int N, @@ -184,7 +186,10 @@ void steel_matmul_regular( Strides batch_strides, int64_t A_batch_stride, int64_t B_batch_stride, - int64_t matrix_stride_out) { + int64_t matrix_stride_out, + int64_t C_batch_stride /* = 0*/, + float alpha /* = 1.0f */, + float beta /* = 0.0f */) { using namespace mlx::steel; // Determine dispatch kernel diff --git a/mlx/backend/metal/matmul.h b/mlx/backend/metal/matmul.h index fb37ae6b2..9c898b282 100644 --- a/mlx/backend/metal/matmul.h +++ b/mlx/backend/metal/matmul.h @@ -6,7 +6,34 @@ namespace mlx::core { -void steel_matmul_regular( +template +void steel_matmul_regular_axpby( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + int ldd, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape, + Strides batch_strides, + int64_t A_batch_stride, + int64_t B_batch_stride, + int64_t matrix_stride_out, + int64_t C_batch_stride = 0, + float alpha = 1.0f, + float beta = 0.0f); + +inline void steel_matmul_regular( const Stream& s, metal::Device& d, const array& a, @@ -26,7 +53,30 @@ void steel_matmul_regular( Strides batch_strides, int64_t A_batch_stride, int64_t B_batch_stride, - int64_t matrix_stride_out); + int64_t matrix_stride_out) { + return steel_matmul_regular_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* int ldd = */ ldd, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides batch_strides = */ batch_strides, + /* int64_t A_batch_stride = */ A_batch_stride, + /* int64_t B_batch_stride = */ B_batch_stride, + /* int64_t matrix_stride_out = */ matrix_stride_out); +} void steel_matmul( const Stream& s, From dd1b6fa629e13b25c194bd4fe303b23519a79912 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 11 Jun 2025 08:54:42 -0700 Subject: [PATCH 06/13] Add axpby routing to steel_matmul_regular --- mlx/backend/metal/matmul.cpp | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 0ee189e47..1c9ee5a6c 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -209,8 +209,8 @@ void steel_matmul_regular_axpby( std::string base_name = kname.str(); const bool has_batch = (batch_shape.size() > 1); - const bool use_out_source = false; - const bool do_axpby = false; + const bool use_out_source = CHECK_AB && (alpha != 0.0f || beta != 1.0f); + const bool do_axpby = use_out_source && (alpha != 1.0f || beta != 1.0f); const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; @@ -294,6 +294,21 @@ void steel_matmul_regular_axpby( compute_encoder.set_vector_bytes(batch_shape, 6); compute_encoder.set_vector_bytes(batch_strides, 7); + if (use_out_source) { + int ldc = c.strides()[c.ndim() - 2]; + int fdc = c.strides()[c.ndim() - 1]; + + GEMMAddMMParams params{ + /* const int ldc = */ ldc, + /* const int fdc = */ fdc, + /* const int64_t batch_stride_c = */ C_batch_stride, + /* const float alpha = */ alpha, + /* const float beta = */ beta}; + + compute_encoder.set_input_array(c, 2); + compute_encoder.set_bytes(params, 5); + } + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); // Record copies From 3ad2574d1a3f42455b0facb16f77802181491aae Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 11 Jun 2025 09:01:45 -0700 Subject: [PATCH 07/13] Refactor AddMM step 1 --- mlx/backend/metal/matmul.cpp | 137 +++++++---------------------------- 1 file changed, 25 insertions(+), 112 deletions(-) diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 1c9ee5a6c..03d5a89cb 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -1000,124 +1000,37 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { ///////////////////////////////////////////////////////////////////////////// // Regular addmm dispatch - // Determine dispatch kernel - int bm = 64, bn = 64, bk = 16; - int wm = 2, wn = 2; - - char devc = d.get_architecture().back(); - GEMM_TPARAM_MACRO(devc) - - // Prepare kernel name - std::ostringstream kname; - kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn; - - std::string base_name = kname.str(); - - const bool has_batch = (batch_shape.size() > 1); - const bool use_out_source = true; - const bool do_axpby = !(alpha_ == 1. && beta_ == 1.); - const bool align_M = (M % bm) == 0; - const bool align_N = (N % bn) == 0; - const bool align_K = (K % bk) == 0; - - metal::MTLFCList func_consts = { - {&has_batch, MTL::DataType::DataTypeBool, 10}, - {&use_out_source, MTL::DataType::DataTypeBool, 100}, - {&do_axpby, MTL::DataType::DataTypeBool, 110}, - {&align_M, MTL::DataType::DataTypeBool, 200}, - {&align_N, MTL::DataType::DataTypeBool, 201}, - {&align_K, MTL::DataType::DataTypeBool, 202}, - }; - - // clang-format off - kname << "_has_batch_" << (has_batch ? 't' : 'n') - << "_use_out_source_" << (use_out_source ? 't' : 'n') - << "_do_axpby_" << (do_axpby ? 't' : 'n') - << "_align_M_" << (align_M ? 't' : 'n') - << "_align_N_" << (align_N ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on - - std::string hash_name = kname.str(); - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = get_steel_gemm_fused_kernel( - d, - base_name, - hash_name, - func_consts, - out, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn); - - compute_encoder.set_compute_pipeline_state(kernel); - - int tn = (N + bn - 1) / bn; - int tm = (M + bm - 1) / bm; - - // TODO: Explore device-based tuning for swizzle - int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2); - - // Prepare steel matmul params - GEMMParams gemm_params{ - /* const int M = */ M, - /* const int N = */ N, - /* const int K = */ K, - /* const int lda = */ lda, - /* const int ldb = */ ldb, - /* const int ldd = */ N, - /* const int tiles_n = */ tn, - /* const int tiles_m = */ tm, - /* const int64_t batch_stride_a = */ A_batch_stride.back(), - /* const int64_t batch_stride_b = */ B_batch_stride.back(), - /* const int64_t batch_stride_d = */ matrix_stride_out, - /* const int swizzle_log = */ swizzle_log, - /* const int gemm_k_iterations_aligned = */ (K / bk), - /* const int batch_ndim = */ int(batch_shape.size())}; - - GEMMAddMMParams params{ - /* const int ldc = */ ldc, - /* const int fdc = */ fdc, - /* const int64_t batch_stride_c = */ C_batch_stride.back(), - /* const float alpha = */ alpha_, - /* const float beta = */ beta_}; - - int tile = 1 << swizzle_log; - tm = (tm + tile - 1) / tile; - tn = tn * tile; - - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); - Strides batch_strides = A_batch_stride; batch_strides.insert( batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); batch_strides.insert( batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end()); - // Launch kernel - compute_encoder.set_input_array(a, 0); - compute_encoder.set_input_array(b, 1); - compute_encoder.set_input_array(c, 2); - compute_encoder.set_output_array(out, 3); - - compute_encoder.set_bytes(gemm_params, 4); - compute_encoder.set_bytes(params, 5); - - compute_encoder.set_vector_bytes(batch_shape, 6); - compute_encoder.set_vector_bytes(batch_strides, 7); - - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - d.add_temporaries(std::move(copies), s.index); + return steel_matmul_regular_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* int ldd = */ ldd, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ std::move(batch_shape), + /* Strides batch_strides = */ std::move(batch_strides), + /* int64_t A_batch_stride = */ A_batch_stride.back(), + /* int64_t B_batch_stride = */ B_batch_stride.back(), + /* int64_t matrix_stride_out = */ int64_t(M) * ldd, + /* int64_t C_batch_stride = */ C_batch_stride.back(), + /* float alpha = */ alpha_, + /* float beta = */ beta_); } void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { From 2e49b57ea5b30a0da1157d0c27e58366caaae0ac Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 11 Jun 2025 09:26:07 -0700 Subject: [PATCH 08/13] Redirect steel_gemm --- mlx/backend/metal/matmul.cpp | 187 ++++++++++++++++++++++++++++++----- mlx/backend/metal/matmul.h | 48 ++++++++- 2 files changed, 210 insertions(+), 25 deletions(-) diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 03d5a89cb..3a17686fc 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -494,11 +494,13 @@ inline void steel_gemm_splitk( /* std::vector& copies = */ copies); } -void steel_matmul( +template +void steel_matmul_axpby( const Stream& s, metal::Device& d, const array& a, const array& b, + const array& c, array& out, int M, int N, @@ -511,32 +513,56 @@ void steel_matmul( std::vector& copies, Shape batch_shape /* = {} */, Strides A_batch_stride /* = {} */, - Strides B_batch_stride /* = {} */) { + Strides B_batch_stride /* = {} */, + Strides C_batch_stride /* = {} */, + float alpha /* = 1.0f */, + float beta /* = 0.0f */) { using namespace mlx::steel; if (batch_shape.empty()) { ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions - auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b); + if constexpr (CHECK_AB) { + auto [batch_shape_, A_bstride_, B_bstride_, C_bstride_] = + collapse_batches(a, b, c); - batch_shape = batch_shape_; - A_batch_stride = A_bstride_; - B_batch_stride = B_bstride_; - // Collapse batches into M if needed - if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && - a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && - B_batch_stride.back() == 0) { - M *= batch_shape.back(); - batch_size_out = 1; + batch_shape = batch_shape_; + A_batch_stride = A_bstride_; + B_batch_stride = B_bstride_; + C_batch_stride = C_bstride_; + // Collapse batches into M if needed + if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && + C_batch_stride.back() == M * c.strides()[c.ndim() - 2] && + B_batch_stride.back() == 0) { + M *= batch_shape.back(); + batch_size_out = 1; - A_batch_stride = {0}; - B_batch_stride = {0}; - batch_shape = {1}; + A_batch_stride = {0}; + B_batch_stride = {0}; + C_batch_stride = {0}; + batch_shape = {1}; + } + } else { + auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b); + + batch_shape = batch_shape_; + A_batch_stride = A_bstride_; + B_batch_stride = B_bstride_; + // Collapse batches into M if needed + if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && + B_batch_stride.back() == 0) { + M *= batch_shape.back(); + batch_size_out = 1; + + A_batch_stride = {0}; + B_batch_stride = {0}; + batch_shape = {1}; + } } } - size_t matrix_stride_out = size_t(M) * N; - ///////////////////////////////////////////////////////////////////////////// // Split K specialization @@ -545,11 +571,12 @@ void steel_matmul( int _tk = K / 16; if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { - return steel_gemm_splitk( + return steel_gemm_splitk_axpby( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, /* const array& b = */ b, + /* const array& c = */ c, /* array& out = */ out, /* int M = */ M, /* int N = */ N, @@ -559,7 +586,9 @@ void steel_matmul( /* int ldb = */ ldb, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, - /* std::vector& copies = */ copies); + /* std::vector& copies = */ copies, + /* float alpha = */ alpha, + /* float beta = */ beta); } ///////////////////////////////////////////////////////////////////////////// @@ -567,12 +596,21 @@ void steel_matmul( auto batch_strides = A_batch_stride; batch_strides.insert( batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); + if (CHECK_AB && !C_batch_stride.empty()) { + batch_strides.insert( + batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end()); + } - return steel_matmul_regular( + int64_t A_batch_stride_ = A_batch_stride.empty() ? 0 : A_batch_stride.back(); + int64_t B_batch_stride_ = B_batch_stride.empty() ? 0 : B_batch_stride.back(); + int64_t C_batch_stride_ = C_batch_stride.empty() ? 0 : C_batch_stride.back(); + + return steel_matmul_regular_axpby( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, /* const array& b = */ b, + /* const array& c = */ c, /* array& out = */ out, /* int M = */ M, /* int N = */ N, @@ -586,11 +624,114 @@ void steel_matmul( /* std::vector& copies = */ copies, /* Shape batch_shape = */ std::move(batch_shape), /* Strides batch_strides = */ std::move(batch_strides), - /* int64_t A_batch_stride = */ A_batch_stride.back(), - /* int64_t B_batch_stride = */ B_batch_stride.back(), - /* int64_t matrix_stride_out = */ matrix_stride_out); + /* int64_t A_batch_stride = */ A_batch_stride_, + /* int64_t B_batch_stride = */ B_batch_stride_, + /* int64_t matrix_stride_out = */ int64_t(M) * N, + /* int64_t C_batch_stride = */ C_batch_stride_, + /* float alpha = */ alpha, + /* float beta = */ beta); } +// void steel_matmul( +// const Stream& s, +// metal::Device& d, +// const array& a, +// const array& b, +// array& out, +// int M, +// int N, +// int K, +// int batch_size_out, +// int lda, +// int ldb, +// bool transpose_a, +// bool transpose_b, +// std::vector& copies, +// Shape batch_shape /* = {} */, +// Strides A_batch_stride /* = {} */, +// Strides B_batch_stride /* = {} */) { + +// return + +// using namespace mlx::steel; + +// if (batch_shape.empty()) { +// ///////////////////////////////////////////////////////////////////////////// +// // Check and collapse batch dimensions +// auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b); + +// batch_shape = batch_shape_; +// A_batch_stride = A_bstride_; +// B_batch_stride = B_bstride_; +// // Collapse batches into M if needed +// if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && +// a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && +// B_batch_stride.back() == 0) { +// M *= batch_shape.back(); +// batch_size_out = 1; + +// A_batch_stride = {0}; +// B_batch_stride = {0}; +// batch_shape = {1}; +// } +// } + +// size_t matrix_stride_out = size_t(M) * N; + +// ///////////////////////////////////////////////////////////////////////////// +// // Split K specialization + +// int _tm = M / 16; +// int _tn = N / 16; +// int _tk = K / 16; + +// if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { +// return steel_gemm_splitk( +// /* const Stream& s = */ s, +// /* metal::Device& d = */ d, +// /* const array& a = */ a, +// /* const array& b = */ b, +// /* array& out = */ out, +// /* int M = */ M, +// /* int N = */ N, +// /* int K = */ K, +// /* int batch_size_out = */ batch_size_out, +// /* int lda = */ lda, +// /* int ldb = */ ldb, +// /* bool transpose_a = */ transpose_a, +// /* bool transpose_b = */ transpose_b, +// /* std::vector& copies = */ copies); +// } + +// ///////////////////////////////////////////////////////////////////////////// +// // Regular kernel dispatch +// auto batch_strides = A_batch_stride; +// batch_strides.insert( +// batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); + +// return steel_matmul_regular( +// /* const Stream& s = */ s, +// /* metal::Device& d = */ d, +// /* const array& a = */ a, +// /* const array& b = */ b, +// /* array& out = */ out, +// /* int M = */ M, +// /* int N = */ N, +// /* int K = */ K, +// /* int batch_size_out = */ batch_size_out, +// /* int lda = */ lda, +// /* int ldb = */ ldb, +// /* int ldd = */ N, +// /* bool transpose_a = */ transpose_a, +// /* bool transpose_b = */ transpose_b, +// /* std::vector& copies = */ copies, +// /* Shape batch_shape = */ std::move(batch_shape), +// /* Strides batch_strides = */ std::move(batch_strides), +// /* int64_t A_batch_stride = */ A_batch_stride.back(), +// /* int64_t B_batch_stride = */ B_batch_stride.back(), +// /* int64_t matrix_stride_out = */ matrix_stride_out); +// } + template void gemv_axbpy( const Stream& s, diff --git a/mlx/backend/metal/matmul.h b/mlx/backend/metal/matmul.h index 9c898b282..218664b1f 100644 --- a/mlx/backend/metal/matmul.h +++ b/mlx/backend/metal/matmul.h @@ -78,7 +78,31 @@ inline void steel_matmul_regular( /* int64_t matrix_stride_out = */ matrix_stride_out); } -void steel_matmul( +template +void steel_matmul_axpby( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape = {}, + Strides A_batch_stride = {}, + Strides B_batch_stride = {}, + Strides C_batch_stride = {}, + float alpha = 1.0f, + float beta = 0.0f); + +inline void steel_matmul( const Stream& s, metal::Device& d, const array& a, @@ -95,6 +119,26 @@ void steel_matmul( std::vector& copies, Shape batch_shape = {}, Strides A_batch_stride = {}, - Strides B_batch_stride = {}); + Strides B_batch_stride = {}) { + return steel_matmul_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride); +} } // namespace mlx::core From 04bb802cd0777117d5e9c2a0cecbf4e76f297357 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 11 Jun 2025 09:30:49 -0700 Subject: [PATCH 09/13] Update addmm --- mlx/backend/metal/matmul.cpp | 149 ++--------------------------------- 1 file changed, 5 insertions(+), 144 deletions(-) diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 3a17686fc..bc4cee56f 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -632,106 +632,6 @@ void steel_matmul_axpby( /* float beta = */ beta); } -// void steel_matmul( -// const Stream& s, -// metal::Device& d, -// const array& a, -// const array& b, -// array& out, -// int M, -// int N, -// int K, -// int batch_size_out, -// int lda, -// int ldb, -// bool transpose_a, -// bool transpose_b, -// std::vector& copies, -// Shape batch_shape /* = {} */, -// Strides A_batch_stride /* = {} */, -// Strides B_batch_stride /* = {} */) { - -// return - -// using namespace mlx::steel; - -// if (batch_shape.empty()) { -// ///////////////////////////////////////////////////////////////////////////// -// // Check and collapse batch dimensions -// auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b); - -// batch_shape = batch_shape_; -// A_batch_stride = A_bstride_; -// B_batch_stride = B_bstride_; -// // Collapse batches into M if needed -// if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && -// a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && -// B_batch_stride.back() == 0) { -// M *= batch_shape.back(); -// batch_size_out = 1; - -// A_batch_stride = {0}; -// B_batch_stride = {0}; -// batch_shape = {1}; -// } -// } - -// size_t matrix_stride_out = size_t(M) * N; - -// ///////////////////////////////////////////////////////////////////////////// -// // Split K specialization - -// int _tm = M / 16; -// int _tn = N / 16; -// int _tk = K / 16; - -// if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { -// return steel_gemm_splitk( -// /* const Stream& s = */ s, -// /* metal::Device& d = */ d, -// /* const array& a = */ a, -// /* const array& b = */ b, -// /* array& out = */ out, -// /* int M = */ M, -// /* int N = */ N, -// /* int K = */ K, -// /* int batch_size_out = */ batch_size_out, -// /* int lda = */ lda, -// /* int ldb = */ ldb, -// /* bool transpose_a = */ transpose_a, -// /* bool transpose_b = */ transpose_b, -// /* std::vector& copies = */ copies); -// } - -// ///////////////////////////////////////////////////////////////////////////// -// // Regular kernel dispatch -// auto batch_strides = A_batch_stride; -// batch_strides.insert( -// batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); - -// return steel_matmul_regular( -// /* const Stream& s = */ s, -// /* metal::Device& d = */ d, -// /* const array& a = */ a, -// /* const array& b = */ b, -// /* array& out = */ out, -// /* int M = */ M, -// /* int N = */ N, -// /* int K = */ K, -// /* int batch_size_out = */ batch_size_out, -// /* int lda = */ lda, -// /* int ldb = */ ldb, -// /* int ldd = */ N, -// /* bool transpose_a = */ transpose_a, -// /* bool transpose_b = */ transpose_b, -// /* std::vector& copies = */ copies, -// /* Shape batch_shape = */ std::move(batch_shape), -// /* Strides batch_strides = */ std::move(batch_strides), -// /* int64_t A_batch_stride = */ A_batch_stride.back(), -// /* int64_t B_batch_stride = */ B_batch_stride.back(), -// /* int64_t matrix_stride_out = */ matrix_stride_out); -// } - template void gemv_axbpy( const Stream& s, @@ -1108,46 +1008,10 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { /* float beta = */ beta_); } - using namespace mlx::steel; - - ///////////////////////////////////////////////////////////////////////////// - // Split K specialization - - int _tm = M / 16; - int _tn = N / 16; - int _tk = K / 16; - - if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { - return steel_gemm_splitk_axpby( - /* const Stream& s = */ s, - /* metal::Device& d = */ d, - /* const array& a = */ a, - /* const array& b = */ b, - /* const array& c = */ c, - /* array& out = */ out, - /* int M = */ M, - /* int N = */ N, - /* int K = */ K, - /* int batch_size_out = */ batch_size_out, - /* int lda = */ lda, - /* int ldb = */ ldb, - /* bool transpose_a = */ transpose_a, - /* bool transpose_b = */ transpose_b, - /* std::vector& copies = */ copies, - /* float alpha = */ alpha_, - /* float beta = */ beta_); - } - ///////////////////////////////////////////////////////////////////////////// // Regular addmm dispatch - Strides batch_strides = A_batch_stride; - batch_strides.insert( - batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); - batch_strides.insert( - batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end()); - - return steel_matmul_regular_axpby( + return steel_matmul_axpby( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, @@ -1160,16 +1024,13 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { /* int batch_size_out = */ batch_size_out, /* int lda = */ lda, /* int ldb = */ ldb, - /* int ldd = */ ldd, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, /* std::vector& copies = */ copies, - /* Shape batch_shape = */ std::move(batch_shape), - /* Strides batch_strides = */ std::move(batch_strides), - /* int64_t A_batch_stride = */ A_batch_stride.back(), - /* int64_t B_batch_stride = */ B_batch_stride.back(), - /* int64_t matrix_stride_out = */ int64_t(M) * ldd, - /* int64_t C_batch_stride = */ C_batch_stride.back(), + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride, + /* Strides B_batch_stride = */ C_batch_stride, /* float alpha = */ alpha_, /* float beta = */ beta_); } From df3bd6b52d2116d8b599ae84d493bb8d9e2d027d Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 11 Jun 2025 09:35:52 -0700 Subject: [PATCH 10/13] Comments and format --- mlx/backend/metal/matmul.cpp | 46 +++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index bc4cee56f..c8c933223 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -201,10 +201,15 @@ void steel_matmul_regular_axpby( // Prepare kernel name std::ostringstream kname; - kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn; + + // clang-format off + kname << "steel_gemm_fused_" + << (transpose_a ? 't' : 'n') + << (transpose_b ? 't' : 'n') + << "_" << type_to_name(a) + << "_" << type_to_name(out) + << "_bm" << bm << "_bn" << bn << "_bk" << bk + << "_wm" << wm << "_wn" << wn; // clang-format on std::string base_name = kname.str(); @@ -237,18 +242,18 @@ void steel_matmul_regular_axpby( // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_gemm_fused_kernel( - d, - base_name, - hash_name, - func_consts, - out, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn); + /* metal::Device& d = */ d, + /* const std::string& kernel_name = */ base_name, + /* const std::string& hash_name = */ hash_name, + /* const metal::MTLFCList& func_consts = */ func_consts, + /* const array& out = */ out, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* int bm = */ bm, + /* int bn = */ bn, + /* int bk = */ bk, + /* int wm = */ wm, + /* int wn = */ wn); compute_encoder.set_compute_pipeline_state(kernel); @@ -722,9 +727,12 @@ void gemv_axbpy( const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f); - kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" - << tm << "_tn" << tn; - kname << "_nc" << !contiguous_kernel << "_axpby" << do_axpby; + // clang-format off + kname << "_bm" << bm << "_bn" << bn + << "_sm" << sm << "_sn" << sn + << "_tm" << tm << "_tn" << tn + << "_nc" << !contiguous_kernel + << "_axpby" << do_axpby; // clang-format on // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); From ff21bb0347fb9f5587408e1109ab1d3f897979f8 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 11 Jun 2025 09:43:34 -0700 Subject: [PATCH 11/13] Some cleanup --- .../steel/gemm/kernels/steel_gemm_fused.h | 4 +- mlx/backend/metal/matmul.cpp | 69 ++++++++++--------- 2 files changed, 37 insertions(+), 36 deletions(-) diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h index add495d93..85830872d 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h @@ -33,8 +33,8 @@ template < device T* D [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], - const constant int* batch_shape [[buffer(6)]], - const constant int64_t* batch_strides [[buffer(7)]], + const constant int* batch_shape [[buffer(6), function_constant(has_batch)]], + const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index c8c933223..2a7a8dd94 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -164,6 +164,10 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) { wn = 2; \ } +/////////////////////////////////////////////////////////////////////////////// +// Regular steel matmul dispatch +/////////////////////////////////////////////////////////////////////////////// + template void steel_matmul_regular_axpby( const Stream& s, @@ -296,8 +300,10 @@ void steel_matmul_regular_axpby( compute_encoder.set_bytes(params, 4); - compute_encoder.set_vector_bytes(batch_shape, 6); - compute_encoder.set_vector_bytes(batch_strides, 7); + if (has_batch) { + compute_encoder.set_vector_bytes(batch_shape, 6); + compute_encoder.set_vector_bytes(batch_strides, 7); + } if (use_out_source) { int ldc = c.strides()[c.ndim() - 2]; @@ -320,6 +326,10 @@ void steel_matmul_regular_axpby( d.add_temporaries(std::move(copies), s.index); } +/////////////////////////////////////////////////////////////////////////////// +// Split k steel matmul +/////////////////////////////////////////////////////////////////////////////// + template void steel_gemm_splitk_axpby( const Stream& s, @@ -466,38 +476,9 @@ void steel_gemm_splitk_axpby( d.add_temporaries(std::move(copies), s.index); } -inline void steel_gemm_splitk( - const Stream& s, - metal::Device& d, - const array& a, - const array& b, - array& out, - int M, - int N, - int K, - int batch_size_out, - int lda, - int ldb, - bool transpose_a, - bool transpose_b, - std::vector& copies) { - return steel_gemm_splitk_axpby( - /* const Stream& s = */ s, - /* metal::Device& d = */ d, - /* const array& a = */ a, - /* const array& b = */ b, - /* const array& c = */ b, - /* array& out = */ out, - /* int M = */ M, - /* int N = */ N, - /* int K = */ K, - /* int batch_size_out = */ batch_size_out, - /* int lda = */ lda, - /* int ldb = */ ldb, - /* bool transpose_a = */ transpose_a, - /* bool transpose_b = */ transpose_b, - /* std::vector& copies = */ copies); -} +/////////////////////////////////////////////////////////////////////////////// +// Split matmul routing +/////////////////////////////////////////////////////////////////////////////// template void steel_matmul_axpby( @@ -637,6 +618,10 @@ void steel_matmul_axpby( /* float beta = */ beta); } +/////////////////////////////////////////////////////////////////////////////// +// GEMV dispatch +/////////////////////////////////////////////////////////////////////////////// + template void gemv_axbpy( const Stream& s, @@ -812,6 +797,10 @@ inline void gemv( /* Strides B_batch_stride = */ B_batch_stride); } +/////////////////////////////////////////////////////////////////////////////// +// Matmul implementation +/////////////////////////////////////////////////////////////////////////////// + void Matmul::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (!issubdtype(out.dtype(), floating)) { @@ -913,6 +902,10 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { /* Strides B_batch_stride = */ std::move(B_batch_stride)); } +/////////////////////////////////////////////////////////////////////////////// +// AddMM implementation +/////////////////////////////////////////////////////////////////////////////// + void AddMM::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 3); if (!issubdtype(out.dtype(), floating)) { @@ -1043,6 +1036,10 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { /* float beta = */ beta_); } +/////////////////////////////////////////////////////////////////////////////// +// BlockMaskedMM implementation +/////////////////////////////////////////////////////////////////////////////// + void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { using namespace mlx::steel; // assert(inputs.size() == 2); @@ -1431,6 +1428,10 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { d.add_temporaries(std::move(copies), s.index); } +/////////////////////////////////////////////////////////////////////////////// +// GatherMM implementation +/////////////////////////////////////////////////////////////////////////////// + void gather_mm_rhs( const array& a_, const array& b_, From 629e63a3122af36c82b1f0e87a7e76eac9aa9a4c Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 11 Jun 2025 09:56:01 -0700 Subject: [PATCH 12/13] Add architecture gen to device --- mlx/backend/metal/device.cpp | 3 +++ mlx/backend/metal/device.h | 5 +++++ mlx/backend/metal/matmul.cpp | 2 -- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 425274361..88835eb75 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -297,6 +297,9 @@ Device::Device() { device_ = load_device(); default_library_ = load_default_library(device_); arch_ = std::string(device_->architecture()->name()->utf8String()); + int ag_tens = arch_[arch_.size() - 3] - '0'; + int ag_ones = arch_[arch_.size() - 2] - '0'; + arch_gen_ = ag_tens * 10 + ag_ones; auto arch = arch_.back(); switch (arch) { case 'p': // phone diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 5bfcc6649..f87a8c48b 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -177,6 +177,10 @@ class Device { return arch_; } + int get_architecture_gen() const { + return arch_gen_; + } + void new_queue(int index); MTL::CommandQueue* get_queue(Stream stream); @@ -268,6 +272,7 @@ class Device { library_kernels_; const MTL::ResidencySet* residency_set_{nullptr}; std::string arch_; + int arch_gen_; int max_ops_per_buffer_; int max_mb_per_buffer_; }; diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 2a7a8dd94..be7f3e2f8 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -503,8 +503,6 @@ void steel_matmul_axpby( Strides C_batch_stride /* = {} */, float alpha /* = 1.0f */, float beta /* = 0.0f */) { - using namespace mlx::steel; - if (batch_shape.empty()) { ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions From d5e3fe7f86e89b77a471edbb1a58189420291f99 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 11 Jun 2025 09:58:15 -0700 Subject: [PATCH 13/13] Update no copy condition in normalization to account for axis size 1 --- mlx/backend/metal/normalization.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index d570bf3c0..8674eff72 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -26,7 +26,7 @@ void RMSNorm::eval_gpu( bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; - no_copy &= (s == 0 || s == x.shape().back()); + no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1); } if (no_copy) { if (x.is_donatable()) { @@ -227,7 +227,7 @@ void LayerNorm::eval_gpu( bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; - no_copy &= (s == 0 || s == x.shape().back()); + no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1); } if (no_copy) { if (x.is_donatable()) {