19    const std::optional<array>& weight,
 
   20    const std::optional<array>& bias,
 
   28    std::optional<float> base,
 
   31    const std::optional<array>& freqs = std::nullopt,
 
   38    std::optional<float> base,
 
   41    const std::optional<array>& freqs = std::nullopt,
 
   50    const std::optional<array>& mask = std::nullopt,
 
   51    const std::optional<int> memory_efficient_threshold = std::nullopt,
 
   70typedef std::function<std::vector<array>(
 
   71    const std::vector<array>&,
 
   72    const std::vector<Shape>&,
 
   73    const std::vector<Dtype>&,
 
   74    std::tuple<int, int, int>,
 
   75    std::tuple<int, int, int>,
 
   76    std::vector<std::pair<std::string, TemplateArg>>,
 
   83    const std::string& name,
 
   84    const std::vector<std::string>& input_names,
 
   85    const std::vector<std::string>& output_names,
 
   86    const std::string& source,
 
   87    const std::string& header = 
"",
 
   88    bool ensure_row_contiguous = 
true,
 
   89    bool atomic_outputs = 
false);
 
 
array layer_norm(const array &x, const std::optional< array > &weight, const std::optional< array > &bias, float eps, StreamOrDevice s={})
 
array affine_dequantize(const array &w, const array &scales, const array &biases, int group_size=64, int bits=4, StreamOrDevice s={})
 
array scaled_dot_product_attention(const array &queries, const array &keys, const array &values, const float scale, const std::optional< array > &mask=std::nullopt, const std::optional< int > memory_efficient_threshold=std::nullopt, StreamOrDevice s={})
Computes: O = softmax(Q @ K.T) @ V.
 
array rope(const array &x, int dims, bool traditional, std::optional< float > base, float scale, int offset, const std::optional< array > &freqs=std::nullopt, StreamOrDevice s={})
 
std::variant< int, bool, Dtype > TemplateArg
Definition fast.h:68
 
std::function< std::vector< array >(const std::vector< array > &, const std::vector< Shape > &, const std::vector< Dtype > &, std::tuple< int, int, int >, std::tuple< int, int, int >, std::vector< std::pair< std::string, TemplateArg > >, std::optional< float >, bool, StreamOrDevice)> MetalKernelFunction
Definition fast.h:80
 
std::tuple< array, array, array > affine_quantize(const array &w, int group_size=64, int bits=4, StreamOrDevice s={})
 
MetalKernelFunction metal_kernel(const std::string &name, const std::vector< std::string > &input_names, const std::vector< std::string > &output_names, const std::string &source, const std::string &header="", bool ensure_row_contiguous=true, bool atomic_outputs=false)
 
array rms_norm(const array &x, const array &weight, float eps, StreamOrDevice s={})
 
std::variant< std::monostate, Stream, Device > StreamOrDevice
Definition utils.h:15