diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 7d2ccd87f..48014a640 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -15,6 +15,8 @@ #include "mlx/primitives.h" #include "mlx/utils.h" +#include "mlx/internal/tuner/primitives.h" + namespace mlx::core { namespace { @@ -1848,4 +1850,195 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { d.add_temporaries(std::move(copies), s.index); } +namespace internal { + +void TunableMatmul::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 2); + auto& s = stream(); + auto& d = metal::device(s.device); + + auto& a_pre = inputs[0]; + auto& b_pre = inputs[1]; + auto& out = outputs[0]; + // Return 0s if either input is empty + if (a_pre.size() == 0 || b_pre.size() == 0) { + array zero = array(0, a_pre.dtype()); + fill_gpu(zero, out, s); + d.add_temporary(std::move(zero), s.index); + return; + } + + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + ///////////////////////////////////////////////////////////////////////////// + // Init checks and prep + + int M = a_pre.shape(-2); + int N = b_pre.shape(-1); + int K = a_pre.shape(-1); + + // Keep a vector with copies to be cleared in the completed buffer to release + // the arrays + std::vector copies; + auto check_transpose = [&copies, &s](const array& arr, bool is_vector) { + auto stx = arr.strides()[arr.ndim() - 2]; + auto sty = arr.strides()[arr.ndim() - 1]; + if (sty == 1 && (!is_vector || stx == arr.shape(-1))) { + return std::make_tuple(false, stx, arr); + } else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) { + return std::make_tuple(true, sty, arr); + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy_gpu(arr, arr_copy, CopyType::General, s); + copies.push_back(arr_copy); + size_t stx = arr.shape(-1); + return std::make_tuple(false, stx, arr_copy); + } + }; + + auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1); + auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1); + + ///////////////////////////////////////////////////////////////////////////// + // Check and collapse batch dimensions + + auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b); + + auto batch_size_out = out.size() / (size_t(M) * size_t(N)); + + // Collapse batches into M if needed + if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && + B_batch_stride.back() == 0) { + M *= batch_shape.back(); + batch_size_out = 1; + + A_batch_stride = {0}; + B_batch_stride = {0}; + batch_shape = {1}; + } + + std::vector batch_strides = A_batch_stride; + batch_strides.insert( + batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); + + using namespace mlx::steel; + + // Determine dispatch kernel + int bm = tparams_["bm"]; + int bn = tparams_["bn"]; + int bk = tparams_["bk"]; + int wm = tparams_["wm"]; + int wn = tparams_["wn"]; + + // Prepare kernel name + std::ostringstream kname; + kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n') + << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" + << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk + << "_wm" << wm << "_wn" << wn; + + std::string base_name = kname.str(); + + const bool has_batch = (batch_shape.size() > 1); + const bool use_out_source = false; + const bool do_axpby = false; + const bool align_M = (M % bm) == 0; + const bool align_N = (N % bn) == 0; + const bool align_K = (K % bk) == 0; + const bool do_gather = false; + + metal::MTLFCList func_consts = { + {&has_batch, MTL::DataType::DataTypeBool, 10}, + {&use_out_source, MTL::DataType::DataTypeBool, 100}, + {&do_axpby, MTL::DataType::DataTypeBool, 110}, + {&align_M, MTL::DataType::DataTypeBool, 200}, + {&align_N, MTL::DataType::DataTypeBool, 201}, + {&align_K, MTL::DataType::DataTypeBool, 202}, + {&do_gather, MTL::DataType::DataTypeBool, 300}, + }; + + // clang-format off + kname << "_has_batch_" << (has_batch ? 't' : 'n') + << "_use_out_source_" << (use_out_source ? 't' : 'n') + << "_do_axpby_" << (do_axpby ? 't' : 'n') + << "_align_M_" << (align_M ? 't' : 'n') + << "_align_N_" << (align_N ? 't' : 'n') + << "_align_K_" << (align_K ? 't' : 'n') + << "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on + + std::string hash_name = kname.str(); + + // Encode and dispatch kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_steel_gemm_fused_kernel( + d, + base_name, + hash_name, + func_consts, + out, + transpose_a, + transpose_b, + bm, + bn, + bk, + wm, + wn); + + compute_encoder->setComputePipelineState(kernel); + + // Use problem size to determine threadblock swizzle + int tn = (N + bn - 1) / bn; + int tm = (M + bm - 1) / bm; + + // TODO: Explore device-based tuning for swizzle + int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2); + + size_t matrix_stride_out = size_t(M) * N; + + // Prepare steel matmul params + GEMMParams params{ + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ int(a_cols), + /* const int ldb = */ int(b_cols), + /* const int ldd = */ N, + /* const int tiles_n = */ tn, + /* const int tiles_m = */ tm, + /* const size_t batch_stride_a = */ A_batch_stride.back(), + /* const size_t batch_stride_b = */ B_batch_stride.back(), + /* const size_t batch_stride_d = */ matrix_stride_out, + /* const int swizzle_log = */ swizzle_log, + /* const int gemm_k_iterations_aligned = */ (K / bk), + /* const int batch_ndim = */ int(batch_shape.size())}; + + // Prepare launch grid params + int tile = 1 << swizzle_log; + tm = (tm + tile - 1) / tile; + tn = tn * tile; + + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); + + // Launch kernel + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_output_array(out, 3); + + compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4); + + set_vector_bytes(compute_encoder, batch_shape, 6); + set_vector_bytes(compute_encoder, batch_strides, 7); + + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + + // Record copies + d.add_temporaries(std::move(copies), s.index); +} + +} // namespace internal + } // namespace mlx::core