From c2f1c2a33826990b8f7fd176ab9709fd77e138b8 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 11 Jun 2025 07:47:56 -0700 Subject: [PATCH] Refactor split k axpby --- mlx/backend/metal/matmul.cpp | 184 +++++++++++++++-------------------- 1 file changed, 80 insertions(+), 104 deletions(-) diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 4b3d82d24..b5deaf0b3 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -295,11 +295,13 @@ void steel_matmul_regular( d.add_temporaries(std::move(copies), s.index); } -void steel_gemm_splitk( +template +void steel_gemm_splitk_axpby( const Stream& s, metal::Device& d, const array& a, const array& b, + const array& c, array& out, int M, int N, @@ -309,7 +311,9 @@ void steel_gemm_splitk( int ldb, bool transpose_a, bool transpose_b, - std::vector& copies) { + std::vector& copies, + float alpha = 1.0f, + float beta = 0.0f) { using namespace mlx::steel; int _tm = M / 16; @@ -393,11 +397,21 @@ void steel_gemm_splitk( // Do accum kernel { + const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f); + auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + type_to_name(C_split); - auto kernel = - get_steel_gemm_splitk_accum_kernel(d, kernel_name, C_split, out, false); + if (do_axpby) { + kernel_name = kernel_name + "_axbpy"; + } + + auto kernel = get_steel_gemm_splitk_accum_kernel( + /* metal::Device& d = */ d, + /* const std::string& kernel_name = */ kernel_name, + /* const array& in = */ C_split, + /* const array& out = */ out, + /* bool axbpy = */ do_axpby); compute_encoder.set_compute_pipeline_state(kernel); // Set the arguments for the kernel @@ -407,6 +421,17 @@ void steel_gemm_splitk( compute_encoder.set_bytes(split_k_partition_stride, 3); compute_encoder.set_bytes(N, 4); + if (do_axpby) { + int ldc = c.strides()[c.ndim() - 2]; + int fdc = c.strides()[c.ndim() - 1]; + + compute_encoder.set_input_array(c, 5); + compute_encoder.set_bytes(ldc, 6); + compute_encoder.set_bytes(fdc, 7); + compute_encoder.set_bytes(alpha, 8); + compute_encoder.set_bytes(beta, 9); + } + // Launch enough thread groups for each output MTL::Size grid_dims = MTL::Size(N, M, 1); auto group_dims = get_block_dims(N, M, 1); @@ -416,6 +441,39 @@ void steel_gemm_splitk( d.add_temporaries(std::move(copies), s.index); } +inline void steel_gemm_splitk( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies) { + return steel_gemm_splitk_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies); +} + void steel_matmul( const Stream& s, metal::Device& d, @@ -899,106 +957,24 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { int _tk = K / 16; if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { - int bm = M < 40 ? 16 : 32; - int bn = N < 40 ? 16 : 32; - int bk = 16; - int wm = 2, wn = 2; - - int split_k_partitions = - _tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16)); - int split_k_partition_stride = M * N; - int gemm_k_iterations = (K / bk) / split_k_partitions; - int split_k_partition_size = gemm_k_iterations * bk; - - array C_split({split_k_partitions, M, N}, float32, nullptr, {}); - C_split.set_data(allocator::malloc(C_split.nbytes())); - copies.push_back(C_split); - - bool mn_aligned = M % bm == 0 && N % bn == 0; - bool k_aligned = K % bk == 0; - - std::ostringstream kname; - kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n") - << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned"; - - // Encode and dispatch gemm kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = get_steel_gemm_splitk_kernel( - d, - kname.str(), - a, - C_split, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn, - mn_aligned, - k_aligned); - - compute_encoder.set_compute_pipeline_state(kernel); - - int tn = (N + bn - 1) / bn; - int tm = (M + bm - 1) / bm; - - GEMMSpiltKParams params{ - M, - N, - K, - lda, - ldb, - N, - tn, - tm, - split_k_partitions, - split_k_partition_stride, - split_k_partition_size, - gemm_k_iterations}; - - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); - - compute_encoder.set_input_array(a, 0); - compute_encoder.set_input_array(b, 1); - compute_encoder.set_output_array(C_split, 2); - - compute_encoder.set_bytes(params, 3); - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - // Do accum kernel - { - auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + - type_to_name(C_split) + "_axbpy"; - auto kernel = get_steel_gemm_splitk_accum_kernel( - d, kernel_name, C_split, out, true); - - compute_encoder.set_compute_pipeline_state(kernel); - - // Set the arguments for the kernel - compute_encoder.set_input_array(C_split, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder.set_bytes(split_k_partitions, 2); - compute_encoder.set_bytes(split_k_partition_stride, 3); - compute_encoder.set_bytes(N, 4); - compute_encoder.set_input_array(c, 5); - compute_encoder.set_bytes(ldc, 6); - compute_encoder.set_bytes(fdc, 7); - compute_encoder.set_bytes(alpha_, 8); - compute_encoder.set_bytes(beta_, 9); - - // Launch enough thread groups for each output - MTL::Size grid_dims = MTL::Size(N, M, 1); - auto group_dims = get_block_dims(N, M, 1); - compute_encoder.dispatch_threads(grid_dims, group_dims); - } - - d.add_temporaries(std::move(copies), s.index); - return; + return steel_gemm_splitk_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* float alpha = */ alpha_, + /* float beta = */ beta_); } /////////////////////////////////////////////////////////////////////////////