#include <optional>
#include "mlx/utils.h"
Go to the source code of this file.
|
array | mlx::core::fast::rms_norm (const array &x, const array &weight, float eps, StreamOrDevice s={}) |
|
array | mlx::core::fast::layer_norm (const array &x, const std::optional< array > &weight, const std::optional< array > &bias, float eps, StreamOrDevice s={}) |
|
array | mlx::core::fast::rope (const array &x, int dims, bool traditional, float base, float scale, int offset, StreamOrDevice s={}) |
|
array | mlx::core::fast::scaled_dot_product_attention (const array &queries, const array &keys, const array &values, const float scale, const std::optional< array > &mask=std::nullopt, StreamOrDevice s={}) |
| Computes: O = softmax(Q @ K.T) @ V.
|
|