12    const std::string& kernel_name,
 
   17    const std::string& kernel_name,
 
   19    const std::string 
op);
 
   23    const std::string& kernel_name,
 
   26    const std::string 
op);
 
   30    const std::string& kernel_name,
 
   33    const std::string 
op);
 
   37    const std::string& kernel_name,
 
   39    const std::string 
op);
 
   43    const std::string& kernel_name,
 
   49    const std::string& kernel_name,
 
   55    const std::string& kernel_name,
 
   58    const std::string& reduce_type,
 
   64    const std::string& kernel_name,
 
   72    const std::string& kernel_name,
 
   80    const std::string& kernel_name,
 
   85    const std::string& kernel_name,
 
   86    const std::string& op_name,
 
   92    const std::string& kernel_name,
 
   93    const std::string& hash_name,
 
  106    const std::string& kernel_name,
 
  121    const std::string& kernel_name,
 
  128    const std::string& kernel_name,
 
  130    const std::optional<array>& mask_out,
 
  131    const std::optional<array>& mask_op,
 
  144    const std::string& kernel_name,
 
  151    int n_channel_specialization,
 
  156    const std::string& kernel_name,
 
  158    const std::optional<array>& mask_out,
 
  159    const std::optional<array>& mask_op,
 
  171    const std::string& kernel_name,
 
  181    const std::string& kernel_name,
 
  182    const std::string& hash_name,
 
  184    const std::string& template_def);
 
  188    const std::string& kernel_name,
 
  189    const std::string& template_def);
 
  192template <
typename... Args>
 
  195  std::ostringstream s;
 
  198  auto add_arg = [&s, &first](
const auto& arg) {
 
  205  (add_arg(args), ...);
 
  207  std::string base_string = R
"( 
  208template [[host_name("{0}")]] [[kernel]] decltype({1}) {1}; 
  210  return fmt::format(base_string, name, s.str());
 
 
Op op
Definition binary.h:141
 
MTL::ComputePipelineState * get_copy_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out)
 
MTL::ComputePipelineState * get_unary_kernel(metal::Device &d, const std::string &kernel_name, Dtype out_type, const std::string op)
 
MTL::ComputePipelineState * get_steel_gemm_splitk_accum_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, bool axbpy)
 
MTL::ComputePipelineState * get_fft_kernel(metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const std::string &template_def)
 
MTL::ComputePipelineState * get_softmax_kernel(metal::Device &d, const std::string &kernel_name, bool precise, const array &out)
 
MTL::ComputePipelineState * get_binary_kernel(metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)
 
MTL::ComputePipelineState * get_binary_two_kernel(metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)
 
MTL::ComputePipelineState * get_reduce_init_kernel(metal::Device &d, const std::string &kernel_name, const array &out)
 
MTL::ComputePipelineState * get_ternary_kernel(metal::Device &d, const std::string &kernel_name, Dtype type, const std::string op)
 
MTL::ComputePipelineState * get_reduce_kernel(metal::Device &d, const std::string &kernel_name, const std::string &op_name, const array &in, const array &out)
 
MTL::ComputePipelineState * get_arange_kernel(metal::Device &d, const std::string &kernel_name, const array &out)
 
MTL::ComputePipelineState * get_sort_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, int bn, int tn)
 
MTL::ComputePipelineState * get_steel_gemm_fused_kernel(metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn)
 
MTL::ComputePipelineState * get_gemv_masked_kernel(metal::Device &d, const std::string &kernel_name, const array &out, const std::optional< array > &mask_out, const std::optional< array > &mask_op, bool transpose_mat, int bm, int bn, int sm, int sn, int tm, int tn, bool contiguous)
 
MTL::ComputePipelineState * get_quantized_kernel(metal::Device &d, const std::string &kernel_name, const std::string &template_def)
 
std::string get_template_definition(std::string name, std::string func, Args... args)
Definition kernels.h:194
 
MTL::ComputePipelineState * get_steel_gemm_masked_kernel(metal::Device &d, const std::string &kernel_name, const array &out, const std::optional< array > &mask_out, const std::optional< array > &mask_op, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned)
 
MTL::ComputePipelineState * get_steel_conv_general_kernel(metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn)
 
MTL::ComputePipelineState * get_steel_conv_kernel(metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn, int n_channel_specialization, bool small_filter)
 
MTL::ComputePipelineState * get_scan_kernel(metal::Device &d, const std::string &kernel_name, bool reverse, bool inclusive, const std::string &reduce_type, const array &in, const array &out)
 
MTL::ComputePipelineState * get_steel_gemm_splitk_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned)
 
MTL::ComputePipelineState * get_mb_sort_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &idx, int bn, int tn)