// Copyright © 2024 Apple Inc. #include "mlx/backend/common/compiled.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" using namespace fmt::literals; namespace mlx::core { std::string op_name(const array& arr) { std::ostringstream op_t; arr.primitive().print(op_t); return op_t.str(); } MTL::ComputePipelineState* get_arange_kernel( metal::Device& d, const std::string& kernel_name, const array& out) { auto lib = d.get_library(kernel_name, [&]() { std::string kernel_source = metal::utils(); kernel_source += metal::arange(); kernel_source += get_template_definition( kernel_name, "arange", get_type_string(out.dtype())); return kernel_source; }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_unary_kernel( metal::Device& d, const std::string& kernel_name, Dtype in_type, Dtype out_type, const std::string op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { auto in_t = get_type_string(in_type); auto out_t = get_type_string(out_type); std::string kernel_source = metal::utils(); concatenate(kernel_source, metal::unary_ops(), metal::unary()); kernel_source += get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op, 1); if (get_work_per_thread(in_type) > 1) { kernel_source += get_template_definition("vn_" + lib_name, "unary_v", in_t, out_t, op); } kernel_source += get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op); kernel_source += get_template_definition( "gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "int"); kernel_source += get_template_definition( "gn4large_" + lib_name, "unary_g", in_t, out_t, op, 4); return kernel_source; }); return d.get_kernel(kernel_name, lib); } void append_binary_kernels( const std::string lib_name, Dtype in_type, Dtype out_type, const std::string op, std::string& kernel_source) { const std::array, 7> kernel_types = {{ {"ss", "binary_ss"}, {"vs2", "binary_vs2"}, {"sv2", "binary_sv2"}, {"vv2", "binary_vv2"}, {"g1large", "binary_g_nd1"}, {"g2large", "binary_g_nd2"}, {"g3large", "binary_g_nd3"}, }}; auto in_t = get_type_string(in_type); auto out_t = get_type_string(out_type); for (auto& [name, func] : kernel_types) { kernel_source += get_template_definition(name + "_" + lib_name, func, in_t, out_t, op); } kernel_source += get_template_definition( "vs_" + lib_name, "binary_vs", in_t, out_t, op, 1); kernel_source += get_template_definition( "sv_" + lib_name, "binary_sv", in_t, out_t, op, 1); kernel_source += get_template_definition( "vv_" + lib_name, "binary_vv", in_t, out_t, op, 1); if (get_work_per_thread(in_type) > 1) { kernel_source += get_template_definition( "vsn_" + lib_name, "binary_vs", in_t, out_t, op); kernel_source += get_template_definition( "svn_" + lib_name, "binary_sv", in_t, out_t, op); kernel_source += get_template_definition( "vvn_" + lib_name, "binary_vv", in_t, out_t, op); } kernel_source += get_template_definition( "g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int"); kernel_source += get_template_definition( "g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "int"); kernel_source += get_template_definition( "g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "int"); kernel_source += get_template_definition( "gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "int"); kernel_source += get_template_definition( "gn4large_" + lib_name, "binary_g", in_t, out_t, op, 4); } MTL::ComputePipelineState* get_binary_kernel( metal::Device& d, const std::string& kernel_name, Dtype in_type, Dtype out_type, const std::string op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { std::string kernel_source; kernel_source = metal::utils(); concatenate(kernel_source, metal::binary_ops(), metal::binary()); append_binary_kernels(lib_name, in_type, out_type, op, kernel_source); return kernel_source; }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_binary_two_kernel( metal::Device& d, const std::string& kernel_name, Dtype in_type, Dtype out_type, const std::string op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { std::string kernel_source = metal::utils(); concatenate(kernel_source, metal::binary_ops(), metal::binary_two()); append_binary_kernels(lib_name, in_type, out_type, op, kernel_source); return kernel_source; }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_ternary_kernel( metal::Device& d, const std::string& kernel_name, Dtype type, const std::string op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { auto t_str = get_type_string(type); std::string kernel_source = metal::utils(); concatenate(kernel_source, metal::ternary_ops(), metal::ternary()); const std::array, 4> kernel_types = {{ {"v2", "ternary_v2"}, {"g1large", "ternary_g_nd1"}, {"g2large", "ternary_g_nd2"}, {"g3large", "ternary_g_nd3"}, }}; for (auto& [name, func] : kernel_types) { kernel_source += get_template_definition(name + "_" + lib_name, func, t_str, op); } if (get_work_per_thread(type) > 1) { kernel_source += get_template_definition("vn_" + lib_name, "ternary_v", t_str, op); } kernel_source += get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1); kernel_source += get_template_definition( "g1_" + lib_name, "ternary_g_nd1", t_str, op, "int"); kernel_source += get_template_definition( "g2_" + lib_name, "ternary_g_nd2", t_str, op, "int"); kernel_source += get_template_definition( "g3_" + lib_name, "ternary_g_nd3", t_str, op, "int"); kernel_source += get_template_definition( "gn2_" + lib_name, "ternary_g", t_str, op, 2, "int"); kernel_source += get_template_definition( "gn4large_" + lib_name, "ternary_g", t_str, op, 4); return kernel_source; }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_copy_kernel( metal::Device& d, const std::string& kernel_name, const array& in, const array& out) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { std::string kernel_source = metal::utils(); kernel_source += metal::copy(); auto in_type = get_type_string(in.dtype()); auto out_type = get_type_string(out.dtype()); kernel_source += get_template_definition( "s_" + lib_name, "copy_s", in_type, out_type, 1); kernel_source += get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type); kernel_source += get_template_definition( "v_" + lib_name, "copy_v", in_type, out_type, 1); kernel_source += get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type); if (get_work_per_thread(out.dtype()) > 1) { kernel_source += get_template_definition( "sn_" + lib_name, "copy_s", in_type, out_type); kernel_source += get_template_definition( "vn_" + lib_name, "copy_v", in_type, out_type); } kernel_source += get_template_definition( "g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int"); kernel_source += get_template_definition( "g2_" + lib_name, "copy_g_nd2", in_type, out_type, "int"); kernel_source += get_template_definition( "g3_" + lib_name, "copy_g_nd3", in_type, out_type, "int"); kernel_source += get_template_definition( "gn2_" + lib_name, "copy_g", in_type, out_type, 2, "int"); kernel_source += get_template_definition( "gg1_" + lib_name, "copy_gg_nd1", in_type, out_type, "int"); kernel_source += get_template_definition( "gg2_" + lib_name, "copy_gg_nd2", in_type, out_type, "int"); kernel_source += get_template_definition( "gg3_" + lib_name, "copy_gg_nd3", in_type, out_type, "int"); kernel_source += get_template_definition( "ggn2_" + lib_name, "copy_gg", in_type, out_type, 2, "int"); kernel_source += get_template_definition( "g1large_" + lib_name, "copy_g_nd1", in_type, out_type); kernel_source += get_template_definition( "g2large_" + lib_name, "copy_g_nd2", in_type, out_type); kernel_source += get_template_definition( "g3large_" + lib_name, "copy_g_nd3", in_type, out_type); kernel_source += get_template_definition( "gn4large_" + lib_name, "copy_g", in_type, out_type, 4); kernel_source += get_template_definition( "gg1large_" + lib_name, "copy_gg_nd1", in_type, out_type); kernel_source += get_template_definition( "gg2large_" + lib_name, "copy_gg_nd2", in_type, out_type); kernel_source += get_template_definition( "gg3large_" + lib_name, "copy_gg_nd3", in_type, out_type); kernel_source += get_template_definition( "ggn4large_" + lib_name, "copy_gg", in_type, out_type, 4); return kernel_source; }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_dynamic_copy_kernel( metal::Device& d, const std::string& kernel_name, const array& in, const array& out) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { std::string kernel_source = metal::utils(); kernel_source += metal::copy(); auto in_type = get_type_string(in.dtype()); auto out_type = get_type_string(out.dtype()); kernel_source += get_template_definition( "gg1_" + lib_name, "copy_gg_dynamic_nd1", in_type, out_type, "int"); kernel_source += get_template_definition( "gg2_" + lib_name, "copy_gg_dynamic_nd2", in_type, out_type, "int"); kernel_source += get_template_definition( "gg3_" + lib_name, "copy_gg_dynamic_nd3", in_type, out_type, "int"); kernel_source += get_template_definition( "ggn2_" + lib_name, "copy_gg_dynamic", in_type, out_type, 2, "int"); kernel_source += get_template_definition( "gg1large_" + lib_name, "copy_gg_dynamic_nd1", in_type, out_type); kernel_source += get_template_definition( "gg2large_" + lib_name, "copy_gg_dynamic_nd2", in_type, out_type); kernel_source += get_template_definition( "gg3large_" + lib_name, "copy_gg_dynamic_nd3", in_type, out_type); kernel_source += get_template_definition( "ggn4large_" + lib_name, "copy_gg_dynamic", in_type, out_type, 4); return kernel_source; }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_softmax_kernel( metal::Device& d, const std::string& kernel_name, bool precise, const array& out) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&] { std::string kernel_source = metal::utils(); auto in_type = get_type_string(out.dtype()); auto acc_type = get_type_string(precise ? float32 : out.dtype()); kernel_source += metal::softmax(); kernel_source += get_template_definition( "block_" + lib_name, "softmax_single_row", in_type, acc_type); kernel_source += get_template_definition( "looped_" + lib_name, "softmax_looped", in_type, acc_type); return kernel_source; }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_logsumexp_kernel( metal::Device& d, const std::string& kernel_name, const array& out) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&] { auto t_str = get_type_string(out.dtype()); std::string kernel_source; kernel_source = metal::utils(); kernel_source += metal::logsumexp(); kernel_source += get_template_definition("block_" + lib_name, "logsumexp", t_str); kernel_source += get_template_definition( "looped_" + lib_name, "logsumexp_looped", t_str); return kernel_source; }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_scan_kernel( metal::Device& d, const std::string& kernel_name, bool reverse, bool inclusive, const std::string& reduce_type, const array& in, const array& out) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { auto out_type = get_type_string(out.dtype()); std::string op = "Cum" + reduce_type + "<" + out_type + ">"; op[3] = toupper(op[3]); std::ostringstream kernel_source; kernel_source << metal::utils() << metal::scan(); const std::array, 2> scan_kernels = {{ {"contig_", "contiguous_scan"}, {"strided_", "strided_scan"}, }}; for (auto& [prefix, kernel] : scan_kernels) { kernel_source << get_template_definition( prefix + lib_name, kernel, get_type_string(in.dtype()), get_type_string(out.dtype()), op, in.itemsize() <= 4 ? 4 : 2, inclusive, reverse); } return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_sort_kernel( metal::Device& d, const std::string& kernel_name, const array& in, const array& out, int bn, int tn) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; auto in_type = get_type_string(in.dtype()); auto out_type = get_type_string(out.dtype()); kernel_source << metal::utils() << metal::sort(); for (bool is_argsort : {true, false}) { std::string bool_string = is_argsort ? "true" : "false"; std::string func_string = is_argsort ? "carg_" : "c_"; kernel_source << get_template_definition( func_string + lib_name, "block_sort", in_type, out_type, bool_string, bn, tn); kernel_source << get_template_definition( "n" + func_string + lib_name, "block_sort_nc", in_type, out_type, bool_string, bn, tn); } return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_mb_sort_kernel( metal::Device& d, const std::string& kernel_name, const array& in, const array& idx, int bn, int tn) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::sort(); std::array, 3> kernel_types = { {{"sort_", "mb_block_sort"}, {"partition_", "mb_block_partition"}, {"merge_", "mb_block_merge"}}}; for (auto& [name, func] : kernel_types) { kernel_source << get_template_definition( name + lib_name, func, get_type_string(in.dtype()), get_type_string(idx.dtype()), "true", bn, tn); } return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_reduce_init_kernel( metal::Device& d, const std::string& kernel_name, const std::string& func_name, const std::string& op_name, const Dtype& out_type) { auto lib = d.get_library(kernel_name, [&]() { std::string op_type = op_name; op_type[0] = std::toupper(op_name[0]); auto out_t = get_type_string(out_type); std::string op = op_type + "<" + out_t + ">"; std::string kernel_source = metal::utils(); kernel_source += metal::reduce_utils(); kernel_source += metal::reduce(); kernel_source += get_template_definition(kernel_name, func_name, out_t, op); return kernel_source; }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_reduce_kernel( metal::Device& d, const std::string& kernel_name, const std::string& func_name, const std::string& op_name, const Dtype& in_type, const Dtype& out_type, const std::string& idx_t, int ndim /* = -1 */, int bm /* = -1 */, int bn /* = -1 */) { auto lib = d.get_library(kernel_name, [&]() { std::string op_type = op_name; op_type[0] = std::toupper(op_name[0]); auto in_t = get_type_string(in_type); auto out_t = get_type_string(out_type); std::string op = op_type + "<" + out_t + ">"; std::string kernel_source = metal::utils(); concatenate(kernel_source, metal::reduce_utils(), metal::reduce()); if (bm >= 0) { kernel_source += get_template_definition( kernel_name, func_name, in_t, out_t, op, idx_t, ndim, bm, bn); } else if (ndim >= 0) { kernel_source += get_template_definition( kernel_name, func_name, in_t, out_t, op, idx_t, ndim); } else { kernel_source += get_template_definition( kernel_name, func_name, in_t, out_t, op, idx_t); } return kernel_source; }); auto st = d.get_kernel(kernel_name, lib); return st; } MTL::ComputePipelineState* get_steel_gemm_fused_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::gemm() << metal::steel_gemm_fused() << get_template_definition( lib_name, "gemm", get_type_string(out.dtype()), bm, bn, bk, wm, wn, transpose_a, transpose_b); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib, hash_name, func_consts); } MTL::ComputePipelineState* get_steel_gemm_splitk_kernel( metal::Device& d, const std::string& kernel_name, const array& in, const array& out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::gemm() << metal::steel_gemm_splitk() << get_template_definition( lib_name, "gemm_splitk", get_type_string(in.dtype()), get_type_string(out.dtype()), bm, bn, bk, wm, wn, transpose_a, transpose_b, mn_aligned, k_aligned); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel( metal::Device& d, const std::string& kernel_name, const array& in, const array& out, bool axbpy) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::gemm() << metal::steel_gemm_splitk() << get_template_definition( lib_name, axbpy ? "gemm_splitk_accum_axpby" : "gemm_splitk_accum", get_type_string(in.dtype()), get_type_string(out.dtype())); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_steel_gemm_masked_kernel( metal::Device& d, const std::string& kernel_name, const array& out, const std::optional& mask_out, const std::optional& mask_op, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; auto out_mask_type = mask_out.has_value() ? get_type_string((*mask_out).dtype()) : "nomask_t"; auto op_mask_type = mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t"; kernel_source << metal::utils() << metal::gemm() << metal::steel_gemm_masked() << get_template_definition( lib_name, "block_masked_gemm", get_type_string(out.dtype()), out_mask_type, op_mask_type, bm, bn, bk, wm, wn, transpose_a, transpose_b, mn_aligned, k_aligned); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_steel_gemm_gather_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool rhs) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::string kernel_source; concatenate( kernel_source, metal::utils(), metal::gemm(), metal::steel_gemm_gather(), get_template_definition( lib_name, rhs ? "gather_mm_rhs" : "gather_mm", get_type_string(out.dtype()), bm, bn, bk, wm, wn, transpose_a, transpose_b)); return kernel_source; }); return d.get_kernel(kernel_name, lib, hash_name, func_consts); } MTL::ComputePipelineState* get_gemv_masked_kernel( metal::Device& d, const std::string& kernel_name, const array& out, const std::optional& mask_out, const std::optional& mask_op, bool transpose_mat, int bm, int bn, int sm, int sn, int tm, int tn, bool contiguous) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; auto out_mask_type = mask_out.has_value() ? get_type_string((*mask_out).dtype()) : "nomask_t"; auto op_mask_type = mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t"; kernel_source << metal::utils() << metal::gemv_masked() << get_template_definition( lib_name, (transpose_mat) ? "gemv_t_masked" : "gemv_masked", get_type_string(out.dtype()), out_mask_type, op_mask_type, bm, bn, sm, sn, tm, tn, contiguous ? 0 : 1); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_steel_conv_kernel( metal::Device& d, const std::string& kernel_name, const array& out, int bm, int bn, int bk, int wm, int wn, int n_channel_specialization, bool small_filter) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::conv() << metal::steel_conv() << get_template_definition( lib_name, "implicit_gemm_conv_2d", get_type_string(out.dtype()), bm, bn, bk, wm, wn, n_channel_specialization, small_filter); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_steel_conv_general_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& out, int bm, int bn, int bk, int wm, int wn) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::conv() << metal::steel_conv_general() << get_template_definition( lib_name, "implicit_gemm_conv_2d_general", get_type_string(out.dtype()), bm, bn, bk, wm, wn); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib, hash_name, func_consts); } MTL::ComputePipelineState* get_fft_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const std::string& template_def) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; std::string kernel_string; kernel_source << metal::fft() << template_def; return kernel_source.str(); }); return d.get_kernel(kernel_name, lib, hash_name, func_consts); } MTL::ComputePipelineState* get_quantized_kernel( metal::Device& d, const std::string& kernel_name, const std::string& template_def) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::gemm() << metal::quantized() << template_def; return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_gather_qmm_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& x, int group_size, int bits, int bm, int bn, int bk, int wm, int wn, bool transpose) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::string kernel_source; concatenate( kernel_source, metal::utils(), metal::gemm(), metal::quantized(), get_template_definition( lib_name, "gather_qmm_rhs", get_type_string(x.dtype()), group_size, bits, bm, bn, bk, wm, wn, transpose)); return kernel_source; }); return d.get_kernel(kernel_name, lib, hash_name, func_consts); } } // namespace mlx::core