// Copyright © 2023-2024 Apple Inc. #pragma once #include #include #include "mlx/utils.h" namespace mlx::core::fast { array rms_norm( const array& x, const std::optional& weight, float eps, StreamOrDevice s = {}); array layer_norm( const array& x, const std::optional& weight, const std::optional& bias, float eps, StreamOrDevice s = {}); array rope( const array& x, int dims, bool traditional, std::optional base, float scale, int offset, const std::optional& freqs = std::nullopt, StreamOrDevice s = {}); array rope( const array& x, int dims, bool traditional, std::optional base, float scale, const array& offset, const std::optional& freqs = std::nullopt, StreamOrDevice s = {}); /** Computes: O = softmax(Q @ K.T) @ V **/ array scaled_dot_product_attention( const array& queries, const array& keys, const array& values, const float scale, const std::variant& mask = {}, const std::optional memory_efficient_threshold = std::nullopt, StreamOrDevice s = {}); std::tuple affine_quantize( const array& w, int group_size = 64, int bits = 4, StreamOrDevice s = {}); array affine_dequantize( const array& w, const array& scales, const array& biases, int group_size = 64, int bits = 4, StreamOrDevice s = {}); typedef std::variant TemplateArg; typedef std::function( const std::vector&, const std::vector&, const std::vector&, std::tuple, std::tuple, std::vector>, std::optional, bool, StreamOrDevice)> MetalKernelFunction; MetalKernelFunction metal_kernel( const std::string& name, const std::vector& input_names, const std::vector& output_names, const std::string& source, const std::string& header = "", bool ensure_row_contiguous = true, bool atomic_outputs = false); } // namespace mlx::core::fast