From fddb6933e1cdcb268467fc5d02be6b471bb232b9 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Fri, 13 Jun 2025 10:44:56 -0700 Subject: [PATCH] Collection of refactors (#2274) * Refactor gemv into a function * Refactor splitk step 1 * Refactor split k axpby * Rearrange steel_gemm_regular * Redirect steel_gemm_regular * Add axpby routing to steel_matmul_regular * Refactor AddMM step 1 * Redirect steel_gemm * Update addmm * Comments and format * Some cleanup * Add architecture gen to device * Update no copy condition in normalization to account for axis size 1 --- mlx/backend/metal/conv.cpp | 40 +- mlx/backend/metal/device.cpp | 3 + mlx/backend/metal/device.h | 5 + .../steel/gemm/kernels/steel_gemm_fused.h | 4 +- mlx/backend/metal/matmul.cpp | 1202 ++++++++--------- mlx/backend/metal/matmul.h | 104 +- mlx/backend/metal/normalization.cpp | 4 +- 7 files changed, 720 insertions(+), 642 deletions(-) diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 697afa6a1..9eb6a6385 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -155,26 +155,26 @@ void explicit_gemm_conv_group_ND_gpu( // Perform gemm std::vector copies = {in_unfolded, wt_transpose}; return steel_matmul_regular( - s, - d, - /* a = */ in_unfolded, - /* b = */ wt_transpose, - /* c = */ out, - /* M = */ implicit_M, - /* N = */ implicit_N, - /* K = */ implicit_K, - /* batch_size_out = */ groups, - /* a_cols = */ implicit_K * groups, - /* b_cols = */ implicit_K, - /* out_cols = */ implicit_N * groups, - /* a_transposed = */ false, - /* b_transposed = */ true, - /* batch_shape = */ {1}, - /* batch_strides = */ {0}, - /* A_batch_strides = */ size_t(implicit_K), - /* B_batch_strides = */ size_t(implicit_N) * implicit_K, - /* matrix_stride_out = */ size_t(implicit_N), - /*copies = */ copies); + /* const Stream& s = */ s, + /* Device& d = */ d, + /* const array& a = */ in_unfolded, + /* const array& b = */ wt_transpose, + /* array& c = */ out, + /* int M = */ implicit_M, + /* int N = */ implicit_N, + /* int K = */ implicit_K, + /* int batch_size_out = */ groups, + /* int lda = */ implicit_K * groups, + /* int ldb = */ implicit_K, + /* int ldd = */ implicit_N * groups, + /* bool transpose_a = */ false, + /* bool transpose_b = */ true, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ {1}, + /* Strides batch_strides = */ {0}, + /* int64_t A_batch_strides = */ int64_t(implicit_K), + /* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K, + /* int64_t matrix_stride_out = */ int64_t(implicit_N)); } void implicit_gemm_conv_2D_gpu( diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 425274361..88835eb75 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -297,6 +297,9 @@ Device::Device() { device_ = load_device(); default_library_ = load_default_library(device_); arch_ = std::string(device_->architecture()->name()->utf8String()); + int ag_tens = arch_[arch_.size() - 3] - '0'; + int ag_ones = arch_[arch_.size() - 2] - '0'; + arch_gen_ = ag_tens * 10 + ag_ones; auto arch = arch_.back(); switch (arch) { case 'p': // phone diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 5bfcc6649..f87a8c48b 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -177,6 +177,10 @@ class Device { return arch_; } + int get_architecture_gen() const { + return arch_gen_; + } + void new_queue(int index); MTL::CommandQueue* get_queue(Stream stream); @@ -268,6 +272,7 @@ class Device { library_kernels_; const MTL::ResidencySet* residency_set_{nullptr}; std::string arch_; + int arch_gen_; int max_ops_per_buffer_; int max_mb_per_buffer_; }; diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h index add495d93..85830872d 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h @@ -33,8 +33,8 @@ template < device T* D [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], - const constant int* batch_shape [[buffer(6)]], - const constant int64_t* batch_strides [[buffer(7)]], + const constant int* batch_shape [[buffer(6), function_constant(has_batch)]], + const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index ed96d37ea..be7f3e2f8 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -164,11 +164,17 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) { wn = 2; \ } -void steel_matmul_regular( +/////////////////////////////////////////////////////////////////////////////// +// Regular steel matmul dispatch +/////////////////////////////////////////////////////////////////////////////// + +template +void steel_matmul_regular_axpby( const Stream& s, metal::Device& d, const array& a, const array& b, + const array& c, array& out, int M, int N, @@ -179,12 +185,15 @@ void steel_matmul_regular( int ldd, bool transpose_a, bool transpose_b, + std::vector& copies, Shape batch_shape, Strides batch_strides, int64_t A_batch_stride, int64_t B_batch_stride, int64_t matrix_stride_out, - std::vector& copies) { + int64_t C_batch_stride /* = 0*/, + float alpha /* = 1.0f */, + float beta /* = 0.0f */) { using namespace mlx::steel; // Determine dispatch kernel @@ -196,16 +205,21 @@ void steel_matmul_regular( // Prepare kernel name std::ostringstream kname; - kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn; + + // clang-format off + kname << "steel_gemm_fused_" + << (transpose_a ? 't' : 'n') + << (transpose_b ? 't' : 'n') + << "_" << type_to_name(a) + << "_" << type_to_name(out) + << "_bm" << bm << "_bn" << bn << "_bk" << bk + << "_wm" << wm << "_wn" << wn; // clang-format on std::string base_name = kname.str(); const bool has_batch = (batch_shape.size() > 1); - const bool use_out_source = false; - const bool do_axpby = false; + const bool use_out_source = CHECK_AB && (alpha != 0.0f || beta != 1.0f); + const bool do_axpby = use_out_source && (alpha != 1.0f || beta != 1.0f); const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; @@ -232,18 +246,18 @@ void steel_matmul_regular( // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_gemm_fused_kernel( - d, - base_name, - hash_name, - func_consts, - out, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn); + /* metal::Device& d = */ d, + /* const std::string& kernel_name = */ base_name, + /* const std::string& hash_name = */ hash_name, + /* const metal::MTLFCList& func_consts = */ func_consts, + /* const array& out = */ out, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* int bm = */ bm, + /* int bn = */ bn, + /* int bk = */ bk, + /* int wm = */ wm, + /* int wn = */ wn); compute_encoder.set_compute_pipeline_state(kernel); @@ -286,8 +300,25 @@ void steel_matmul_regular( compute_encoder.set_bytes(params, 4); - compute_encoder.set_vector_bytes(batch_shape, 6); - compute_encoder.set_vector_bytes(batch_strides, 7); + if (has_batch) { + compute_encoder.set_vector_bytes(batch_shape, 6); + compute_encoder.set_vector_bytes(batch_strides, 7); + } + + if (use_out_source) { + int ldc = c.strides()[c.ndim() - 2]; + int fdc = c.strides()[c.ndim() - 1]; + + GEMMAddMMParams params{ + /* const int ldc = */ ldc, + /* const int fdc = */ fdc, + /* const int64_t batch_stride_c = */ C_batch_stride, + /* const float alpha = */ alpha, + /* const float beta = */ beta}; + + compute_encoder.set_input_array(c, 2); + compute_encoder.set_bytes(params, 5); + } compute_encoder.dispatch_threadgroups(grid_dims, group_dims); @@ -295,7 +326,437 @@ void steel_matmul_regular( d.add_temporaries(std::move(copies), s.index); } -void steel_matmul( +/////////////////////////////////////////////////////////////////////////////// +// Split k steel matmul +/////////////////////////////////////////////////////////////////////////////// + +template +void steel_gemm_splitk_axpby( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + float alpha = 1.0f, + float beta = 0.0f) { + using namespace mlx::steel; + + int _tm = M / 16; + int _tn = N / 16; + int _tk = K / 16; + + int bm = M < 40 ? 16 : 32; + int bn = N < 40 ? 16 : 32; + int bk = 16; + int wm = 2, wn = 2; + + int split_k_partitions = _tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16)); + int split_k_partition_stride = M * N; + int gemm_k_iterations = (K / bk) / split_k_partitions; + int split_k_partition_size = gemm_k_iterations * bk; + + array C_split({split_k_partitions, M, N}, float32, nullptr, {}); + C_split.set_data(allocator::malloc(C_split.nbytes())); + copies.push_back(C_split); + + bool mn_aligned = M % bm == 0 && N % bn == 0; + bool k_aligned = K % bk == 0; + std::ostringstream kname; + + // clang-format off + kname << "steel_gemm_splitk_" + << (transpose_a ? 't' : 'n') + << (transpose_b ? 't' : 'n') + << "_" << type_to_name(a) + << "_" << type_to_name(C_split) + << "_bm" << bm << "_bn" << bn << "_bk" << bk + << "_wm" << wm << "_wn" << wn + << "_MN_" << (mn_aligned ? "t" : "n") << "aligned" + << "_K_" << (k_aligned ? "t" : "n") << "aligned"; // clang-format on + + // Encode and dispatch gemm kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_steel_gemm_splitk_kernel( + /* metal::Device& d = */ d, + /* const std::string& kernel_name = */ kname.str(), + /* const array& in = */ a, + /* const array& out = */ C_split, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* int bm = */ bm, + /* int bn = */ bn, + /* int bk = */ bk, + /* int wm = */ wm, + /* int wn = */ wn, + /* bool mn_aligned = */ mn_aligned, + /* bool k_aligned = */ k_aligned); + + compute_encoder.set_compute_pipeline_state(kernel); + + int tn = (N + bn - 1) / bn; + int tm = (M + bm - 1) / bm; + + GEMMSpiltKParams params{ + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ lda, + /* const int ldb = */ ldb, + /* const int ldc = */ N, + /* const int tiles_n = */ tn, + /* const int tiles_m = */ tm, + /* const int split_k_partitions = */ split_k_partitions, + /* const int split_k_partition_stride = */ split_k_partition_stride, + /* const int split_k_partition_size = */ split_k_partition_size, + /* const int gemm_k_iterations_aligned = */ gemm_k_iterations}; + + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); + + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_output_array(C_split, 2); + + compute_encoder.set_bytes(params, 3); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + + // Do accum kernel + { + const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f); + + auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + + type_to_name(C_split); + + if (do_axpby) { + kernel_name = kernel_name + "_axbpy"; + } + + auto kernel = get_steel_gemm_splitk_accum_kernel( + /* metal::Device& d = */ d, + /* const std::string& kernel_name = */ kernel_name, + /* const array& in = */ C_split, + /* const array& out = */ out, + /* bool axbpy = */ do_axpby); + compute_encoder.set_compute_pipeline_state(kernel); + + // Set the arguments for the kernel + compute_encoder.set_input_array(C_split, 0); + compute_encoder.set_output_array(out, 1); + compute_encoder.set_bytes(split_k_partitions, 2); + compute_encoder.set_bytes(split_k_partition_stride, 3); + compute_encoder.set_bytes(N, 4); + + if (do_axpby) { + int ldc = c.strides()[c.ndim() - 2]; + int fdc = c.strides()[c.ndim() - 1]; + + compute_encoder.set_input_array(c, 5); + compute_encoder.set_bytes(ldc, 6); + compute_encoder.set_bytes(fdc, 7); + compute_encoder.set_bytes(alpha, 8); + compute_encoder.set_bytes(beta, 9); + } + + // Launch enough thread groups for each output + MTL::Size grid_dims = MTL::Size(N, M, 1); + auto group_dims = get_block_dims(N, M, 1); + compute_encoder.dispatch_threads(grid_dims, group_dims); + } + + d.add_temporaries(std::move(copies), s.index); +} + +/////////////////////////////////////////////////////////////////////////////// +// Split matmul routing +/////////////////////////////////////////////////////////////////////////////// + +template +void steel_matmul_axpby( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape /* = {} */, + Strides A_batch_stride /* = {} */, + Strides B_batch_stride /* = {} */, + Strides C_batch_stride /* = {} */, + float alpha /* = 1.0f */, + float beta /* = 0.0f */) { + if (batch_shape.empty()) { + ///////////////////////////////////////////////////////////////////////////// + // Check and collapse batch dimensions + if constexpr (CHECK_AB) { + auto [batch_shape_, A_bstride_, B_bstride_, C_bstride_] = + collapse_batches(a, b, c); + + batch_shape = batch_shape_; + A_batch_stride = A_bstride_; + B_batch_stride = B_bstride_; + C_batch_stride = C_bstride_; + // Collapse batches into M if needed + if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && + C_batch_stride.back() == M * c.strides()[c.ndim() - 2] && + B_batch_stride.back() == 0) { + M *= batch_shape.back(); + batch_size_out = 1; + + A_batch_stride = {0}; + B_batch_stride = {0}; + C_batch_stride = {0}; + batch_shape = {1}; + } + } else { + auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b); + + batch_shape = batch_shape_; + A_batch_stride = A_bstride_; + B_batch_stride = B_bstride_; + // Collapse batches into M if needed + if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && + B_batch_stride.back() == 0) { + M *= batch_shape.back(); + batch_size_out = 1; + + A_batch_stride = {0}; + B_batch_stride = {0}; + batch_shape = {1}; + } + } + } + + ///////////////////////////////////////////////////////////////////////////// + // Split K specialization + + int _tm = M / 16; + int _tn = N / 16; + int _tk = K / 16; + + if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { + return steel_gemm_splitk_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* float alpha = */ alpha, + /* float beta = */ beta); + } + + ///////////////////////////////////////////////////////////////////////////// + // Regular kernel dispatch + auto batch_strides = A_batch_stride; + batch_strides.insert( + batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); + if (CHECK_AB && !C_batch_stride.empty()) { + batch_strides.insert( + batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end()); + } + + int64_t A_batch_stride_ = A_batch_stride.empty() ? 0 : A_batch_stride.back(); + int64_t B_batch_stride_ = B_batch_stride.empty() ? 0 : B_batch_stride.back(); + int64_t C_batch_stride_ = C_batch_stride.empty() ? 0 : C_batch_stride.back(); + + return steel_matmul_regular_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* int ldd = */ N, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ std::move(batch_shape), + /* Strides batch_strides = */ std::move(batch_strides), + /* int64_t A_batch_stride = */ A_batch_stride_, + /* int64_t B_batch_stride = */ B_batch_stride_, + /* int64_t matrix_stride_out = */ int64_t(M) * N, + /* int64_t C_batch_stride = */ C_batch_stride_, + /* float alpha = */ alpha, + /* float beta = */ beta); +} + +/////////////////////////////////////////////////////////////////////////////// +// GEMV dispatch +/////////////////////////////////////////////////////////////////////////////// + +template +void gemv_axbpy( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape = {}, + Strides A_batch_stride = {}, + Strides B_batch_stride = {}, + Strides C_batch_stride = {}, + float alpha = 1.0f, + float beta = 0.0f) { + // Collect problem info + bool is_b_matrix = N != 1; + + auto& mat = is_b_matrix ? b : a; + auto& vec = is_b_matrix ? a : b; + bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a; + int in_vector_len = K; + int out_vector_len = is_b_matrix ? N : M; + + int mat_cols = transpose_mat ? out_vector_len : in_vector_len; + int mat_rows = transpose_mat ? in_vector_len : out_vector_len; + int mat_ld = is_b_matrix ? ldb : lda; + + auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; + auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride; + + int stride_mat = batch_strides_mat.back(); + int stride_vec = batch_strides_vec.back(); + + // Determine if inputs have simple batching / broadcasting + bool contiguous_kernel = (batch_shape.size() == 1); + + int batch_ndim = batch_shape.size(); + + // Determine dispatch kernel + int tm = 4, tn = 4; + int sm = 1, sn = 32; + int bm = 1, bn = 1; + int n_out_per_tgp; + std::ostringstream kname; + + if (transpose_mat) { + if (in_vector_len >= 8192 && out_vector_len >= 2048) { + sm = 4; + sn = 8; + } else { + sm = 8; + sn = 4; + } + + if (out_vector_len >= 2048) { + bn = 16; + } else if (out_vector_len >= 512) { + bn = 4; + } else { + bn = 2; + } + + // Specialized kernel for very small outputs + tn = out_vector_len < tn ? 1 : tn; + + n_out_per_tgp = bn * sn * tn; + kname << "gemv_t_" << type_to_name(out); + + } else { + bm = out_vector_len >= 4096 ? 8 : 4; + sn = 32; + + // Specialized kernel for very small outputs + tm = out_vector_len < tm ? 1 : tm; + + n_out_per_tgp = bm * sm * tm; + kname << "gemv_" << type_to_name(out); + } + + const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f); + + // clang-format off + kname << "_bm" << bm << "_bn" << bn + << "_sm" << sm << "_sn" << sn + << "_tm" << tm << "_tn" << tn + << "_nc" << !contiguous_kernel + << "_axpby" << do_axpby; // clang-format on + + // Encode and dispatch kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder.set_compute_pipeline_state(kernel); + + int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; + MTL::Size group_dims = MTL::Size(32, bn, bm); + MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); + + compute_encoder.set_input_array(mat, 0); + compute_encoder.set_input_array(vec, 1); + compute_encoder.set_output_array(out, 3); + + compute_encoder.set_bytes(in_vector_len, 4); + compute_encoder.set_bytes(out_vector_len, 5); + compute_encoder.set_bytes(mat_ld, 6); + + compute_encoder.set_bytes(batch_ndim, 9); + compute_encoder.set_vector_bytes(batch_shape, 10); + compute_encoder.set_vector_bytes(batch_strides_vec, 11); + compute_encoder.set_vector_bytes(batch_strides_mat, 12); + + if (do_axpby) { + compute_encoder.set_input_array(c, 2); + + compute_encoder.set_bytes(alpha, 7); + compute_encoder.set_bytes(beta, 8); + + compute_encoder.set_vector_bytes(C_batch_stride, 13); + + int bias_stride = c.strides()[c.ndim() - 1]; + compute_encoder.set_bytes(bias_stride, 14); + } + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + + d.add_temporaries(std::move(copies), s.index); +} + +inline void gemv( const Stream& s, metal::Device& d, const array& a, @@ -310,166 +771,34 @@ void steel_matmul( bool transpose_a, bool transpose_b, std::vector& copies, - Shape batch_shape /* = {} */, - Strides A_batch_stride /* = {} */, - Strides B_batch_stride /* = {} */) { - using namespace mlx::steel; - - if (batch_shape.empty()) { - ///////////////////////////////////////////////////////////////////////////// - // Check and collapse batch dimensions - auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b); - - batch_shape = batch_shape_; - A_batch_stride = A_bstride_; - B_batch_stride = B_bstride_; - // Collapse batches into M if needed - if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && - a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && - B_batch_stride.back() == 0) { - M *= batch_shape.back(); - batch_size_out = 1; - - A_batch_stride = {0}; - B_batch_stride = {0}; - batch_shape = {1}; - } - } - - size_t matrix_stride_out = size_t(M) * N; - - ///////////////////////////////////////////////////////////////////////////// - // Split K specialization - - int _tm = M / 16; - int _tn = N / 16; - int _tk = K / 16; - - if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { - int bm = M < 40 ? 16 : 32; - int bn = N < 40 ? 16 : 32; - int bk = 16; - int wm = 2, wn = 2; - - int split_k_partitions = - _tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16)); - int split_k_partition_stride = M * N; - int gemm_k_iterations = (K / bk) / split_k_partitions; - int split_k_partition_size = gemm_k_iterations * bk; - - array C_split({split_k_partitions, M, N}, float32, nullptr, {}); - C_split.set_data(allocator::malloc(C_split.nbytes())); - copies.push_back(C_split); - - bool mn_aligned = M % bm == 0 && N % bn == 0; - bool k_aligned = K % bk == 0; - std::ostringstream kname; - kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n") - << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned"; - - // Encode and dispatch gemm kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = get_steel_gemm_splitk_kernel( - d, - kname.str(), - a, - C_split, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn, - mn_aligned, - k_aligned); - compute_encoder.set_compute_pipeline_state(kernel); - - int tn = (N + bn - 1) / bn; - int tm = (M + bm - 1) / bm; - - GEMMSpiltKParams params{ - /* const int M = */ M, - /* const int N = */ N, - /* const int K = */ K, - /* const int lda = */ lda, - /* const int ldb = */ ldb, - /* const int ldc = */ N, - /* const int tiles_n = */ tn, - /* const int tiles_m = */ tm, - /* const int split_k_partitions = */ split_k_partitions, - /* const int split_k_partition_stride = */ split_k_partition_stride, - /* const int split_k_partition_size = */ split_k_partition_size, - /* const int gemm_k_iterations_aligned = */ gemm_k_iterations}; - - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); - - compute_encoder.set_input_array(a, 0); - compute_encoder.set_input_array(b, 1); - compute_encoder.set_output_array(C_split, 2); - - compute_encoder.set_bytes(params, 3); - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - // Do accum kernel - { - auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + - type_to_name(C_split); - - auto kernel = get_steel_gemm_splitk_accum_kernel( - d, kernel_name, C_split, out, false); - compute_encoder.set_compute_pipeline_state(kernel); - - // Set the arguments for the kernel - compute_encoder.set_input_array(C_split, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder.set_bytes(split_k_partitions, 2); - compute_encoder.set_bytes(split_k_partition_stride, 3); - compute_encoder.set_bytes(N, 4); - - // Launch enough thread groups for each output - MTL::Size grid_dims = MTL::Size(N, M, 1); - auto group_dims = get_block_dims(N, M, 1); - compute_encoder.dispatch_threads(grid_dims, group_dims); - } - - d.add_temporaries(std::move(copies), s.index); - return; - } - - ///////////////////////////////////////////////////////////////////////////// - // Regular kernel dispatch - auto batch_strides = A_batch_stride; - batch_strides.insert( - batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); - - steel_matmul_regular( - s, - d, - a, - b, - out, - M, - N, - K, - batch_size_out, - lda, - ldb, - N, - transpose_a, - transpose_b, - std::move(batch_shape), - std::move(batch_strides), - A_batch_stride.back(), - B_batch_stride.back(), - matrix_stride_out, - copies); + Shape batch_shape = {}, + Strides A_batch_stride = {}, + Strides B_batch_stride = {}) { + return gemv_axbpy( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride); } +/////////////////////////////////////////////////////////////////////////////// +// Matmul implementation +/////////////////////////////////////////////////////////////////////////////// + void Matmul::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (!issubdtype(out.dtype(), floating)) { @@ -528,102 +857,26 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { // Route to gemv if needed if (std::min(M, N) == 1) { - // Collect problem info - bool is_b_matrix = N != 1; - - auto& mat = is_b_matrix ? b : a; - auto& vec = is_b_matrix ? a : b; - bool transpose_mat = is_b_matrix ? !b_transposed : a_transposed; - int in_vector_len = K; - int out_vector_len = is_b_matrix ? N : M; - - int mat_cols = transpose_mat ? out_vector_len : in_vector_len; - int mat_rows = transpose_mat ? in_vector_len : out_vector_len; - int mat_ld = is_b_matrix ? b_cols : a_cols; - - auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; - auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride; - - int stride_mat = batch_strides_mat.back(); - int stride_vec = batch_strides_vec.back(); - - // Determine if inputs have simple batching / broadcasting - bool contiguous_kernel = (batch_shape.size() == 1); - - int batch_ndim = batch_shape.size(); - - // Determine dispatch kernel - int tm = 4, tn = 4; - int sm = 1, sn = 32; - int bm = 1, bn = 1; - int n_out_per_tgp; - std::ostringstream kname; - - if (transpose_mat) { - if (in_vector_len >= 8192 && out_vector_len >= 2048) { - sm = 4; - sn = 8; - } else { - sm = 8; - sn = 4; - } - - if (out_vector_len >= 2048) { - bn = 16; - } else if (out_vector_len >= 512) { - bn = 4; - } else { - bn = 2; - } - - // Specialized kernel for very small outputs - tn = out_vector_len < tn ? 1 : tn; - - n_out_per_tgp = bn * sn * tn; - kname << "gemv_t_" << type_to_name(out); - - } else { - bm = out_vector_len >= 4096 ? 8 : 4; - sn = 32; - - // Specialized kernel for very small outputs - tm = out_vector_len < tm ? 1 : tm; - - n_out_per_tgp = bm * sm * tm; - kname << "gemv_" << type_to_name(out); - } - - kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" - << tm << "_tn" << tn; - kname << "_nc" << !contiguous_kernel << "_axpby0"; - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); - compute_encoder.set_compute_pipeline_state(kernel); - - int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; - MTL::Size group_dims = MTL::Size(32, bn, bm); - MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); - - compute_encoder.set_input_array(mat, 0); - compute_encoder.set_input_array(vec, 1); - compute_encoder.set_output_array(out, 3); - - compute_encoder.set_bytes(in_vector_len, 4); - compute_encoder.set_bytes(out_vector_len, 5); - compute_encoder.set_bytes(mat_ld, 6); - - compute_encoder.set_bytes(batch_ndim, 9); - compute_encoder.set_vector_bytes(batch_shape, 10); - compute_encoder.set_vector_bytes(batch_strides_vec, 11); - compute_encoder.set_vector_bytes(batch_strides_mat, 12); - - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - d.add_temporaries(std::move(copies), s.index); - return; + return gemv( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ a_cols, + /* int ldb = */ b_cols, + /* bool transpose_a = */ a_transposed, + /* bool transpose_b = */ b_transposed, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ std::move(batch_shape), + /* Strides A_batch_stride = */ std::move(A_batch_stride), + /* Strides B_batch_stride = */ std::move(B_batch_stride)); } + ///////////////////////////////////////////////////////////////////////////// // Gemm specialization @@ -641,12 +894,16 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { /* int ldb = */ b_cols, /* bool transpose_a = */ a_transposed, /* bool transpose_b = */ b_transposed, - /* std::vector& = */ copies, - /* Shape batch_shape = */ batch_shape, - /* Strides A_batch_stride = */ A_batch_stride, - /* Strides B_batch_stride = */ B_batch_stride); + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ std::move(batch_shape), + /* Strides A_batch_stride = */ std::move(A_batch_stride), + /* Strides B_batch_stride = */ std::move(B_batch_stride)); } +/////////////////////////////////////////////////////////////////////////////// +// AddMM implementation +/////////////////////////////////////////////////////////////////////////////// + void AddMM::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 3); if (!issubdtype(out.dtype(), floating)) { @@ -726,346 +983,61 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { // Route to gemv if needed if (std::min(M, N) == 1) { - // Collect problem info - bool is_b_matrix = N != 1; - - auto& mat = is_b_matrix ? b : a; - auto& vec = is_b_matrix ? a : b; - bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a; - int in_vector_len = K; - int out_vector_len = is_b_matrix ? N : M; - - int mat_cols = transpose_mat ? out_vector_len : in_vector_len; - int mat_rows = transpose_mat ? in_vector_len : out_vector_len; - int mat_ld = is_b_matrix ? b_cols : a_cols; - - auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; - auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride; - - int stride_mat = batch_strides_mat.back(); - int stride_vec = batch_strides_vec.back(); - - // Determine if inputs have simple batching / broadcasting - bool contiguous_kernel = (batch_shape.size() == 1); - - int batch_ndim = batch_shape.size(); - - // Determine dispatch kernel - int tm = 4, tn = 4; - int sm = 1, sn = 32; - int bm = 1, bn = 1; - int n_out_per_tgp; - std::ostringstream kname; - - if (transpose_mat) { - if (in_vector_len >= 8192 && out_vector_len >= 2048) { - sm = 4; - sn = 8; - } else { - sm = 8; - sn = 4; - } - - if (out_vector_len >= 2048) { - bn = 16; - } else if (out_vector_len >= 512) { - bn = 4; - } else { - bn = 2; - } - - // Specialized kernel for very small outputs - tn = out_vector_len < tn ? 1 : tn; - - n_out_per_tgp = bn * sn * tn; - kname << "gemv_t_" << type_to_name(out); - - } else { - bm = out_vector_len >= 4096 ? 8 : 4; - sn = 32; - - // Specialized kernel for very small outputs - tm = out_vector_len < tm ? 1 : tm; - - n_out_per_tgp = bm * sm * tm; - kname << "gemv_" << type_to_name(out); - } - - kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" - << tm << "_tn" << tn; - kname << "_nc" << !contiguous_kernel << "_axpby1"; - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); - compute_encoder.set_compute_pipeline_state(kernel); - - int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; - MTL::Size group_dims = MTL::Size(32, bn, bm); - MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); - - compute_encoder.set_input_array(mat, 0); - compute_encoder.set_input_array(vec, 1); - compute_encoder.set_input_array(c, 2); - compute_encoder.set_output_array(out, 3); - - compute_encoder.set_bytes(in_vector_len, 4); - compute_encoder.set_bytes(out_vector_len, 5); - compute_encoder.set_bytes(mat_ld, 6); - - compute_encoder.set_bytes(alpha_, 7); - compute_encoder.set_bytes(beta_, 8); - - compute_encoder.set_bytes(batch_ndim, 9); - compute_encoder.set_vector_bytes(batch_shape, 10); - compute_encoder.set_vector_bytes(batch_strides_vec, 11); - compute_encoder.set_vector_bytes(batch_strides_mat, 12); - compute_encoder.set_vector_bytes(C_batch_stride, 13); - - int bias_stride = c.strides()[c.ndim() - 1]; - compute_encoder.set_bytes(bias_stride, 14); - - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - d.add_temporaries(std::move(copies), s.index); - return; - } - - using namespace mlx::steel; - - ///////////////////////////////////////////////////////////////////////////// - // Split K specialization - - int _tm = M / 16; - int _tn = N / 16; - int _tk = K / 16; - - if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { - int bm = M < 40 ? 16 : 32; - int bn = N < 40 ? 16 : 32; - int bk = 16; - int wm = 2, wn = 2; - - int split_k_partitions = - _tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16)); - int split_k_partition_stride = M * N; - int gemm_k_iterations = (K / bk) / split_k_partitions; - int split_k_partition_size = gemm_k_iterations * bk; - - array C_split({split_k_partitions, M, N}, float32, nullptr, {}); - C_split.set_data(allocator::malloc(C_split.nbytes())); - copies.push_back(C_split); - - bool mn_aligned = M % bm == 0 && N % bn == 0; - bool k_aligned = K % bk == 0; - - std::ostringstream kname; - kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n") - << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned"; - - // Encode and dispatch gemm kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = get_steel_gemm_splitk_kernel( - d, - kname.str(), - a, - C_split, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn, - mn_aligned, - k_aligned); - - compute_encoder.set_compute_pipeline_state(kernel); - - int tn = (N + bn - 1) / bn; - int tm = (M + bm - 1) / bm; - - GEMMSpiltKParams params{ - M, - N, - K, - lda, - ldb, - N, - tn, - tm, - split_k_partitions, - split_k_partition_stride, - split_k_partition_size, - gemm_k_iterations}; - - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); - - compute_encoder.set_input_array(a, 0); - compute_encoder.set_input_array(b, 1); - compute_encoder.set_output_array(C_split, 2); - - compute_encoder.set_bytes(params, 3); - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - // Do accum kernel - { - auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + - type_to_name(C_split) + "_axbpy"; - auto kernel = get_steel_gemm_splitk_accum_kernel( - d, kernel_name, C_split, out, true); - - compute_encoder.set_compute_pipeline_state(kernel); - - // Set the arguments for the kernel - compute_encoder.set_input_array(C_split, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder.set_bytes(split_k_partitions, 2); - compute_encoder.set_bytes(split_k_partition_stride, 3); - compute_encoder.set_bytes(N, 4); - compute_encoder.set_input_array(c, 5); - compute_encoder.set_bytes(ldc, 6); - compute_encoder.set_bytes(fdc, 7); - compute_encoder.set_bytes(alpha_, 8); - compute_encoder.set_bytes(beta_, 9); - - // Launch enough thread groups for each output - MTL::Size grid_dims = MTL::Size(N, M, 1); - auto group_dims = get_block_dims(N, M, 1); - compute_encoder.dispatch_threads(grid_dims, group_dims); - } - - d.add_temporaries(std::move(copies), s.index); - return; + return gemv_axbpy( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride, + /* Strides C_batch_stride = */ C_batch_stride, + /* float alpha = */ alpha_, + /* float beta = */ beta_); } ///////////////////////////////////////////////////////////////////////////// // Regular addmm dispatch - // Determine dispatch kernel - int bm = 64, bn = 64, bk = 16; - int wm = 2, wn = 2; - - char devc = d.get_architecture().back(); - GEMM_TPARAM_MACRO(devc) - - // Prepare kernel name - std::ostringstream kname; - kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn; - - std::string base_name = kname.str(); - - const bool has_batch = (batch_shape.size() > 1); - const bool use_out_source = true; - const bool do_axpby = !(alpha_ == 1. && beta_ == 1.); - const bool align_M = (M % bm) == 0; - const bool align_N = (N % bn) == 0; - const bool align_K = (K % bk) == 0; - - metal::MTLFCList func_consts = { - {&has_batch, MTL::DataType::DataTypeBool, 10}, - {&use_out_source, MTL::DataType::DataTypeBool, 100}, - {&do_axpby, MTL::DataType::DataTypeBool, 110}, - {&align_M, MTL::DataType::DataTypeBool, 200}, - {&align_N, MTL::DataType::DataTypeBool, 201}, - {&align_K, MTL::DataType::DataTypeBool, 202}, - }; - - // clang-format off - kname << "_has_batch_" << (has_batch ? 't' : 'n') - << "_use_out_source_" << (use_out_source ? 't' : 'n') - << "_do_axpby_" << (do_axpby ? 't' : 'n') - << "_align_M_" << (align_M ? 't' : 'n') - << "_align_N_" << (align_N ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on - - std::string hash_name = kname.str(); - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = get_steel_gemm_fused_kernel( - d, - base_name, - hash_name, - func_consts, - out, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn); - - compute_encoder.set_compute_pipeline_state(kernel); - - int tn = (N + bn - 1) / bn; - int tm = (M + bm - 1) / bm; - - // TODO: Explore device-based tuning for swizzle - int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2); - - // Prepare steel matmul params - GEMMParams gemm_params{ - /* const int M = */ M, - /* const int N = */ N, - /* const int K = */ K, - /* const int lda = */ lda, - /* const int ldb = */ ldb, - /* const int ldd = */ N, - /* const int tiles_n = */ tn, - /* const int tiles_m = */ tm, - /* const int64_t batch_stride_a = */ A_batch_stride.back(), - /* const int64_t batch_stride_b = */ B_batch_stride.back(), - /* const int64_t batch_stride_d = */ matrix_stride_out, - /* const int swizzle_log = */ swizzle_log, - /* const int gemm_k_iterations_aligned = */ (K / bk), - /* const int batch_ndim = */ int(batch_shape.size())}; - - GEMMAddMMParams params{ - /* const int ldc = */ ldc, - /* const int fdc = */ fdc, - /* const int64_t batch_stride_c = */ C_batch_stride.back(), - /* const float alpha = */ alpha_, - /* const float beta = */ beta_}; - - int tile = 1 << swizzle_log; - tm = (tm + tile - 1) / tile; - tn = tn * tile; - - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); - - Strides batch_strides = A_batch_stride; - batch_strides.insert( - batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); - batch_strides.insert( - batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end()); - - // Launch kernel - compute_encoder.set_input_array(a, 0); - compute_encoder.set_input_array(b, 1); - compute_encoder.set_input_array(c, 2); - compute_encoder.set_output_array(out, 3); - - compute_encoder.set_bytes(gemm_params, 4); - compute_encoder.set_bytes(params, 5); - - compute_encoder.set_vector_bytes(batch_shape, 6); - compute_encoder.set_vector_bytes(batch_strides, 7); - - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - d.add_temporaries(std::move(copies), s.index); + return steel_matmul_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride, + /* Strides B_batch_stride = */ C_batch_stride, + /* float alpha = */ alpha_, + /* float beta = */ beta_); } +/////////////////////////////////////////////////////////////////////////////// +// BlockMaskedMM implementation +/////////////////////////////////////////////////////////////////////////////// + void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { using namespace mlx::steel; // assert(inputs.size() == 2); @@ -1454,6 +1426,10 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { d.add_temporaries(std::move(copies), s.index); } +/////////////////////////////////////////////////////////////////////////////// +// GatherMM implementation +/////////////////////////////////////////////////////////////////////////////// + void gather_mm_rhs( const array& a_, const array& b_, diff --git a/mlx/backend/metal/matmul.h b/mlx/backend/metal/matmul.h index 09ffe05a8..218664b1f 100644 --- a/mlx/backend/metal/matmul.h +++ b/mlx/backend/metal/matmul.h @@ -6,7 +6,34 @@ namespace mlx::core { -void steel_matmul_regular( +template +void steel_matmul_regular_axpby( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + int ldd, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape, + Strides batch_strides, + int64_t A_batch_stride, + int64_t B_batch_stride, + int64_t matrix_stride_out, + int64_t C_batch_stride = 0, + float alpha = 1.0f, + float beta = 0.0f); + +inline void steel_matmul_regular( const Stream& s, metal::Device& d, const array& a, @@ -21,14 +48,61 @@ void steel_matmul_regular( int ldd, bool transpose_a, bool transpose_b, + std::vector& copies, Shape batch_shape, Strides batch_strides, int64_t A_batch_stride, int64_t B_batch_stride, - int64_t matrix_stride_out, - std::vector& copies); + int64_t matrix_stride_out) { + return steel_matmul_regular_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* int ldd = */ ldd, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides batch_strides = */ batch_strides, + /* int64_t A_batch_stride = */ A_batch_stride, + /* int64_t B_batch_stride = */ B_batch_stride, + /* int64_t matrix_stride_out = */ matrix_stride_out); +} -void steel_matmul( +template +void steel_matmul_axpby( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape = {}, + Strides A_batch_stride = {}, + Strides B_batch_stride = {}, + Strides C_batch_stride = {}, + float alpha = 1.0f, + float beta = 0.0f); + +inline void steel_matmul( const Stream& s, metal::Device& d, const array& a, @@ -45,6 +119,26 @@ void steel_matmul( std::vector& copies, Shape batch_shape = {}, Strides A_batch_stride = {}, - Strides B_batch_stride = {}); + Strides B_batch_stride = {}) { + return steel_matmul_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride); +} } // namespace mlx::core diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index d570bf3c0..8674eff72 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -26,7 +26,7 @@ void RMSNorm::eval_gpu( bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; - no_copy &= (s == 0 || s == x.shape().back()); + no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1); } if (no_copy) { if (x.is_donatable()) { @@ -227,7 +227,7 @@ void LayerNorm::eval_gpu( bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; - no_copy &= (s == 0 || s == x.shape().back()); + no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1); } if (no_copy) { if (x.is_donatable()) {