MLX
Loading...
Searching...
No Matches
Namespaces | Typedefs | Functions
fast.h File Reference
#include <optional>
#include "mlx/utils.h"

Go to the source code of this file.

Namespaces

namespace  mlx
 
namespace  mlx::core
 
namespace  mlx::core::fast
 

Typedefs

typedef std::variant< int, bool, Dtypemlx::core::fast::TemplateArg
 
typedef std::function< std::vector< array >(const std::vector< array > &, const std::vector< std::vector< int > > &, const std::vector< Dtype > &, std::tuple< int, int, int >, std::tuple< int, int, int >, std::vector< std::pair< std::string, TemplateArg > >, std::optional< float >, bool, StreamOrDevicemlx::core::fast::MetalKernelFunction)
 

Functions

array mlx::core::fast::rms_norm (const array &x, const array &weight, float eps, StreamOrDevice s={})
 
array mlx::core::fast::layer_norm (const array &x, const std::optional< array > &weight, const std::optional< array > &bias, float eps, StreamOrDevice s={})
 
array mlx::core::fast::rope (const array &x, int dims, bool traditional, std::optional< float > base, float scale, int offset, const std::optional< array > &freqs=std::nullopt, StreamOrDevice s={})
 
array mlx::core::fast::scaled_dot_product_attention (const array &queries, const array &keys, const array &values, const float scale, const std::optional< array > &mask=std::nullopt, const std::optional< int > memory_efficient_threshold=std::nullopt, StreamOrDevice s={})
 Computes: O = softmax(Q @ K.T) @ V.
 
std::tuple< array, array, arraymlx::core::fast::affine_quantize (const array &w, int group_size=64, int bits=4, StreamOrDevice s={})
 
array mlx::core::fast::affine_quantize (const array &w, const array &scales, const array &biases, int group_size=64, int bits=4, StreamOrDevice s={})
 
array mlx::core::fast::affine_dequantize (const array &w, const array &scales, const array &biases, int group_size=64, int bits=4, StreamOrDevice s={})
 
MetalKernelFunction mlx::core::fast::metal_kernel (const std::string &name, const std::vector< std::string > &input_names, const std::vector< std::string > &output_names, const std::string &source, const std::string &header="", bool ensure_row_contiguous=true, bool atomic_outputs=false)