// Copyright © 2023-2024 Apple Inc. #include #include #include #include #include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/matmul.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/steel/gemm/params.h" #include "mlx/backend/metal/matmul.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/utils.h" namespace mlx::core { namespace { std::tuple check_transpose( std::vector& copies, const Stream& s, const array& arr, bool is_vector) { auto stx = arr.strides()[arr.ndim() - 2]; auto sty = arr.strides()[arr.ndim() - 1]; if (sty == 1 && (!is_vector || stx == arr.shape(-1))) { return std::make_tuple(false, stx, arr); } else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) { return std::make_tuple(true, sty, arr); } else { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); copy_gpu(arr, arr_copy, CopyType::General, s); copies.push_back(arr_copy); return std::make_tuple(false, arr.shape(-1), arr_copy); } }; inline array ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) { if (!x.flags().row_contiguous) { array x_copy(x.shape(), x.dtype(), nullptr, {}); copy_gpu(x, x_copy, CopyType::General, s); d.add_temporary(x_copy, s.index); return x_copy; } else { return x; } } inline std::tuple ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) { if (x.flags().row_contiguous) { return std::make_tuple(false, x.strides()[x.ndim() - 2], x); } bool rc = true; for (int i = 0; i < x.ndim() - 3; i++) { rc &= x.strides()[i + 1] * x.shape(i) == x.strides()[i]; } if (rc) { auto stx = x.strides()[x.ndim() - 2]; auto sty = x.strides()[x.ndim() - 1]; auto K = x.shape(-2); auto N = x.shape(-1); if (sty == 1 && (N != 1 || stx == N)) { return std::make_tuple(false, stx, x); } if (stx == 1 && (N != 1 || sty == K)) { return std::make_tuple(true, sty, x); } } array x_copy(x.shape(), x.dtype(), nullptr, {}); copy_gpu(x, x_copy, CopyType::General, s); d.add_temporary(x_copy, s.index); return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy); } } // namespace /////////////////////////////////////////////////////////////////////////////// // Steel matmul fallback /////////////////////////////////////////////////////////////////////////////// #define GEMM_TPARAM_MACRO(devc) \ if (devc == 'g' || devc == 'p') { /* Small device */ \ if (!transpose_a && transpose_b) { /* nt */ \ bm = 64; \ bn = 32; \ bk = 32; \ wm = 2; \ wn = 2; \ } else if (out.dtype() != float32) { /* half and bfloat */ \ bm = 64; \ bn = 64; \ bk = 16; \ wm = 1; \ wn = 2; \ } \ } else if (devc == 'd') { /* Large device */ \ if ((size_t)batch_size_out * M * N >= 1ul << 20) { /* large matmul */ \ if (out.dtype() != float32) { /* half and bfloat */ \ if (2 * std::max(M, N) > K) { /* Reasonable K */ \ bm = 64; \ bn = 64; \ bk = 16; \ wm = 1; \ wn = 2; \ } else if (!transpose_a && transpose_b) { /* nt with large k */ \ bm = 64; \ bn = 32; \ bk = 32; \ wm = 2; \ wn = 2; \ } else { /* nn with large K */ \ bm = 32; \ bn = 64; \ bk = 16; \ wm = 1; \ wn = 2; \ } \ } /* float takes default */ \ } else { /* smaller matmul */ \ if (out.dtype() != float32) { /* half and bfloat */ \ if (!transpose_a && transpose_b) { /* nt */ \ bm = 64; \ bn = 32; \ bk = 32; \ wm = 2; \ wn = 2; \ } else { /* nn */ \ bm = 64; \ bn = 64; \ bk = 16; \ wm = 1; \ wn = 2; \ } \ } else { /* floats */ \ if (!transpose_a && transpose_b) { /* nt */ \ bm = 32; \ bn = 64; \ bk = 16; \ wm = 1; \ wn = 2; \ } else { /* nn */ \ bm = 64; \ bn = 32; \ bk = 32; \ wm = 2; \ wn = 2; \ } \ } \ } \ } else { /* Medium device */ \ bm = 64; \ bn = 64; \ bk = 16; \ wm = 2; \ wn = 2; \ } /////////////////////////////////////////////////////////////////////////////// // Regular steel matmul dispatch /////////////////////////////////////////////////////////////////////////////// template void steel_matmul_regular_axpby( const Stream& s, metal::Device& d, const array& a, const array& b, const array& c, array& out, int M, int N, int K, int batch_size_out, int lda, int ldb, int ldd, bool transpose_a, bool transpose_b, std::vector& copies, Shape batch_shape, Strides batch_strides, int64_t A_batch_stride, int64_t B_batch_stride, int64_t matrix_stride_out, int64_t C_batch_stride /* = 0*/, float alpha /* = 1.0f */, float beta /* = 0.0f */) { using namespace mlx::steel; // Determine dispatch kernel int bm = 64, bn = 64, bk = 16; int wm = 2, wn = 2; char devc = d.get_architecture().back(); GEMM_TPARAM_MACRO(devc) // Prepare kernel name std::ostringstream kname; // clang-format off kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn; // clang-format on std::string base_name = kname.str(); const bool has_batch = (batch_shape.size() > 1); const bool use_out_source = CHECK_AB && (alpha != 0.0f || beta != 1.0f); const bool do_axpby = use_out_source && (alpha != 1.0f || beta != 1.0f); const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; metal::MTLFCList func_consts = { {&has_batch, MTL::DataType::DataTypeBool, 10}, {&use_out_source, MTL::DataType::DataTypeBool, 100}, {&do_axpby, MTL::DataType::DataTypeBool, 110}, {&align_M, MTL::DataType::DataTypeBool, 200}, {&align_N, MTL::DataType::DataTypeBool, 201}, {&align_K, MTL::DataType::DataTypeBool, 202}, }; // clang-format off kname << "_has_batch_" << (has_batch ? 't' : 'n') << "_use_out_source_" << (use_out_source ? 't' : 'n') << "_do_axpby_" << (do_axpby ? 't' : 'n') << "_align_M_" << (align_M ? 't' : 'n') << "_align_N_" << (align_N ? 't' : 'n') << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on std::string hash_name = kname.str(); // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_gemm_fused_kernel( /* metal::Device& d = */ d, /* const std::string& kernel_name = */ base_name, /* const std::string& hash_name = */ hash_name, /* const metal::MTLFCList& func_consts = */ func_consts, /* const array& out = */ out, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, /* int bm = */ bm, /* int bn = */ bn, /* int bk = */ bk, /* int wm = */ wm, /* int wn = */ wn); compute_encoder.set_compute_pipeline_state(kernel); // Use problem size to determine threadblock swizzle int tn = (N + bn - 1) / bn; int tm = (M + bm - 1) / bm; // TODO: Explore device-based tuning for swizzle int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2); // Prepare steel matmul params GEMMParams params{ /* const int M = */ M, /* const int N = */ N, /* const int K = */ K, /* const int lda = */ lda, /* const int ldb = */ ldb, /* const int ldd = */ ldd, /* const int tiles_n = */ tn, /* const int tiles_m = */ tm, /* const int64_t batch_stride_a = */ A_batch_stride, /* const int64_t batch_stride_b = */ B_batch_stride, /* const int64_t batch_stride_d = */ matrix_stride_out, /* const int swizzle_log = */ swizzle_log, /* const int gemm_k_iterations_aligned = */ (K / bk), /* const int batch_ndim = */ int(batch_shape.size())}; // Prepare launch grid params int tile = 1 << swizzle_log; tm = (tm + tile - 1) / tile; tn = tn * tile; MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); // Launch kernel compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(b, 1); compute_encoder.set_output_array(out, 3); compute_encoder.set_bytes(params, 4); if (has_batch) { compute_encoder.set_vector_bytes(batch_shape, 6); compute_encoder.set_vector_bytes(batch_strides, 7); } if (use_out_source) { int ldc = c.strides()[c.ndim() - 2]; int fdc = c.strides()[c.ndim() - 1]; GEMMAddMMParams params{ /* const int ldc = */ ldc, /* const int fdc = */ fdc, /* const int64_t batch_stride_c = */ C_batch_stride, /* const float alpha = */ alpha, /* const float beta = */ beta}; compute_encoder.set_input_array(c, 2); compute_encoder.set_bytes(params, 5); } compute_encoder.dispatch_threadgroups(grid_dims, group_dims); // Record copies d.add_temporaries(std::move(copies), s.index); } /////////////////////////////////////////////////////////////////////////////// // Split k steel matmul /////////////////////////////////////////////////////////////////////////////// template void steel_gemm_splitk_axpby( const Stream& s, metal::Device& d, const array& a, const array& b, const array& c, array& out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector& copies, float alpha = 1.0f, float beta = 0.0f) { using namespace mlx::steel; int _tm = M / 16; int _tn = N / 16; int _tk = K / 16; int bm = M < 40 ? 16 : 32; int bn = N < 40 ? 16 : 32; int bk = 16; int wm = 2, wn = 2; int split_k_partitions = _tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16)); int split_k_partition_stride = M * N; int gemm_k_iterations = (K / bk) / split_k_partitions; int split_k_partition_size = gemm_k_iterations * bk; array C_split({split_k_partitions, M, N}, float32, nullptr, {}); C_split.set_data(allocator::malloc(C_split.nbytes())); copies.push_back(C_split); bool mn_aligned = M % bm == 0 && N % bn == 0; bool k_aligned = K % bk == 0; std::ostringstream kname; // clang-format off kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n") << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned"; // clang-format on // Encode and dispatch gemm kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_gemm_splitk_kernel( /* metal::Device& d = */ d, /* const std::string& kernel_name = */ kname.str(), /* const array& in = */ a, /* const array& out = */ C_split, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, /* int bm = */ bm, /* int bn = */ bn, /* int bk = */ bk, /* int wm = */ wm, /* int wn = */ wn, /* bool mn_aligned = */ mn_aligned, /* bool k_aligned = */ k_aligned); compute_encoder.set_compute_pipeline_state(kernel); int tn = (N + bn - 1) / bn; int tm = (M + bm - 1) / bm; GEMMSpiltKParams params{ /* const int M = */ M, /* const int N = */ N, /* const int K = */ K, /* const int lda = */ lda, /* const int ldb = */ ldb, /* const int ldc = */ N, /* const int tiles_n = */ tn, /* const int tiles_m = */ tm, /* const int split_k_partitions = */ split_k_partitions, /* const int split_k_partition_stride = */ split_k_partition_stride, /* const int split_k_partition_size = */ split_k_partition_size, /* const int gemm_k_iterations_aligned = */ gemm_k_iterations}; MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(b, 1); compute_encoder.set_output_array(C_split, 2); compute_encoder.set_bytes(params, 3); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); // Do accum kernel { const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f); auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + type_to_name(C_split); if (do_axpby) { kernel_name = kernel_name + "_axbpy"; } auto kernel = get_steel_gemm_splitk_accum_kernel( /* metal::Device& d = */ d, /* const std::string& kernel_name = */ kernel_name, /* const array& in = */ C_split, /* const array& out = */ out, /* bool axbpy = */ do_axpby); compute_encoder.set_compute_pipeline_state(kernel); // Set the arguments for the kernel compute_encoder.set_input_array(C_split, 0); compute_encoder.set_output_array(out, 1); compute_encoder.set_bytes(split_k_partitions, 2); compute_encoder.set_bytes(split_k_partition_stride, 3); compute_encoder.set_bytes(N, 4); if (do_axpby) { int ldc = c.strides()[c.ndim() - 2]; int fdc = c.strides()[c.ndim() - 1]; compute_encoder.set_input_array(c, 5); compute_encoder.set_bytes(ldc, 6); compute_encoder.set_bytes(fdc, 7); compute_encoder.set_bytes(alpha, 8); compute_encoder.set_bytes(beta, 9); } // Launch enough thread groups for each output MTL::Size grid_dims = MTL::Size(N, M, 1); auto group_dims = get_block_dims(N, M, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } d.add_temporaries(std::move(copies), s.index); } /////////////////////////////////////////////////////////////////////////////// // Split matmul routing /////////////////////////////////////////////////////////////////////////////// template void steel_matmul_axpby( const Stream& s, metal::Device& d, const array& a, const array& b, const array& c, array& out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector& copies, Shape batch_shape /* = {} */, Strides A_batch_stride /* = {} */, Strides B_batch_stride /* = {} */, Strides C_batch_stride /* = {} */, float alpha /* = 1.0f */, float beta /* = 0.0f */) { if (batch_shape.empty()) { ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions if constexpr (CHECK_AB) { auto [batch_shape_, A_bstride_, B_bstride_, C_bstride_] = collapse_batches(a, b, c); batch_shape = batch_shape_; A_batch_stride = A_bstride_; B_batch_stride = B_bstride_; C_batch_stride = C_bstride_; // Collapse batches into M if needed if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && C_batch_stride.back() == M * c.strides()[c.ndim() - 2] && B_batch_stride.back() == 0) { M *= batch_shape.back(); batch_size_out = 1; A_batch_stride = {0}; B_batch_stride = {0}; C_batch_stride = {0}; batch_shape = {1}; } } else { auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b); batch_shape = batch_shape_; A_batch_stride = A_bstride_; B_batch_stride = B_bstride_; // Collapse batches into M if needed if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && B_batch_stride.back() == 0) { M *= batch_shape.back(); batch_size_out = 1; A_batch_stride = {0}; B_batch_stride = {0}; batch_shape = {1}; } } } ///////////////////////////////////////////////////////////////////////////// // Split K specialization int _tm = M / 16; int _tn = N / 16; int _tk = K / 16; if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { return steel_gemm_splitk_axpby( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, /* const array& b = */ b, /* const array& c = */ c, /* array& out = */ out, /* int M = */ M, /* int N = */ N, /* int K = */ K, /* int batch_size_out = */ batch_size_out, /* int lda = */ lda, /* int ldb = */ ldb, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, /* std::vector& copies = */ copies, /* float alpha = */ alpha, /* float beta = */ beta); } ///////////////////////////////////////////////////////////////////////////// // Regular kernel dispatch auto batch_strides = A_batch_stride; batch_strides.insert( batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); if (CHECK_AB && !C_batch_stride.empty()) { batch_strides.insert( batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end()); } int64_t A_batch_stride_ = A_batch_stride.empty() ? 0 : A_batch_stride.back(); int64_t B_batch_stride_ = B_batch_stride.empty() ? 0 : B_batch_stride.back(); int64_t C_batch_stride_ = C_batch_stride.empty() ? 0 : C_batch_stride.back(); return steel_matmul_regular_axpby( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, /* const array& b = */ b, /* const array& c = */ c, /* array& out = */ out, /* int M = */ M, /* int N = */ N, /* int K = */ K, /* int batch_size_out = */ batch_size_out, /* int lda = */ lda, /* int ldb = */ ldb, /* int ldd = */ N, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, /* std::vector& copies = */ copies, /* Shape batch_shape = */ std::move(batch_shape), /* Strides batch_strides = */ std::move(batch_strides), /* int64_t A_batch_stride = */ A_batch_stride_, /* int64_t B_batch_stride = */ B_batch_stride_, /* int64_t matrix_stride_out = */ int64_t(M) * N, /* int64_t C_batch_stride = */ C_batch_stride_, /* float alpha = */ alpha, /* float beta = */ beta); } /////////////////////////////////////////////////////////////////////////////// // GEMV dispatch /////////////////////////////////////////////////////////////////////////////// template void gemv_axbpy( const Stream& s, metal::Device& d, const array& a, const array& b, const array& c, array& out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector& copies, Shape batch_shape = {}, Strides A_batch_stride = {}, Strides B_batch_stride = {}, Strides C_batch_stride = {}, float alpha = 1.0f, float beta = 0.0f) { // Collect problem info bool is_b_matrix = N != 1; auto& mat = is_b_matrix ? b : a; auto& vec = is_b_matrix ? a : b; bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a; int in_vector_len = K; int out_vector_len = is_b_matrix ? N : M; int mat_cols = transpose_mat ? out_vector_len : in_vector_len; int mat_rows = transpose_mat ? in_vector_len : out_vector_len; int mat_ld = is_b_matrix ? ldb : lda; auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride; int stride_mat = batch_strides_mat.back(); int stride_vec = batch_strides_vec.back(); // Determine if inputs have simple batching / broadcasting bool contiguous_kernel = (batch_shape.size() == 1); int batch_ndim = batch_shape.size(); // Determine dispatch kernel int tm = 4, tn = 4; int sm = 1, sn = 32; int bm = 1, bn = 1; int n_out_per_tgp; std::ostringstream kname; if (transpose_mat) { if (in_vector_len >= 8192 && out_vector_len >= 2048) { sm = 4; sn = 8; } else { sm = 8; sn = 4; } if (out_vector_len >= 2048) { bn = 16; } else if (out_vector_len >= 512) { bn = 4; } else { bn = 2; } // Specialized kernel for very small outputs tn = out_vector_len < tn ? 1 : tn; n_out_per_tgp = bn * sn * tn; kname << "gemv_t_" << type_to_name(out); } else { bm = out_vector_len >= 4096 ? 8 : 4; sn = 32; // Specialized kernel for very small outputs tm = out_vector_len < tm ? 1 : tm; n_out_per_tgp = bm * sm * tm; kname << "gemv_" << type_to_name(out); } const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f); // clang-format off kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" << tm << "_tn" << tn << "_nc" << !contiguous_kernel << "_axpby" << do_axpby; // clang-format on // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder.set_compute_pipeline_state(kernel); int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; MTL::Size group_dims = MTL::Size(32, bn, bm); MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); compute_encoder.set_input_array(mat, 0); compute_encoder.set_input_array(vec, 1); compute_encoder.set_output_array(out, 3); compute_encoder.set_bytes(in_vector_len, 4); compute_encoder.set_bytes(out_vector_len, 5); compute_encoder.set_bytes(mat_ld, 6); compute_encoder.set_bytes(batch_ndim, 9); compute_encoder.set_vector_bytes(batch_shape, 10); compute_encoder.set_vector_bytes(batch_strides_vec, 11); compute_encoder.set_vector_bytes(batch_strides_mat, 12); if (do_axpby) { compute_encoder.set_input_array(c, 2); compute_encoder.set_bytes(alpha, 7); compute_encoder.set_bytes(beta, 8); compute_encoder.set_vector_bytes(C_batch_stride, 13); int bias_stride = c.strides()[c.ndim() - 1]; compute_encoder.set_bytes(bias_stride, 14); } compute_encoder.dispatch_threadgroups(grid_dims, group_dims); d.add_temporaries(std::move(copies), s.index); } inline void gemv( const Stream& s, metal::Device& d, const array& a, const array& b, array& out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector& copies, Shape batch_shape = {}, Strides A_batch_stride = {}, Strides B_batch_stride = {}) { return gemv_axbpy( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, /* const array& b = */ b, /* const array& c = */ b, /* array& out = */ out, /* int M = */ M, /* int N = */ N, /* int K = */ K, /* int batch_size_out = */ batch_size_out, /* int lda = */ lda, /* int ldb = */ ldb, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, /* std::vector& copies = */ copies, /* Shape batch_shape = */ batch_shape, /* Strides A_batch_stride = */ A_batch_stride, /* Strides B_batch_stride = */ B_batch_stride); } /////////////////////////////////////////////////////////////////////////////// // Matmul implementation /////////////////////////////////////////////////////////////////////////////// void Matmul::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (!issubdtype(out.dtype(), floating)) { throw std::runtime_error( "[matmul] Does not yet support non-floating point types."); } auto& s = stream(); auto& d = metal::device(s.device); auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; // Return 0s if either input is empty if (a_pre.size() == 0 || b_pre.size() == 0) { array zero = array(0, a_pre.dtype()); fill_gpu(zero, out, s); d.add_temporary(std::move(zero), s.index); return; } out.set_data(allocator::malloc(out.nbytes())); ///////////////////////////////////////////////////////////////////////////// // Init checks and prep int M = a_pre.shape(-2); int N = b_pre.shape(-1); int K = a_pre.shape(-1); // Keep a vector with copies to be cleared in the completed buffer to release // the arrays std::vector copies; auto [a_transposed, a_cols, a] = check_transpose(copies, s, a_pre, M == 1); auto [b_transposed, b_cols, b] = check_transpose(copies, s, b_pre, N == 1); ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b); auto batch_size_out = out.size() / (size_t(M) * size_t(N)); // Collapse batches into M if needed if (batch_size_out > 1 && !a_transposed && batch_shape.size() == 1 && a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && B_batch_stride.back() == 0) { M *= batch_shape.back(); batch_size_out = 1; A_batch_stride = {0}; B_batch_stride = {0}; batch_shape = {1}; } ///////////////////////////////////////////////////////////////////////////// // Gemv specialization // Route to gemv if needed if (std::min(M, N) == 1) { return gemv( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, /* const array& b = */ b, /* array& out = */ out, /* int M = */ M, /* int N = */ N, /* int K = */ K, /* int batch_size_out = */ batch_size_out, /* int lda = */ a_cols, /* int ldb = */ b_cols, /* bool transpose_a = */ a_transposed, /* bool transpose_b = */ b_transposed, /* std::vector& copies = */ copies, /* Shape batch_shape = */ std::move(batch_shape), /* Strides A_batch_stride = */ std::move(A_batch_stride), /* Strides B_batch_stride = */ std::move(B_batch_stride)); } ///////////////////////////////////////////////////////////////////////////// // Gemm specialization return steel_matmul( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, /* const array& b = */ b, /* array& out = */ out, /* int M = */ M, /* int N = */ N, /* int K = */ K, /* int batch_size_out = */ batch_size_out, /* int lda = */ a_cols, /* int ldb = */ b_cols, /* bool transpose_a = */ a_transposed, /* bool transpose_b = */ b_transposed, /* std::vector& copies = */ copies, /* Shape batch_shape = */ std::move(batch_shape), /* Strides A_batch_stride = */ std::move(A_batch_stride), /* Strides B_batch_stride = */ std::move(B_batch_stride)); } /////////////////////////////////////////////////////////////////////////////// // AddMM implementation /////////////////////////////////////////////////////////////////////////////// void AddMM::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 3); if (!issubdtype(out.dtype(), floating)) { throw std::runtime_error( "[matmul] Does not yet support non-floating point types."); } // Return 0s if either input is empty if (out.size() == 0) { out.set_data(allocator::malloc(out.nbytes())); return; } // Copy c into out and return if (inputs[0].shape(-1) == 0) { copy_gpu( inputs[2], out, inputs[2].flags().row_contiguous ? CopyType::Vector : CopyType::General, stream()); return; } out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; auto& c_pre = inputs[2]; ///////////////////////////////////////////////////////////////////////////// // Init checks and prep int M = a_pre.shape(-2); int N = b_pre.shape(-1); int K = a_pre.shape(-1); // Keep a vector with copies to be cleared in the completed buffer to release // the arrays std::vector copies; auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1); auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1); array c = c_pre; int ldc = c.strides()[c.ndim() - 2]; int fdc = c.strides()[c.ndim() - 1]; int lda = a_cols; int ldb = b_cols; int ldd = N; ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions auto [batch_shape, A_batch_stride, B_batch_stride, C_batch_stride] = collapse_batches(a, b, c); int64_t matrix_stride_out = M * static_cast(N); auto batch_size_out = out.size() / (matrix_stride_out); // Collapse batches into M if needed if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && C_batch_stride.back() == M * c.strides()[c.ndim() - 2] && B_batch_stride.back() == 0) { M *= batch_shape.back(); batch_size_out = 1; A_batch_stride = {0}; B_batch_stride = {0}; C_batch_stride = {0}; batch_shape = {1}; } ///////////////////////////////////////////////////////////////////////////// // Gemv specialization // Route to gemv if needed if (std::min(M, N) == 1) { return gemv_axbpy( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, /* const array& b = */ b, /* const array& c = */ c, /* array& out = */ out, /* int M = */ M, /* int N = */ N, /* int K = */ K, /* int batch_size_out = */ batch_size_out, /* int lda = */ lda, /* int ldb = */ ldb, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, /* std::vector& copies = */ copies, /* Shape batch_shape = */ batch_shape, /* Strides A_batch_stride = */ A_batch_stride, /* Strides B_batch_stride = */ B_batch_stride, /* Strides C_batch_stride = */ C_batch_stride, /* float alpha = */ alpha_, /* float beta = */ beta_); } ///////////////////////////////////////////////////////////////////////////// // Regular addmm dispatch return steel_matmul_axpby( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, /* const array& b = */ b, /* const array& c = */ c, /* array& out = */ out, /* int M = */ M, /* int N = */ N, /* int K = */ K, /* int batch_size_out = */ batch_size_out, /* int lda = */ lda, /* int ldb = */ ldb, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, /* std::vector& copies = */ copies, /* Shape batch_shape = */ batch_shape, /* Strides A_batch_stride = */ A_batch_stride, /* Strides B_batch_stride = */ B_batch_stride, /* Strides B_batch_stride = */ C_batch_stride, /* float alpha = */ alpha_, /* float beta = */ beta_); } /////////////////////////////////////////////////////////////////////////////// // BlockMaskedMM implementation /////////////////////////////////////////////////////////////////////////////// void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { using namespace mlx::steel; // assert(inputs.size() == 2); if (!issubdtype(out.dtype(), floating)) { throw std::runtime_error( "[matmul] Does not yet support non-floating point types."); } auto& s = stream(); auto& d = metal::device(s.device); auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; // Return 0s if either input is empty if (a_pre.size() == 0 || b_pre.size() == 0) { array zero = array(0, a_pre.dtype()); fill_gpu(zero, out, s); d.add_temporary(std::move(zero), s.index); return; } out.set_data(allocator::malloc(out.nbytes())); ///////////////////////////////////////////////////////////////////////////// // Init checks and prep int M = a_pre.shape(-2); int N = b_pre.shape(-1); int K = a_pre.shape(-1); // Keep a vector with copies to be cleared in the completed buffer to release // the arrays std::vector copies; auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1); auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1); int lda = a_cols; int ldb = b_cols; ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions bool has_op_mask = inputs.size() > 3; bool has_out_mask = inputs.size() == 3 || inputs.size() == 5; // Prepare kernel name std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask"; std::string op_mask_nm = has_op_mask ? type_to_name(inputs.back()) : "nomask"; auto get_batch_dims = [](const auto& v) { return decltype(v){v.begin(), v.end() - 2}; }; Shape batch_shape{1}; Strides A_batch_stride{0}; Strides B_batch_stride{0}; Strides outmask_bstride{0}; Strides Amask_bstride{0}; Strides Bmask_bstride{0}; int64_t A_batch_str = 0; int64_t B_batch_str = 0; Strides batch_strides; if (out.ndim() > 2) { Shape bshape{out.shape().begin(), out.shape().end() - 2}; std::vector bstrides; for (auto& arr : inputs) { bstrides.emplace_back(arr.strides().begin(), arr.strides().end() - 2); } // auto [bshape_c, bstrides_c] = collapse_contiguous_dims(bshape, bstrides); batch_shape = bshape; A_batch_str = bstrides[0].back(); B_batch_str = bstrides[1].back(); for (auto& bstr : bstrides) { batch_strides.insert(batch_strides.end(), bstr.begin(), bstr.end()); } A_batch_stride = bstrides[0]; B_batch_stride = bstrides[1]; if (has_out_mask) { outmask_bstride = bstrides[2]; } if (has_op_mask) { Amask_bstride = bstrides[has_out_mask + 2]; Bmask_bstride = bstrides[has_out_mask + 3]; } } else { batch_strides = Strides(inputs.size(), 0); } int64_t matrix_stride_out = static_cast(M) * N; size_t batch_size_out = out.size() / (matrix_stride_out); ///////////////////////////////////////////////////////////////////////////// // Gemv specialization // Route to gemv if needed if (std::min(M, N) == 1) { // Collect problem info bool is_b_matrix = N != 1; auto& mat = is_b_matrix ? b : a; auto& vec = is_b_matrix ? a : b; bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a; int in_vector_len = K; int out_vector_len = is_b_matrix ? N : M; int mat_cols = transpose_mat ? out_vector_len : in_vector_len; int mat_rows = transpose_mat ? in_vector_len : out_vector_len; int mat_ld = is_b_matrix ? b_cols : a_cols; auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride; auto mask_bstrides_mat = is_b_matrix ? Bmask_bstride : Amask_bstride; auto mask_bstrides_vec = is_b_matrix ? Amask_bstride : Bmask_bstride; auto mat_mask_idx = int(has_out_mask) + (is_b_matrix ? 3 : 2); auto vec_mask_idx = int(has_out_mask) + (is_b_matrix ? 2 : 3); // Determine if inputs have simple batching / broadcasting bool contiguous_kernel = (batch_shape.size() == 1); int batch_ndim = batch_shape.size(); // Determine dispatch kernel int tm = 4, tn = 4; int sm = 1, sn = 32; int bm = 1, bn = 1; int n_out_per_tgp; std::ostringstream kname; if (transpose_mat) { sm = 8; sn = 4; bm = 1; bn = (block_size_ == 64 && out_vector_len >= 2048) ? 4 : 2; tm = block_size_ == 32 ? 4 : 8; tn = 4; // Specialized kernel for very small outputs tn = out_vector_len < tn ? 1 : tn; n_out_per_tgp = bn * sn * tn; kname << "gemv_t"; } else { if (block_size_ == 32) { sm = 4; sn = 8; bm = 2; } else { sm = 2; sn = 16; bm = out_vector_len >= 512 ? 4 : 2; } // Specialized kernel for very small outputs tm = out_vector_len < tm ? 1 : tm; n_out_per_tgp = bm * sm * tm; kname << "gemv"; } kname << "_outmask_" << out_mask_nm; kname << "_opmask_" << op_mask_nm; kname << "_" << type_to_name(out); kname << "_bm" << bm << "_bn" << bn; kname << "_sm" << sm << "_sn" << sn; kname << "_tm" << tm << "_tn" << tn; kname << "_nc" << !contiguous_kernel; // Encode and dispatch kernel auto kernel = get_gemv_masked_kernel( d, kname.str(), out, has_out_mask ? std::optional{inputs[2]} : std::nullopt, has_op_mask ? std::optional{inputs.back()} : std::nullopt, transpose_mat, bm, bn, sm, sn, tm, tn, contiguous_kernel); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; MTL::Size group_dims = MTL::Size(32, bn, bm); MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); // Get mask params std::vector mask_strides; Strides mask_batch_strides; if (has_out_mask) { auto& out_mask = inputs[2]; if (transpose_mat) { mask_strides.push_back(out_mask.strides(out.shape(-2) == 1 ? -1 : -2)); mask_strides.push_back(out_mask.strides(out.shape(-2) == 1 ? -2 : -1)); } else { mask_strides.push_back(out_mask.strides(out.shape(-1) == 1 ? -1 : -2)); mask_strides.push_back(out_mask.strides(out.shape(-1) == 1 ? -2 : -1)); } mask_batch_strides.insert( mask_batch_strides.end(), outmask_bstride.begin(), outmask_bstride.end()); compute_encoder.set_input_array(out_mask, 20); } if (has_op_mask) { auto& mat_mask = inputs[mat_mask_idx]; if (transpose_mat) { mask_strides.push_back(mat_mask.strides(!is_b_matrix ? -2 : -1)); mask_strides.push_back(mat_mask.strides(!is_b_matrix ? -1 : -2)); } else { mask_strides.push_back(mat_mask.strides(is_b_matrix ? -2 : -1)); mask_strides.push_back(mat_mask.strides(is_b_matrix ? -1 : -2)); } mask_batch_strides.insert( mask_batch_strides.end(), mask_bstrides_mat.begin(), mask_bstrides_mat.end()); compute_encoder.set_input_array(mat_mask, 21); auto& vec_mask = inputs[vec_mask_idx]; if (transpose_mat) { mask_strides.push_back(vec_mask.strides(vec.shape(-2) == 1 ? -1 : -2)); mask_strides.push_back(vec_mask.strides(vec.shape(-2) == 1 ? -2 : -1)); } else { mask_strides.push_back(vec_mask.strides(vec.shape(-1) == 1 ? -1 : -2)); mask_strides.push_back(vec_mask.strides(vec.shape(-1) == 1 ? -2 : -1)); } mask_batch_strides.insert( mask_batch_strides.end(), mask_bstrides_vec.begin(), mask_bstrides_vec.end()); compute_encoder.set_input_array(vec_mask, 22); } // Get gemv params compute_encoder.set_input_array(mat, 0); compute_encoder.set_input_array(vec, 1); compute_encoder.set_output_array(out, 3); compute_encoder.set_bytes(in_vector_len, 4); compute_encoder.set_bytes(out_vector_len, 5); compute_encoder.set_bytes(mat_ld, 6); compute_encoder.set_bytes(batch_ndim, 9); compute_encoder.set_vector_bytes(batch_shape, 10); compute_encoder.set_vector_bytes(batch_strides_vec, 11); compute_encoder.set_vector_bytes(batch_strides_mat, 12); compute_encoder.set_vector_bytes(mask_strides, 23); compute_encoder.set_vector_bytes(mask_batch_strides, 24); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); d.add_temporaries(std::move(copies), s.index); return; } ///////////////////////////////////////////////////////////////////////////// // Regular kernel dispatch // Determine dispatch kernel int bm = block_size_, bn = block_size_, bk = 16; int wm = 2, wn = 2; bool mn_aligned = M % bm == 0 && N % bn == 0; bool k_aligned = K % bk == 0; std::ostringstream kname; kname << "steel_gemm_block_outmask_" << out_mask_nm << "_opmask_" << op_mask_nm << "_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n") << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned"; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_gemm_masked_kernel( d, kname.str(), out, has_out_mask ? std::optional{inputs[2]} : std::nullopt, has_op_mask ? std::optional{inputs.back()} : std::nullopt, transpose_a, transpose_b, bm, bn, bk, wm, wn, mn_aligned, k_aligned); compute_encoder.set_compute_pipeline_state(kernel); // Use problem size to determine threadblock swizzle int tn = (N + bn - 1) / bn; int tm = (M + bm - 1) / bm; // TODO: Explore device-based tuning for swizzle int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2); // Prepare steel matmul params GEMMParams params{ /* const int M = */ M, /* const int N = */ N, /* const int K = */ K, /* const int lda = */ lda, /* const int ldb = */ ldb, /* const int ldd = */ N, /* const int tiles_n = */ tn, /* const int tiles_m = */ tm, /* const int64_t batch_stride_a = */ A_batch_str, /* const int64_t batch_stride_b = */ B_batch_str, /* const int64_t batch_stride_d = */ matrix_stride_out, /* const int swizzle_log = */ swizzle_log, /* const int gemm_k_iterations_aligned = */ (K / bk), /* const int batch_ndim = */ int(batch_shape.size())}; // Prepare launch grid params int tile = 1 << swizzle_log; tm = (tm + tile - 1) / tile; tn = tn * tile; MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); std::vector mask_strides; if (has_out_mask) { auto& out_mask = inputs[2]; mask_strides.push_back(*(out_mask.strides().end() - 1)); mask_strides.push_back(*(out_mask.strides().end() - 2)); compute_encoder.set_input_array(out_mask, 10); } if (has_op_mask) { auto& lhs_mask = inputs[2 + has_out_mask]; mask_strides.push_back(*(lhs_mask.strides().end() - 1)); mask_strides.push_back(*(lhs_mask.strides().end() - 2)); compute_encoder.set_input_array(lhs_mask, 11); auto& rhs_mask = inputs[3 + has_out_mask]; mask_strides.push_back(*(rhs_mask.strides().end() - 1)); mask_strides.push_back(*(rhs_mask.strides().end() - 2)); compute_encoder.set_input_array(rhs_mask, 12); } // Launch kernel compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(b, 1); compute_encoder.set_output_array(out, 3); compute_encoder.set_bytes(params, 4); compute_encoder.set_vector_bytes(batch_shape, 6); compute_encoder.set_vector_bytes(batch_strides, 7); compute_encoder.set_vector_bytes(mask_strides, 13); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); d.add_temporaries(std::move(copies), s.index); } /////////////////////////////////////////////////////////////////////////////// // GatherMM implementation /////////////////////////////////////////////////////////////////////////////// void gather_mm_rhs( const array& a_, const array& b_, const array& indices_, array& out, metal::Device& d, const Stream& s) { array indices = ensure_row_contiguous(indices_, d, s); auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s); // Broadcast a with indices. If we are here that means lhs_indices were not // provided so the lhs_indices are implied to be the shape of a broadcasted // with rhs_indices. We need only broadcast a and copy it as if applying the // lhs_indices. auto broadcast_with_indices = [&d, &s, &indices](const array& x) { if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) { return ensure_row_contiguous(x, d, s); } auto x_shape = indices.shape(); x_shape.push_back(x.shape(-2)); x_shape.push_back(x.shape(-1)); array new_x(std::move(x_shape), x.dtype(), nullptr, {}); broadcast(x, new_x); return ensure_row_contiguous(new_x, d, s); }; array a = broadcast_with_indices(a_); // Extract the matmul shapes int K = a.shape(-1); int M = a.size() / K; int N = b.shape(-1); int lda = a.strides()[a.ndim() - 2]; // should be K // Define the dispatch blocks int bm = 16, bn = 64, bk = 16; int wm = 1, wn = 2; const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; // Define the kernel name std::string base_name; base_name.reserve(64); concatenate( base_name, "steel_gather_mm_rhs_n", transpose_b ? 't' : 'n', '_', type_to_name(a), '_', type_to_name(out), "_bm", bm, "_bn", bn, "_bk", bk, "_wm", wm, "_wn", wn); metal::MTLFCList func_consts = { {&align_M, MTL::DataType::DataTypeBool, 200}, {&align_N, MTL::DataType::DataTypeBool, 201}, {&align_K, MTL::DataType::DataTypeBool, 202}, }; // And the kernel hash that includes the function constants std::string hash_name; hash_name.reserve(128); concatenate( hash_name, base_name, "_align_M_", align_M ? 't' : 'n', "_align_N_", align_N ? 't' : 'n', "_align_K_", align_K ? 't' : 'n'); // Get and set the kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_gemm_gather_kernel( d, base_name, hash_name, func_consts, out, false, transpose_b, bm, bn, bk, wm, wn, true); compute_encoder.set_compute_pipeline_state(kernel); // Prepare the matmul params auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size(); steel::GEMMParams params{ /* const int M = */ M, /* const int N = */ N, /* const int K = */ K, /* const int lda = */ lda, /* const int ldb = */ static_cast(ldb), /* const int ldd = */ N, /* const int tiles_n = */ (N + bn - 1) / bn, /* const int tiles_m = */ (M + bm - 1) / bm, /* const int64_t batch_stride_a = */ 0, /* const int64_t batch_stride_b = */ static_cast(batch_stride_b), /* const int64_t batch_stride_d = */ 0, /* const int swizzle_log = */ 0, /* const int gemm_k_iterations_aligned = */ (K / bk), /* const int batch_ndim = */ 0}; // Prepare the grid MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1); // Launch kernel compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(b, 1); compute_encoder.set_input_array(indices, 2); compute_encoder.set_output_array(out, 3); compute_encoder.set_bytes(params, 4); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void gather_mv( const array& mat_, const array& vec_, const array& mat_indices_, const array& vec_indices_, array& out, int N, int K, bool is_mv, metal::Device& d, const Stream& s) { // Copy if needed std::vector copies; auto [transpose_mat, mat_cols, mat] = check_transpose(copies, s, mat_, N == 1); auto [transpose_vec, vec_cols, vec] = check_transpose(copies, s, vec_, true); d.add_temporaries(std::move(copies), s.index); // If we are doing vector matrix instead of matrix vector we need to flip the // matrix transposition. Basically m @ v = v @ m.T assuming that v is treated // as a one dimensional array. transpose_mat = (!is_mv) ^ transpose_mat; // Define some shapes int in_vector_len = K; int out_vector_len = N; int mat_ld = mat_cols; int batch_size_out = out.size() / N; int batch_ndim = out.ndim() - 2; int batch_ndim_mat = mat.ndim() - 2; int batch_ndim_vec = vec.ndim() - 2; Strides index_strides = vec_indices_.strides(); index_strides.insert( index_strides.end(), mat_indices_.strides().begin(), mat_indices_.strides().end()); // Determine dispatch kernel int tm = 4, tn = 4; int sm = 1, sn = 32; int bm = 1, bn = 1; int n_out_per_tgp; std::ostringstream kname; if (transpose_mat) { if (in_vector_len >= 8192 && out_vector_len >= 2048) { sm = 4; sn = 8; } else { sm = 8; sn = 4; } if (out_vector_len >= 2048) { bn = 16; } else if (out_vector_len >= 512) { bn = 4; } else { bn = 2; } // Specialized kernel for very small outputs tn = out_vector_len < tn ? 1 : tn; n_out_per_tgp = bn * sn * tn; kname << "gemv_t_gather_" << type_to_name(out); } else { bm = out_vector_len >= 4096 ? 8 : 4; sn = 32; // Specialized kernel for very small outputs tm = out_vector_len < tm ? 1 : tm; n_out_per_tgp = bm * sm * tm; kname << "gemv_gather_" << type_to_name(out); } kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" << tm << "_tn" << tn; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder.set_compute_pipeline_state(kernel); int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; MTL::Size group_dims = MTL::Size(32, bn, bm); MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); compute_encoder.set_input_array(mat, 0); compute_encoder.set_input_array(vec, 1); compute_encoder.set_output_array(out, 3); compute_encoder.set_bytes(in_vector_len, 4); compute_encoder.set_bytes(out_vector_len, 5); compute_encoder.set_bytes(mat_ld, 6); compute_encoder.set_bytes(batch_ndim, 9); compute_encoder.set_vector_bytes(out.shape(), 10); compute_encoder.set_vector_bytes(index_strides, 11); compute_encoder.set_bytes(batch_ndim_vec, 12); compute_encoder.set_vector_bytes(vec.shape(), 13); compute_encoder.set_vector_bytes(vec.strides(), 14); compute_encoder.set_bytes(batch_ndim_mat, 15); compute_encoder.set_vector_bytes(mat.shape(), 16); compute_encoder.set_vector_bytes(mat.strides(), 17); compute_encoder.set_input_array(vec_indices_, 18); compute_encoder.set_input_array(mat_indices_, 19); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void gather_mm( const array& a_, const array& b_, const array& lhs_indices, const array& rhs_indices, array& out, int M, int N, int K, metal::Device& d, const Stream& s) { // Copy if needed std::vector copies; auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false); auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false); d.add_temporaries(std::move(copies), s.index); // Determine dispatch kernel int bm = 64, bn = 64, bk = 16; int wm = 2, wn = 2; size_t batch_size_out = out.size() / M / N; int batch_ndim = out.ndim() - 2; int batch_ndim_a = a.ndim() - 2; int batch_ndim_b = b.ndim() - 2; char devc = d.get_architecture().back(); GEMM_TPARAM_MACRO(devc) const bool has_batch = batch_ndim > 1; const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; // Define the kernel name std::string base_name; base_name.reserve(128); concatenate( base_name, "steel_gather_mm_", transpose_a ? 't' : 'n', transpose_b ? 't' : 'n', "_", type_to_name(a), "_", type_to_name(out), "_bm", bm, "_bn", bn, "_bk", bk, "_wm", wm, "_wn", wn); metal::MTLFCList func_consts = { {&has_batch, MTL::DataType::DataTypeBool, 10}, {&align_M, MTL::DataType::DataTypeBool, 200}, {&align_N, MTL::DataType::DataTypeBool, 201}, {&align_K, MTL::DataType::DataTypeBool, 202}, }; // And the kernel hash that includes the function constants std::string hash_name; hash_name.reserve(128); concatenate( hash_name, base_name, "_has_batch_", has_batch ? 't' : 'n', "_align_M_", align_M ? 't' : 'n', "_align_N_", align_N ? 't' : 'n', "_align_K_", align_K ? 't' : 'n'); // Get and set the kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_gemm_gather_kernel( d, base_name, hash_name, func_consts, out, transpose_a, transpose_b, bm, bn, bk, wm, wn, false); compute_encoder.set_compute_pipeline_state(kernel); // Prepare the matmul params steel::GEMMParams params{ /* const int M = */ M, /* const int N = */ N, /* const int K = */ K, /* const int lda = */ static_cast(lda), /* const int ldb = */ static_cast(ldb), /* const int ldd = */ N, /* const int tiles_n = */ (N + bn - 1) / bn, /* const int tiles_m = */ (M + bm - 1) / bm, /* const int64_t batch_stride_a = */ (batch_ndim > 0) ? lhs_indices.strides()[0] : 0, /* const int64_t batch_stride_b = */ (batch_ndim > 0) ? rhs_indices.strides()[0] : 0, /* const int64_t batch_stride_d = */ M * N, /* const int swizzle_log = */ 0, /* const int gemm_k_iterations_aligned = */ (K / bk), /* const int batch_ndim = */ batch_ndim}; // Prepare the grid MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, batch_size_out); // Launch kernel compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(b, 1); compute_encoder.set_input_array(lhs_indices, 2); compute_encoder.set_input_array(rhs_indices, 3); compute_encoder.set_output_array(out, 4); compute_encoder.set_bytes(params, 5); compute_encoder.set_vector_bytes(lhs_indices.shape(), 6); compute_encoder.set_vector_bytes(lhs_indices.strides(), 7); compute_encoder.set_vector_bytes(rhs_indices.strides(), 8); compute_encoder.set_bytes(batch_ndim_a, 9); compute_encoder.set_vector_bytes(a.shape(), 10); compute_encoder.set_vector_bytes(a.strides(), 11); compute_encoder.set_bytes(batch_ndim_b, 12); compute_encoder.set_vector_bytes(b.shape(), 13); compute_encoder.set_vector_bytes(b.strides(), 14); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void GatherMM::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); auto& a = inputs[0]; auto& b = inputs[1]; auto& lhs_indices = inputs[2]; auto& rhs_indices = inputs[3]; // Return 0s if either input is empty if (a.size() == 0 || b.size() == 0) { array zero = array(0, a.dtype()); fill_gpu(zero, out, s); d.add_temporary(std::move(zero), s.index); return; } out.set_data(allocator::malloc(out.nbytes())); // Extract shapes from inputs. int M = a.shape(-2); int N = b.shape(-1); int K = a.shape(-1); // We are walking a in order and b is also in order so we can batch up the // matmuls and reuse reading a and b. if (M == 1 && right_sorted_ == true) { gather_mm_rhs(a, b, rhs_indices, out, d, s); return; } // Route to gather gemv if any of a or b are vectors if (M == 1) { gather_mv(b, a, rhs_indices, lhs_indices, out, N, K, false, d, s); return; } if (N == 1) { gather_mv(a, b, lhs_indices, rhs_indices, out, M, K, true, d, s); return; } // Route to non specialized gather mm gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s); } } // namespace mlx::core