12 const std::string& kernel_name,
17 const std::string& kernel_name,
20 const std::string op);
24 const std::string& kernel_name,
27 const std::string op);
31 const std::string& kernel_name,
34 const std::string op);
38 const std::string& kernel_name,
40 const std::string op);
44 const std::string& kernel_name,
50 const std::string& kernel_name,
56 const std::string& kernel_name,
62 const std::string& kernel_name,
65 const std::string& reduce_type,
71 const std::string& kernel_name,
79 const std::string& kernel_name,
87 const std::string& kernel_name,
88 const std::string& func_name,
89 const std::string& op_name,
90 const Dtype& out_type);
94 const std::string& kernel_name,
95 const std::string& func_name,
96 const std::string& op_name,
98 const Dtype& out_type,
99 const std::string& idx_t,
106 const std::string& kernel_name,
107 const std::string& hash_name,
120 const std::string& kernel_name,
135 const std::string& kernel_name,
142 const std::string& kernel_name,
144 const std::optional<array>& mask_out,
145 const std::optional<array>& mask_op,
158 const std::string& kernel_name,
165 int n_channel_specialization,
170 const std::string& kernel_name,
172 const std::optional<array>& mask_out,
173 const std::optional<array>& mask_op,
185 const std::string& kernel_name,
195 const std::string& kernel_name,
196 const std::string& hash_name,
198 const std::string& template_def);
202 const std::string& kernel_name,
203 const std::string& template_def);
206template <
typename...
Args>
209 std::ostringstream s;
212 auto add_arg = [&s, &first](
const auto& arg) {
219 (add_arg(args), ...);
222 "\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n",
array contiguous(const array &a, bool allow_col_major=false, StreamOrDevice s={})
MTL::ComputePipelineState * get_copy_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out)
MTL::ComputePipelineState * get_steel_gemm_splitk_accum_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, bool axbpy)
MTL::ComputePipelineState * get_reduce_kernel(metal::Device &d, const std::string &kernel_name, const std::string &func_name, const std::string &op_name, const Dtype &in_type, const Dtype &out_type, const std::string &idx_t, int ndim=-1, int bm=-1, int bn=-1)
MTL::ComputePipelineState * get_fft_kernel(metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const std::string &template_def)
MTL::ComputePipelineState * get_softmax_kernel(metal::Device &d, const std::string &kernel_name, bool precise, const array &out)
MTL::ComputePipelineState * get_binary_kernel(metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)
MTL::ComputePipelineState * get_binary_two_kernel(metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)
MTL::ComputePipelineState * get_ternary_kernel(metal::Device &d, const std::string &kernel_name, Dtype type, const std::string op)
MTL::ComputePipelineState * get_arange_kernel(metal::Device &d, const std::string &kernel_name, const array &out)
MTL::ComputePipelineState * get_sort_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, int bn, int tn)
MTL::ComputePipelineState * get_steel_gemm_fused_kernel(metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn)
MTL::ComputePipelineState * get_gemv_masked_kernel(metal::Device &d, const std::string &kernel_name, const array &out, const std::optional< array > &mask_out, const std::optional< array > &mask_op, bool transpose_mat, int bm, int bn, int sm, int sn, int tm, int tn, bool contiguous)
MTL::ComputePipelineState * get_quantized_kernel(metal::Device &d, const std::string &kernel_name, const std::string &template_def)
std::string get_template_definition(std::string name, std::string func, Args... args)
Definition kernels.h:208
MTL::ComputePipelineState * get_steel_gemm_masked_kernel(metal::Device &d, const std::string &kernel_name, const array &out, const std::optional< array > &mask_out, const std::optional< array > &mask_op, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned)
MTL::ComputePipelineState * get_steel_conv_general_kernel(metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn)
std::vector< array > Args
Definition export.h:11
MTL::ComputePipelineState * get_steel_conv_kernel(metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn, int n_channel_specialization, bool small_filter)
MTL::ComputePipelineState * get_dynamic_copy_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out)
MTL::ComputePipelineState * get_reduce_init_kernel(metal::Device &d, const std::string &kernel_name, const std::string &func_name, const std::string &op_name, const Dtype &out_type)
MTL::ComputePipelineState * get_scan_kernel(metal::Device &d, const std::string &kernel_name, bool reverse, bool inclusive, const std::string &reduce_type, const array &in, const array &out)
MTL::ComputePipelineState * get_steel_gemm_splitk_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned)
MTL::ComputePipelineState * get_mb_sort_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &idx, int bn, int tn)
MTL::ComputePipelineState * get_unary_kernel(metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)