From d192587cdf4e3de4c82cbdef821d4538b7ee39c2 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 11 Jun 2025 07:29:36 -0700 Subject: [PATCH] Refactor splitk step 1 --- mlx/backend/metal/matmul.cpp | 283 ++++++++++++++++++++--------------- 1 file changed, 163 insertions(+), 120 deletions(-) diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 8f4fa8e61..4b3d82d24 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -295,6 +295,127 @@ void steel_matmul_regular( d.add_temporaries(std::move(copies), s.index); } +void steel_gemm_splitk( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies) { + using namespace mlx::steel; + + int _tm = M / 16; + int _tn = N / 16; + int _tk = K / 16; + + int bm = M < 40 ? 16 : 32; + int bn = N < 40 ? 16 : 32; + int bk = 16; + int wm = 2, wn = 2; + + int split_k_partitions = _tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16)); + int split_k_partition_stride = M * N; + int gemm_k_iterations = (K / bk) / split_k_partitions; + int split_k_partition_size = gemm_k_iterations * bk; + + array C_split({split_k_partitions, M, N}, float32, nullptr, {}); + C_split.set_data(allocator::malloc(C_split.nbytes())); + copies.push_back(C_split); + + bool mn_aligned = M % bm == 0 && N % bn == 0; + bool k_aligned = K % bk == 0; + std::ostringstream kname; + + // clang-format off + kname << "steel_gemm_splitk_" + << (transpose_a ? 't' : 'n') + << (transpose_b ? 't' : 'n') + << "_" << type_to_name(a) + << "_" << type_to_name(C_split) + << "_bm" << bm << "_bn" << bn << "_bk" << bk + << "_wm" << wm << "_wn" << wn + << "_MN_" << (mn_aligned ? "t" : "n") << "aligned" + << "_K_" << (k_aligned ? "t" : "n") << "aligned"; // clang-format on + + // Encode and dispatch gemm kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_steel_gemm_splitk_kernel( + /* metal::Device& d = */ d, + /* const std::string& kernel_name = */ kname.str(), + /* const array& in = */ a, + /* const array& out = */ C_split, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* int bm = */ bm, + /* int bn = */ bn, + /* int bk = */ bk, + /* int wm = */ wm, + /* int wn = */ wn, + /* bool mn_aligned = */ mn_aligned, + /* bool k_aligned = */ k_aligned); + + compute_encoder.set_compute_pipeline_state(kernel); + + int tn = (N + bn - 1) / bn; + int tm = (M + bm - 1) / bm; + + GEMMSpiltKParams params{ + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ lda, + /* const int ldb = */ ldb, + /* const int ldc = */ N, + /* const int tiles_n = */ tn, + /* const int tiles_m = */ tm, + /* const int split_k_partitions = */ split_k_partitions, + /* const int split_k_partition_stride = */ split_k_partition_stride, + /* const int split_k_partition_size = */ split_k_partition_size, + /* const int gemm_k_iterations_aligned = */ gemm_k_iterations}; + + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); + + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_output_array(C_split, 2); + + compute_encoder.set_bytes(params, 3); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + + // Do accum kernel + { + auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + + type_to_name(C_split); + + auto kernel = + get_steel_gemm_splitk_accum_kernel(d, kernel_name, C_split, out, false); + compute_encoder.set_compute_pipeline_state(kernel); + + // Set the arguments for the kernel + compute_encoder.set_input_array(C_split, 0); + compute_encoder.set_output_array(out, 1); + compute_encoder.set_bytes(split_k_partitions, 2); + compute_encoder.set_bytes(split_k_partition_stride, 3); + compute_encoder.set_bytes(N, 4); + + // Launch enough thread groups for each output + MTL::Size grid_dims = MTL::Size(N, M, 1); + auto group_dims = get_block_dims(N, M, 1); + compute_encoder.dispatch_threads(grid_dims, group_dims); + } + + d.add_temporaries(std::move(copies), s.index); +} + void steel_matmul( const Stream& s, metal::Device& d, @@ -346,99 +467,21 @@ void steel_matmul( int _tk = K / 16; if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { - int bm = M < 40 ? 16 : 32; - int bn = N < 40 ? 16 : 32; - int bk = 16; - int wm = 2, wn = 2; - - int split_k_partitions = - _tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16)); - int split_k_partition_stride = M * N; - int gemm_k_iterations = (K / bk) / split_k_partitions; - int split_k_partition_size = gemm_k_iterations * bk; - - array C_split({split_k_partitions, M, N}, float32, nullptr, {}); - C_split.set_data(allocator::malloc(C_split.nbytes())); - copies.push_back(C_split); - - bool mn_aligned = M % bm == 0 && N % bn == 0; - bool k_aligned = K % bk == 0; - std::ostringstream kname; - kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n") - << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned"; - - // Encode and dispatch gemm kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = get_steel_gemm_splitk_kernel( - d, - kname.str(), - a, - C_split, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn, - mn_aligned, - k_aligned); - compute_encoder.set_compute_pipeline_state(kernel); - - int tn = (N + bn - 1) / bn; - int tm = (M + bm - 1) / bm; - - GEMMSpiltKParams params{ - /* const int M = */ M, - /* const int N = */ N, - /* const int K = */ K, - /* const int lda = */ lda, - /* const int ldb = */ ldb, - /* const int ldc = */ N, - /* const int tiles_n = */ tn, - /* const int tiles_m = */ tm, - /* const int split_k_partitions = */ split_k_partitions, - /* const int split_k_partition_stride = */ split_k_partition_stride, - /* const int split_k_partition_size = */ split_k_partition_size, - /* const int gemm_k_iterations_aligned = */ gemm_k_iterations}; - - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); - - compute_encoder.set_input_array(a, 0); - compute_encoder.set_input_array(b, 1); - compute_encoder.set_output_array(C_split, 2); - - compute_encoder.set_bytes(params, 3); - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - // Do accum kernel - { - auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + - type_to_name(C_split); - - auto kernel = get_steel_gemm_splitk_accum_kernel( - d, kernel_name, C_split, out, false); - compute_encoder.set_compute_pipeline_state(kernel); - - // Set the arguments for the kernel - compute_encoder.set_input_array(C_split, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder.set_bytes(split_k_partitions, 2); - compute_encoder.set_bytes(split_k_partition_stride, 3); - compute_encoder.set_bytes(N, 4); - - // Launch enough thread groups for each output - MTL::Size grid_dims = MTL::Size(N, M, 1); - auto group_dims = get_block_dims(N, M, 1); - compute_encoder.dispatch_threads(grid_dims, group_dims); - } - - d.add_temporaries(std::move(copies), s.index); - return; + return steel_gemm_splitk( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies); } ///////////////////////////////////////////////////////////////////////////// @@ -447,27 +490,27 @@ void steel_matmul( batch_strides.insert( batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); - steel_matmul_regular( - s, - d, - a, - b, - out, - M, - N, - K, - batch_size_out, - lda, - ldb, - N, - transpose_a, - transpose_b, - std::move(batch_shape), - std::move(batch_strides), - A_batch_stride.back(), - B_batch_stride.back(), - matrix_stride_out, - copies); + return steel_matmul_regular( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* int ldd = */ N, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* Shape batch_shape = */ std::move(batch_shape), + /* Strides batch_strides = */ std::move(batch_strides), + /* int64_t A_batch_stride = */ A_batch_stride.back(), + /* int64_t B_batch_stride = */ B_batch_stride.back(), + /* int64_t matrix_stride_out = */ matrix_stride_out, + /* std::vector& copies = */ copies); } template @@ -715,9 +758,9 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { /* bool transpose_a = */ a_transposed, /* bool transpose_b = */ b_transposed, /* std::vector& copies = */ copies, - /* Shape batch_shape = */ batch_shape, - /* Strides A_batch_stride = */ A_batch_stride, - /* Strides B_batch_stride = */ B_batch_stride); + /* Shape batch_shape = */ std::move(batch_shape), + /* Strides A_batch_stride = */ std::move(A_batch_stride), + /* Strides B_batch_stride = */ std::move(B_batch_stride)); } ///////////////////////////////////////////////////////////////////////////// @@ -738,9 +781,9 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { /* bool transpose_a = */ a_transposed, /* bool transpose_b = */ b_transposed, /* std::vector& copies = */ copies, - /* Shape batch_shape = */ batch_shape, - /* Strides A_batch_stride = */ A_batch_stride, - /* Strides B_batch_stride = */ B_batch_stride); + /* Shape batch_shape = */ std::move(batch_shape), + /* Strides A_batch_stride = */ std::move(A_batch_stride), + /* Strides B_batch_stride = */ std::move(B_batch_stride)); } void AddMM::eval_gpu(const std::vector& inputs, array& out) {