diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 1c9ee5a6c..03d5a89cb 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -1000,124 +1000,37 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { ///////////////////////////////////////////////////////////////////////////// // Regular addmm dispatch - // Determine dispatch kernel - int bm = 64, bn = 64, bk = 16; - int wm = 2, wn = 2; - - char devc = d.get_architecture().back(); - GEMM_TPARAM_MACRO(devc) - - // Prepare kernel name - std::ostringstream kname; - kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn; - - std::string base_name = kname.str(); - - const bool has_batch = (batch_shape.size() > 1); - const bool use_out_source = true; - const bool do_axpby = !(alpha_ == 1. && beta_ == 1.); - const bool align_M = (M % bm) == 0; - const bool align_N = (N % bn) == 0; - const bool align_K = (K % bk) == 0; - - metal::MTLFCList func_consts = { - {&has_batch, MTL::DataType::DataTypeBool, 10}, - {&use_out_source, MTL::DataType::DataTypeBool, 100}, - {&do_axpby, MTL::DataType::DataTypeBool, 110}, - {&align_M, MTL::DataType::DataTypeBool, 200}, - {&align_N, MTL::DataType::DataTypeBool, 201}, - {&align_K, MTL::DataType::DataTypeBool, 202}, - }; - - // clang-format off - kname << "_has_batch_" << (has_batch ? 't' : 'n') - << "_use_out_source_" << (use_out_source ? 't' : 'n') - << "_do_axpby_" << (do_axpby ? 't' : 'n') - << "_align_M_" << (align_M ? 't' : 'n') - << "_align_N_" << (align_N ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on - - std::string hash_name = kname.str(); - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = get_steel_gemm_fused_kernel( - d, - base_name, - hash_name, - func_consts, - out, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn); - - compute_encoder.set_compute_pipeline_state(kernel); - - int tn = (N + bn - 1) / bn; - int tm = (M + bm - 1) / bm; - - // TODO: Explore device-based tuning for swizzle - int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2); - - // Prepare steel matmul params - GEMMParams gemm_params{ - /* const int M = */ M, - /* const int N = */ N, - /* const int K = */ K, - /* const int lda = */ lda, - /* const int ldb = */ ldb, - /* const int ldd = */ N, - /* const int tiles_n = */ tn, - /* const int tiles_m = */ tm, - /* const int64_t batch_stride_a = */ A_batch_stride.back(), - /* const int64_t batch_stride_b = */ B_batch_stride.back(), - /* const int64_t batch_stride_d = */ matrix_stride_out, - /* const int swizzle_log = */ swizzle_log, - /* const int gemm_k_iterations_aligned = */ (K / bk), - /* const int batch_ndim = */ int(batch_shape.size())}; - - GEMMAddMMParams params{ - /* const int ldc = */ ldc, - /* const int fdc = */ fdc, - /* const int64_t batch_stride_c = */ C_batch_stride.back(), - /* const float alpha = */ alpha_, - /* const float beta = */ beta_}; - - int tile = 1 << swizzle_log; - tm = (tm + tile - 1) / tile; - tn = tn * tile; - - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); - Strides batch_strides = A_batch_stride; batch_strides.insert( batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); batch_strides.insert( batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end()); - // Launch kernel - compute_encoder.set_input_array(a, 0); - compute_encoder.set_input_array(b, 1); - compute_encoder.set_input_array(c, 2); - compute_encoder.set_output_array(out, 3); - - compute_encoder.set_bytes(gemm_params, 4); - compute_encoder.set_bytes(params, 5); - - compute_encoder.set_vector_bytes(batch_shape, 6); - compute_encoder.set_vector_bytes(batch_strides, 7); - - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - d.add_temporaries(std::move(copies), s.index); + return steel_matmul_regular_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* int ldd = */ ldd, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ std::move(batch_shape), + /* Strides batch_strides = */ std::move(batch_strides), + /* int64_t A_batch_stride = */ A_batch_stride.back(), + /* int64_t B_batch_stride = */ B_batch_stride.back(), + /* int64_t matrix_stride_out = */ int64_t(M) * ldd, + /* int64_t C_batch_stride = */ C_batch_stride.back(), + /* float alpha = */ alpha_, + /* float beta = */ beta_); } void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) {