MLX
Loading...
Searching...
No Matches
kernels.h File Reference
#include <fmt/format.h>
#include "mlx/array.h"
#include "mlx/backend/metal/device.h"

Go to the source code of this file.

Namespaces

namespace  mlx
 
namespace  mlx::core
 

Functions

MTL::ComputePipelineState * mlx::core::get_arange_kernel (metal::Device &d, const std::string &kernel_name, const array &out)
 
MTL::ComputePipelineState * mlx::core::get_unary_kernel (metal::Device &d, const std::string &kernel_name, Dtype out_type, const std::string op)
 
MTL::ComputePipelineState * mlx::core::get_binary_kernel (metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)
 
MTL::ComputePipelineState * mlx::core::get_binary_two_kernel (metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)
 
MTL::ComputePipelineState * mlx::core::get_ternary_kernel (metal::Device &d, const std::string &kernel_name, Dtype type, const std::string op)
 
MTL::ComputePipelineState * mlx::core::get_copy_kernel (metal::Device &d, const std::string &kernel_name, const array &in, const array &out)
 
MTL::ComputePipelineState * mlx::core::get_softmax_kernel (metal::Device &d, const std::string &kernel_name, bool precise, const array &out)
 
MTL::ComputePipelineState * mlx::core::get_scan_kernel (metal::Device &d, const std::string &kernel_name, bool reverse, bool inclusive, const std::string &reduce_type, const array &in, const array &out)
 
MTL::ComputePipelineState * mlx::core::get_sort_kernel (metal::Device &d, const std::string &kernel_name, const array &in, const array &out, int bn, int tn)
 
MTL::ComputePipelineState * mlx::core::get_mb_sort_kernel (metal::Device &d, const std::string &kernel_name, const array &in, const array &idx, int bn, int tn)
 
MTL::ComputePipelineState * mlx::core::get_reduce_init_kernel (metal::Device &d, const std::string &kernel_name, const array &out)
 
MTL::ComputePipelineState * mlx::core::get_reduce_kernel (metal::Device &d, const std::string &kernel_name, const std::string &func_name, const std::string &op_name, const array &in, const array &out, int ndim=-1, int bm=-1, int bn=-1)
 
MTL::ComputePipelineState * mlx::core::get_steel_gemm_fused_kernel (metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn)
 
MTL::ComputePipelineState * mlx::core::get_steel_gemm_splitk_kernel (metal::Device &d, const std::string &kernel_name, const array &in, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned)
 
MTL::ComputePipelineState * mlx::core::get_steel_gemm_splitk_accum_kernel (metal::Device &d, const std::string &kernel_name, const array &in, const array &out, bool axbpy)
 
MTL::ComputePipelineState * mlx::core::get_steel_gemm_masked_kernel (metal::Device &d, const std::string &kernel_name, const array &out, const std::optional< array > &mask_out, const std::optional< array > &mask_op, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned)
 
MTL::ComputePipelineState * mlx::core::get_steel_conv_kernel (metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn, int n_channel_specialization, bool small_filter)
 
MTL::ComputePipelineState * mlx::core::get_gemv_masked_kernel (metal::Device &d, const std::string &kernel_name, const array &out, const std::optional< array > &mask_out, const std::optional< array > &mask_op, bool transpose_mat, int bm, int bn, int sm, int sn, int tm, int tn, bool contiguous)
 
MTL::ComputePipelineState * mlx::core::get_steel_conv_general_kernel (metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn)
 
MTL::ComputePipelineState * mlx::core::get_fft_kernel (metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const std::string &template_def)
 
MTL::ComputePipelineState * mlx::core::get_quantized_kernel (metal::Device &d, const std::string &kernel_name, const std::string &template_def)
 
template<typename... Args>
std::string mlx::core::get_template_definition (std::string name, std::string func, Args... args)