From 358e1fd6abd5d3bd0066a51aa211a198a7360967 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 15 May 2024 10:30:41 -0700 Subject: [PATCH] Fused GEMM (#1123) * Basic gemm working * Update addmm * Clear out steel_gemm and steel_addmm kernels * Fuse and clear out gather gemm * Update objc releases --- mlx/backend/metal/device.cpp | 3 +- .../steel/gemm/kernels/steel_gemm.metal | 136 ----- .../steel/gemm/kernels/steel_gemm_addmm.metal | 340 ------------- .../steel/gemm/kernels/steel_gemm_fused.metal | 468 ++++++++++++++++++ .../gemm/kernels/steel_gemm_gather.metal | 168 ------- mlx/backend/metal/kernels/steel/gemm/mma.h | 67 +++ .../metal/kernels/steel/gemm/transforms.h | 8 + mlx/backend/metal/matmul.cpp | 209 ++++++-- 8 files changed, 709 insertions(+), 690 deletions(-) delete mode 100644 mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm.metal delete mode 100644 mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_addmm.metal create mode 100644 mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal delete mode 100644 mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index f6db67c68..03974db3e 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -357,7 +357,6 @@ MTL::Function* Device::get_function_( } mtl_func_consts->release(); - desc->release(); return mtl_function; } @@ -526,11 +525,13 @@ MTL::ComputePipelineState* Device::get_kernel( // Compile kernel to compute pipeline auto mtl_linked_funcs = get_linked_functions_(linked_functions); auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs); + mtl_function->release(); mtl_linked_funcs->release(); // Add kernel to cache kernel_map_.insert({kname, kernel}); + return kernel; } diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm.metal deleted file mode 100644 index b7445f2b1..000000000 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm.metal +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" -#include "mlx/backend/metal/kernels/utils.h" - -using namespace metal; -using namespace mlx::steel; - -/////////////////////////////////////////////////////////////////////////////// -// GEMM kernels -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - bool MN_aligned, - bool K_aligned> -[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - device T* D [[buffer(3)]], - const constant GEMMParams* params [[buffer(4)]], - const constant int* batch_shape [[buffer(6)]], - const constant size_t* batch_strides [[buffer(7)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - using gemm_kernel = GEMMKernel< - T, - T, - BM, - BN, - BK, - WM, - WN, - transpose_a, - transpose_b, - MN_aligned, - K_aligned>; - - threadgroup T As[gemm_kernel::tgp_mem_size_a]; - threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; - - // Adjust for batch - if (params->batch_ndim > 1) { - const constant size_t* A_bstrides = batch_strides; - const constant size_t* B_bstrides = batch_strides + params->batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); - - A += batch_offsets.x; - B += batch_offsets.y; - - } else { - A += params->batch_stride_a * tid.z; - B += params->batch_stride_b * tid.z; - } - - D += params->batch_stride_d * tid.z; - - gemm_kernel::run( - A, B, D, params, As, Bs, simd_lane_id, simd_group_id, tid, lid); -} - -/////////////////////////////////////////////////////////////////////////////// -// GEMM kernel initializations -/////////////////////////////////////////////////////////////////////////////// - -#define instantiate_gemm( \ - tname, \ - trans_a, \ - trans_b, \ - iname, \ - itype, \ - oname, \ - otype, \ - bm, \ - bn, \ - bk, \ - wm, \ - wn, \ - aname, \ - mn_aligned, \ - kname, \ - k_aligned) \ - template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm \ - "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname \ - "_K_" #kname)]] [[kernel]] void \ - gemm( \ - const device itype* A [[buffer(0)]], \ - const device itype* B [[buffer(1)]], \ - device itype* D [[buffer(3)]], \ - const constant GEMMParams* params [[buffer(4)]], \ - const constant int* batch_shape [[buffer(6)]], \ - const constant size_t* batch_strides [[buffer(7)]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]]); - -// clang-format off -#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ - instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ - instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ - instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on - -// clang-format off -#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on - -// clang-format off -#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on - -// clang-format off -instantiate_gemm_shapes_helper(float16, half, float16, half); -instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); - -instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_addmm.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_addmm.metal deleted file mode 100644 index 848ff1b81..000000000 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_addmm.metal +++ /dev/null @@ -1,340 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" - -using namespace metal; -using namespace mlx::steel; - -/////////////////////////////////////////////////////////////////////////////// -// GEMM kernels -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - bool MN_aligned, - bool K_aligned, - typename AccumType = float, - typename Epilogue = TransformAdd> -[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void addmm( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - const device T* C [[buffer(2)]], - device T* D [[buffer(3)]], - const constant GEMMParams* params [[buffer(4)]], - const constant GEMMAddMMParams* addmm_params [[buffer(5)]], - const constant int* batch_shape [[buffer(6)]], - const constant size_t* batch_strides [[buffer(7)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - // Pacifying compiler - (void)lid; - - using gemm_kernel = GEMMKernel< - T, - T, - BM, - BN, - BK, - WM, - WN, - transpose_a, - transpose_b, - MN_aligned, - K_aligned, - AccumType, - Epilogue>; - - using loader_a_t = typename gemm_kernel::loader_a_t; - using loader_b_t = typename gemm_kernel::loader_b_t; - using mma_t = typename gemm_kernel::mma_t; - - threadgroup T As[gemm_kernel::tgp_mem_size_a]; - threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; - - // Adjust for batch - if (params->batch_ndim > 1) { - const constant size_t* A_bstrides = batch_strides; - const constant size_t* B_bstrides = batch_strides + params->batch_ndim; - const constant size_t* C_bstrides = B_bstrides + params->batch_ndim; - - ulong3 batch_offsets = elem_to_loc_broadcast( - tid.z, - batch_shape, - A_bstrides, - B_bstrides, - C_bstrides, - params->batch_ndim); - - A += batch_offsets.x; - B += batch_offsets.y; - C += batch_offsets.z; - - } else { - A += params->batch_stride_a * tid.z; - B += params->batch_stride_b * tid.z; - C += addmm_params->batch_stride_c * tid.z; - } - - D += params->batch_stride_d * tid.z; - - const int tid_y = ((tid.y) << params->swizzle_log) + - ((tid.x) & ((1 << params->swizzle_log) - 1)); - const int tid_x = (tid.x) >> params->swizzle_log; - - if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { - return; - } - - threadgroup_barrier(mem_flags::mem_none); - - // Find block in A, B, C - const int c_row = tid_y * BM; - const int c_col = tid_x * BN; - const size_t c_row_long = size_t(c_row); - const size_t c_col_long = size_t(c_col); - - A += transpose_a ? c_row_long : c_row_long * params->lda; - B += transpose_b ? c_col_long * params->ldb : c_col_long; - D += c_row_long * params->ldd + c_col_long; - - C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc; - - // Prepare threadgroup loading operations - thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); - thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); - - // Prepare threadgroup mma operation - thread mma_t mma_op(simd_group_id, simd_lane_id); - - int gemm_k_iterations = params->gemm_k_iterations_aligned; - - const Epilogue epilogue_op(addmm_params->alpha, addmm_params->beta); - - /////////////////////////////////////////////////////////////////////////////// - // MNK aligned loop - if (MN_aligned) { - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - - threadgroup_barrier(mem_flags::mem_none); - - // Loop tail - if (!K_aligned) { - int lbk = params->K - params->gemm_k_iterations_aligned * BK; - short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); - short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); - - loader_a.load_safe(tile_dims_A); - loader_b.load_safe(tile_dims_B); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - mma_op.mma(As, Bs); - } - - // Store results to device memory - mma_op.store_result( - D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op); - return; - - } - /////////////////////////////////////////////////////////////////////////////// - // MN unaligned loop - else { // Loop over K - unaligned case - short tgp_bm = min(BM, params->M - c_row); - short tgp_bn = min(BN, params->N - c_col); - int leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; - - if (tgp_bm == BM && tgp_bn == BN) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - - mma_op.store_result( - D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op); - return; - - } else if (tgp_bn == BN) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - - return mma_op.store_result_safe( - D, - params->ldd, - C, - addmm_params->ldc, - addmm_params->fdc, - short2(tgp_bn, tgp_bm), - epilogue_op); - - } else if (tgp_bm == BM) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - - return mma_op.store_result_safe( - D, - params->ldd, - C, - addmm_params->ldc, - addmm_params->fdc, - short2(tgp_bn, tgp_bm), - epilogue_op); - - } else { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - - return mma_op.store_result_safe( - D, - params->ldd, - C, - addmm_params->ldc, - addmm_params->fdc, - short2(tgp_bn, tgp_bm), - epilogue_op); - } - } -} - -/////////////////////////////////////////////////////////////////////////////// -// GEMM kernel initializations -/////////////////////////////////////////////////////////////////////////////// - -#define instantiate_gemm( \ - tname, \ - trans_a, \ - trans_b, \ - iname, \ - itype, \ - oname, \ - otype, \ - bm, \ - bn, \ - bk, \ - wm, \ - wn, \ - aname, \ - mn_aligned, \ - kname, \ - k_aligned, \ - ep_name, \ - epilogue) \ - template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm \ - "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname \ - "_K_" #kname "_" #ep_name)]] [[kernel]] void \ - addmm< \ - itype, \ - bm, \ - bn, \ - bk, \ - wm, \ - wn, \ - trans_a, \ - trans_b, \ - mn_aligned, \ - k_aligned, \ - float, \ - epilogue>( \ - const device itype* A [[buffer(0)]], \ - const device itype* B [[buffer(1)]], \ - const device itype* C [[buffer(2)]], \ - device itype* D [[buffer(3)]], \ - const constant GEMMParams* gemm_params [[buffer(4)]], \ - const constant GEMMAddMMParams* params [[buffer(5)]], \ - const constant int* batch_shape [[buffer(6)]], \ - const constant size_t* batch_strides [[buffer(7)]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]]); - -// clang-format off -#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ - instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \ - instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby) // clang-format on - -// clang-format off -#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ - instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ - instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ - instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on - -// clang-format off -#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on - -// clang-format off -#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on - -// clang-format off -instantiate_gemm_shapes_helper(float16, half, float16, half); -instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); - -instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal new file mode 100644 index 000000000..b4304a551 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal @@ -0,0 +1,468 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/utils.h" + +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" + +using namespace metal; +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +constant bool has_batch [[function_constant(10)]]; + +constant bool use_out_source [[function_constant(100)]]; +constant bool do_axpby [[function_constant(110)]]; + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +constant bool do_gather [[function_constant(300)]]; + +constant bool gather_bias = do_gather && use_out_source; + +// clang-format off +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device T* C [[buffer(2), function_constant(use_out_source)]], + device T* D [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], + const constant int* batch_shape [[buffer(6)]], + const constant size_t* batch_strides [[buffer(7)]], + const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], + const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], + const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], + const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], + const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]], + const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on + // Pacifying compiler + (void)lid; + + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + // Find block + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + // Exit early if out of bounds + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + // Adjust for batch + + // Handle gather + if (do_gather) { + // Read indices + uint32_t indx_A, indx_B, indx_C; + + if (has_batch) { + const constant size_t* indx_A_bstrides = batch_strides; + const constant size_t* indx_B_bstrides = + batch_strides + params->batch_ndim; + + ulong2 indx_offsets = elem_to_loc_broadcast( + tid.z, + batch_shape, + indx_A_bstrides, + indx_B_bstrides, + params->batch_ndim); + indx_A = lhs_indices[indx_offsets.x]; + indx_B = rhs_indices[indx_offsets.y]; + + if (use_out_source) { + const constant size_t* indx_C_bstrides = + indx_B_bstrides + params->batch_ndim; + auto indx_offset_C = elem_to_loc( + tid.z, batch_shape, indx_C_bstrides, params->batch_ndim); + indx_C = C_indices[indx_offset_C]; + } + } else { + indx_A = lhs_indices[params->batch_stride_a * tid.z]; + indx_B = rhs_indices[params->batch_stride_b * tid.z]; + + if (use_out_source) { + indx_C = C_indices[addmm_params->batch_stride_c * tid.z]; + } + } + + // Translate indices to offsets + int batch_ndim_A = operand_batch_ndim.x; + const constant int* batch_shape_A = operand_shape; + const constant size_t* batch_strides_A = operand_strides; + A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A); + + int batch_ndim_B = operand_batch_ndim.y; + const constant int* batch_shape_B = batch_shape_A + batch_ndim_A; + const constant size_t* batch_strides_B = batch_strides_A + batch_ndim_A; + B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B); + + if (use_out_source) { + int batch_ndim_C = operand_batch_ndim.z; + const constant int* batch_shape_C = batch_shape_B + batch_ndim_B; + const constant size_t* batch_strides_C = batch_strides_B + batch_ndim_B; + C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C); + } + + } + + // Handle regular batch + else { + if (has_batch) { + const constant size_t* A_bstrides = batch_strides; + const constant size_t* B_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + + A += batch_offsets.x; + B += batch_offsets.y; + + if (use_out_source) { + const constant size_t* C_bstrides = B_bstrides + params->batch_ndim; + C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim); + } + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + + if (use_out_source) { + C += addmm_params->batch_stride_c * tid.z; + } + } + } + + D += params->batch_stride_d * tid.z; + + // Prepare threadgroup memory + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + if (use_out_source) { + C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc; + } + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + + // Prepare iterations + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + // Do unaligned K iterations first + if (!align_K) { + const int k_last = params->gemm_k_iterations_aligned * BK; + const int k_remain = params->K - k_last; + const size_t k_jump_a = + transpose_a ? params->lda * size_t(k_last) : size_t(k_last); + const size_t k_jump_b = + transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); + + // Move loader source ahead to end + loader_a.src += k_jump_a; + loader_b.src += k_jump_b; + + // Load tile + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do matmul + mma_op.mma(As, Bs); + + // Reset source back to start + loader_a.src -= k_jump_a; + loader_b.src -= k_jump_b; + } + + const TransformAdd epilogue_op_add( + addmm_params->alpha, addmm_params->beta); + const TransformAxpby epilogue_op_axpby( + addmm_params->alpha, addmm_params->beta); + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (align_M && align_N) { + // Do gemm + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby); + } else { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result(D, params->ldd); + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + const int leftover_bk = 0; + + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + // Do gemm + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby); + } else { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result(D, params->ldd); + + } else if (align_N || tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_axpby); + } else { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + + } else if (align_M || tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_axpby); + } else { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + + } else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_axpby); + } else { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernel initializations +/////////////////////////////////////////////////////////////////////////////// + +// clang-format off +#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + template [[host_name("steel_gemm_fused_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \ + [[kernel]] void gemm( \ + const device itype *A [[buffer(0)]], \ + const device itype *B [[buffer(1)]], \ + const device itype *C [[buffer(2), function_constant(use_out_source)]], \ + device itype *D [[buffer(3)]], \ + const constant GEMMParams* params [[buffer(4)]], \ + const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], \ + const constant int* batch_shape [[buffer(6)]], \ + const constant size_t* batch_strides [[buffer(7)]], \ + const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], \ + const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], \ + const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], \ + const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], \ + const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]], \ + const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); // clang-format on + +// clang-format off +#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on + +// clang-format off +#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on + +instantiate_gemm_shapes_helper(float16, half, float16, half); +instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); + +instantiate_gemm_shapes_helper(float32, float, float32, float); \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal deleted file mode 100644 index d93417e81..000000000 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" -#include "mlx/backend/metal/kernels/utils.h" - -using namespace metal; -using namespace mlx::steel; - -/////////////////////////////////////////////////////////////////////////////// -// GEMM kernels -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - bool MN_aligned, - bool K_aligned> -[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void bs_gemm( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - device T* D [[buffer(3)]], - const constant GEMMParams* params [[buffer(4)]], - const constant int* batch_shape [[buffer(6)]], - const constant size_t* batch_strides [[buffer(7)]], - const constant uint32_t* lhs_indices [[buffer(10)]], - const constant uint32_t* rhs_indices [[buffer(11)]], - const constant int* batch_shape_A [[buffer(12)]], - const constant size_t* batch_strides_A [[buffer(13)]], - const constant int* batch_shape_B [[buffer(14)]], - const constant size_t* batch_strides_B [[buffer(15)]], - const constant int2& operand_batch_ndim [[buffer(16)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - using gemm_kernel = GEMMKernel< - T, - T, - BM, - BN, - BK, - WM, - WN, - transpose_a, - transpose_b, - MN_aligned, - K_aligned>; - - threadgroup T As[gemm_kernel::tgp_mem_size_a]; - threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; - - uint32_t indx_A; - uint32_t indx_B; - - // Adjust for batch - if (params->batch_ndim > 1) { - const constant size_t* A_bstrides = batch_strides; - const constant size_t* B_bstrides = batch_strides + params->batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); - - indx_A = lhs_indices[batch_offsets.x]; - indx_B = rhs_indices[batch_offsets.y]; - - } else { - indx_A = lhs_indices[params->batch_stride_a * tid.z]; - indx_B = rhs_indices[params->batch_stride_b * tid.z]; - } - - int batch_ndim_A = operand_batch_ndim.x; - int batch_ndim_B = operand_batch_ndim.y; - - if (batch_ndim_A > 1) { - A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A); - } else { - A += indx_A * batch_strides_A[0]; - } - - if (batch_ndim_B > 1) { - B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B); - } else { - B += indx_B * batch_strides_B[0]; - } - - D += params->batch_stride_d * tid.z; - - gemm_kernel::run( - A, B, D, params, As, Bs, simd_lane_id, simd_group_id, tid, lid); -} - -/////////////////////////////////////////////////////////////////////////////// -// GEMM kernel initializations -/////////////////////////////////////////////////////////////////////////////// - -#define instantiate_gemm( \ - tname, \ - trans_a, \ - trans_b, \ - iname, \ - itype, \ - oname, \ - otype, \ - bm, \ - bn, \ - bk, \ - wm, \ - wn, \ - aname, \ - mn_aligned, \ - kname, \ - k_aligned) \ - template [[host_name("steel_block_sparse_gemm_" #tname "_" #iname "_" #oname \ - "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \ - "_MN_" #aname "_K_" #kname)]] [[kernel]] void \ - bs_gemm( \ - const device itype* A [[buffer(0)]], \ - const device itype* B [[buffer(1)]], \ - device itype* D [[buffer(3)]], \ - const constant GEMMParams* params [[buffer(4)]], \ - const constant int* batch_shape [[buffer(6)]], \ - const constant size_t* batch_strides [[buffer(7)]], \ - const constant uint32_t* lhs_indices [[buffer(10)]], \ - const constant uint32_t* rhs_indices [[buffer(11)]], \ - const constant int* batch_shape_A [[buffer(12)]], \ - const constant size_t* batch_strides_A [[buffer(13)]], \ - const constant int* batch_shape_B [[buffer(14)]], \ - const constant size_t* batch_strides_B [[buffer(15)]], \ - const constant int2& operand_batch_ndim [[buffer(16)]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]]); - -// clang-format off -#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ - instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ - instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ - instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on - -// clang-format off -#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on - -// clang-format off -#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on - -// clang-format off -instantiate_gemm_shapes_helper(float16, half, float16, half); -instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); - -instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/gemm/mma.h b/mlx/backend/metal/kernels/steel/gemm/mma.h index 1190cff47..0fab5b0b2 100644 --- a/mlx/backend/metal/kernels/steel/gemm/mma.h +++ b/mlx/backend/metal/kernels/steel/gemm/mma.h @@ -198,6 +198,73 @@ struct BlockMMA { } } + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue( + const device U* C, + const int ldc, + const int fdc, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + // Apply epilogue + accum[0] = epilogue_op.apply(accum[0], C[offset_c]); + accum[1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); + } + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue_safe( + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + dst_tile_dims -= short2(tn + sn, sm + tm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + // Read C + U c_elems[2] = {0}; + + if ((j * TN_stride + 1) < dst_tile_dims.x) { + c_elems[0] = C[offset_c]; + c_elems[1] = C[offset_c + fdc]; + } else if ((j * TN_stride) < dst_tile_dims.x) { + c_elems[0] = C[offset_c]; + } + + // Apply epilogue + accum[0] = epilogue_op.apply(accum[0], c_elems[0]); + accum[1] = epilogue_op.apply(accum[1], c_elems[1]); + } + } + } + /* Store results from simdgroup_matrix results into device memory */ METAL_FUNC void store_result( device U* D, diff --git a/mlx/backend/metal/kernels/steel/gemm/transforms.h b/mlx/backend/metal/kernels/steel/gemm/transforms.h index 100f34925..c0624d21b 100644 --- a/mlx/backend/metal/kernels/steel/gemm/transforms.h +++ b/mlx/backend/metal/kernels/steel/gemm/transforms.h @@ -26,6 +26,10 @@ template struct TransformAdd { TransformAdd(const float, const float) {} + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + static METAL_FUNC OutT apply(InT x, OutT c) { return static_cast(x) + c; } @@ -39,6 +43,10 @@ struct TransformAxpby { TransformAxpby(const float alpha_, const float beta_) : alpha(alpha_), beta(beta_) {} + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + METAL_FUNC OutT apply(InT x, OutT c) const { return static_cast(x * alpha + (beta * c)); } diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index f82d315ba..255391238 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -298,16 +298,46 @@ void steel_matmul_conv_groups( // Prepare kernel name std::ostringstream kname; - kname << "steel_gemm_" << (transpose_a ? 't' : 'n') + kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn << "_MN_" - << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_" - << ((K % bk == 0) ? "t" : "n") << "aligned"; + << "_wm" << wm << "_wn" << wn; + + std::string base_name = kname.str(); + + const bool has_batch = false; + const bool use_out_source = false; + const bool do_axpby = false; + const bool align_M = (M % bm) == 0; + const bool align_N = (N % bn) == 0; + const bool align_K = (K % bk) == 0; + const bool do_gather = false; + + metal::MTLFCList func_consts = { + {&has_batch, MTL::DataType::DataTypeBool, 10}, + {&use_out_source, MTL::DataType::DataTypeBool, 100}, + {&do_axpby, MTL::DataType::DataTypeBool, 110}, + {&align_M, MTL::DataType::DataTypeBool, 200}, + {&align_N, MTL::DataType::DataTypeBool, 201}, + {&align_K, MTL::DataType::DataTypeBool, 202}, + {&do_gather, MTL::DataType::DataTypeBool, 300}, + }; + + // clang-format off + kname << "_has_batch_" << (has_batch ? 't' : 'n') + << "_use_out_source_" << (use_out_source ? 't' : 'n') + << "_do_axpby_" << (do_axpby ? 't' : 'n') + << "_align_M_" << (align_M ? 't' : 'n') + << "_align_N_" << (align_N ? 't' : 'n') + << "_align_K_" << (align_K ? 't' : 'n') + << "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on + + std::string hash_name = kname.str(); // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); + compute_encoder->setComputePipelineState(kernel); // Use problem size to determine threadblock swizzle @@ -351,10 +381,8 @@ void steel_matmul_conv_groups( compute_encoder.set_output_array(out, 3); compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4); - compute_encoder->setBytes( - batch_shape.data(), sizeof(int) * batch_shape.size(), 6); - compute_encoder->setBytes( - batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7); + set_vector_bytes(compute_encoder, batch_shape, 6); + set_vector_bytes(compute_encoder, batch_strides, 7); compute_encoder.dispatchThreadgroups(grid_dims, group_dims); @@ -521,16 +549,46 @@ void steel_matmul( // Prepare kernel name std::ostringstream kname; - kname << "steel_gemm_" << (transpose_a ? 't' : 'n') + kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn << "_MN_" - << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_" - << ((K % bk == 0) ? "t" : "n") << "aligned"; + << "_wm" << wm << "_wn" << wn; + + std::string base_name = kname.str(); + + const bool has_batch = (batch_shape.size() > 1); + const bool use_out_source = false; + const bool do_axpby = false; + const bool align_M = (M % bm) == 0; + const bool align_N = (N % bn) == 0; + const bool align_K = (K % bk) == 0; + const bool do_gather = false; + + metal::MTLFCList func_consts = { + {&has_batch, MTL::DataType::DataTypeBool, 10}, + {&use_out_source, MTL::DataType::DataTypeBool, 100}, + {&do_axpby, MTL::DataType::DataTypeBool, 110}, + {&align_M, MTL::DataType::DataTypeBool, 200}, + {&align_N, MTL::DataType::DataTypeBool, 201}, + {&align_K, MTL::DataType::DataTypeBool, 202}, + {&do_gather, MTL::DataType::DataTypeBool, 300}, + }; + + // clang-format off + kname << "_has_batch_" << (has_batch ? 't' : 'n') + << "_use_out_source_" << (use_out_source ? 't' : 'n') + << "_do_axpby_" << (do_axpby ? 't' : 'n') + << "_align_M_" << (align_M ? 't' : 'n') + << "_align_N_" << (align_N ? 't' : 'n') + << "_align_K_" << (align_K ? 't' : 'n') + << "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on + + std::string hash_name = kname.str(); // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); + compute_encoder->setComputePipelineState(kernel); // Use problem size to determine threadblock swizzle @@ -576,10 +634,8 @@ void steel_matmul( compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4); - compute_encoder->setBytes( - batch_shape.data(), sizeof(int) * batch_shape.size(), 6); - compute_encoder->setBytes( - batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7); + set_vector_bytes(compute_encoder, batch_shape, 6); + set_vector_bytes(compute_encoder, batch_strides, 7); compute_encoder.dispatchThreadgroups(grid_dims, group_dims); @@ -957,13 +1013,10 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&beta_, sizeof(float), 8); compute_encoder->setBytes(&batch_ndim, sizeof(int), 9); - compute_encoder->setBytes(batch_shape.data(), batch_ndim * sizeof(int), 10); - compute_encoder->setBytes( - batch_strides_vec.data(), batch_ndim * sizeof(size_t), 11); - compute_encoder->setBytes( - batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12); - compute_encoder->setBytes( - C_batch_stride.data(), batch_ndim * sizeof(size_t), 13); + set_vector_bytes(compute_encoder, batch_shape, 10); + set_vector_bytes(compute_encoder, batch_strides_vec, 11); + set_vector_bytes(compute_encoder, batch_strides_mat, 12); + set_vector_bytes(compute_encoder, C_batch_stride, 13); int bias_stride = c.strides()[c.ndim() - 1]; compute_encoder->setBytes(&bias_stride, sizeof(int), 14); @@ -1089,18 +1142,48 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { } } + // Prepare kernel name std::ostringstream kname; - kname << "steel_addmm_" << (transpose_a ? 't' : 'n') + kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn << "_MN_" - << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_" - << ((K % bk == 0) ? "t" : "n") << "aligned" - << ((alpha_ == 1. && beta_ == 1.) ? "_add" : "_axpby"); + << "_wm" << wm << "_wn" << wn; + + std::string base_name = kname.str(); + + const bool has_batch = (batch_shape.size() > 1); + const bool use_out_source = true; + const bool do_axpby = !(alpha_ == 1. && beta_ == 1.); + const bool align_M = (M % bm) == 0; + const bool align_N = (N % bn) == 0; + const bool align_K = (K % bk) == 0; + const bool do_gather = false; + + metal::MTLFCList func_consts = { + {&has_batch, MTL::DataType::DataTypeBool, 10}, + {&use_out_source, MTL::DataType::DataTypeBool, 100}, + {&do_axpby, MTL::DataType::DataTypeBool, 110}, + {&align_M, MTL::DataType::DataTypeBool, 200}, + {&align_N, MTL::DataType::DataTypeBool, 201}, + {&align_K, MTL::DataType::DataTypeBool, 202}, + {&do_gather, MTL::DataType::DataTypeBool, 300}, + }; + + // clang-format off + kname << "_has_batch_" << (has_batch ? 't' : 'n') + << "_use_out_source_" << (use_out_source ? 't' : 'n') + << "_do_axpby_" << (do_axpby ? 't' : 'n') + << "_align_M_" << (align_M ? 't' : 'n') + << "_align_N_" << (align_N ? 't' : 'n') + << "_align_K_" << (align_K ? 't' : 'n') + << "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on + + std::string hash_name = kname.str(); // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); + compute_encoder->setComputePipelineState(kernel); int tn = (N + bn - 1) / bn; @@ -1155,10 +1238,8 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&gemm_params, sizeof(GEMMParams), 4); compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 5); - compute_encoder->setBytes( - batch_shape.data(), sizeof(int) * batch_shape.size(), 6); - compute_encoder->setBytes( - batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7); + set_vector_bytes(compute_encoder, batch_shape, 6); + set_vector_bytes(compute_encoder, batch_strides, 7); compute_encoder.dispatchThreadgroups(grid_dims, group_dims); @@ -1593,16 +1674,46 @@ void BlockSparseMM::eval_gpu(const std::vector& inputs, array& out) { // Prepare kernel name std::ostringstream kname; - kname << "steel_block_sparse_gemm_" << (transpose_a ? 't' : 'n') + kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn << "_MN_" - << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_" - << ((K % bk == 0) ? "t" : "n") << "aligned"; + << "_wm" << wm << "_wn" << wn; + + std::string base_name = kname.str(); + + const bool has_batch = batch_ndim > 1; + const bool use_out_source = false; + const bool do_axpby = false; + const bool align_M = (M % bm) == 0; + const bool align_N = (N % bn) == 0; + const bool align_K = (K % bk) == 0; + const bool do_gather = true; + + metal::MTLFCList func_consts = { + {&has_batch, MTL::DataType::DataTypeBool, 10}, + {&use_out_source, MTL::DataType::DataTypeBool, 100}, + {&do_axpby, MTL::DataType::DataTypeBool, 110}, + {&align_M, MTL::DataType::DataTypeBool, 200}, + {&align_N, MTL::DataType::DataTypeBool, 201}, + {&align_K, MTL::DataType::DataTypeBool, 202}, + {&do_gather, MTL::DataType::DataTypeBool, 300}, + }; + + // clang-format off + kname << "_has_batch_" << (has_batch ? 't' : 'n') + << "_use_out_source_" << (use_out_source ? 't' : 'n') + << "_do_axpby_" << (do_axpby ? 't' : 'n') + << "_align_M_" << (align_M ? 't' : 'n') + << "_align_N_" << (align_N ? 't' : 'n') + << "_align_K_" << (align_K ? 't' : 'n') + << "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on + + std::string hash_name = kname.str(); // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); + compute_encoder->setComputePipelineState(kernel); // Use problem size to determine threadblock swizzle @@ -1650,11 +1761,19 @@ void BlockSparseMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_input_array(lhs_indices, 10); compute_encoder.set_input_array(rhs_indices, 11); - set_vector_bytes(compute_encoder, batch_shape_A, 12); - set_vector_bytes(compute_encoder, batch_strides_A, 13); - set_vector_bytes(compute_encoder, batch_shape_B, 14); - set_vector_bytes(compute_encoder, batch_strides_B, 15); - set_vector_bytes(compute_encoder, operand_batch_ndim, 16); + std::vector operand_shape = batch_shape_A; + operand_shape.insert( + operand_shape.end(), batch_shape_B.begin(), batch_shape_B.end()); + + std::vector operand_strides = batch_strides_A; + operand_strides.insert( + operand_strides.end(), batch_strides_B.begin(), batch_strides_B.end()); + + operand_batch_ndim.push_back(0); + + set_vector_bytes(compute_encoder, operand_shape, 13); + set_vector_bytes(compute_encoder, operand_strides, 14); + set_vector_bytes(compute_encoder, operand_batch_ndim, 15); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);