MLX
Loading...
Searching...
No Matches
fast.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <optional>
6
7#include "mlx/utils.h"
8
9namespace mlx::core::fast {
10
12 const array& x,
13 const array& weight,
14 float eps,
15 StreamOrDevice s = {});
16
18 const array& x,
19 const std::optional<array>& weight,
20 const std::optional<array>& bias,
21 float eps,
22 StreamOrDevice s = {});
23
25 const array& x,
26 int dims,
27 bool traditional,
28 float base,
29 float scale,
30 int offset,
31 StreamOrDevice s = {});
32
35 const array& queries,
36 const array& keys,
37 const array& values,
38 const float scale,
39 const std::optional<array>& mask = std::nullopt,
40 StreamOrDevice s = {});
41
42} // namespace mlx::core::fast
Definition array.h:20
Definition fast.h:9
array layer_norm(const array &x, const std::optional< array > &weight, const std::optional< array > &bias, float eps, StreamOrDevice s={})
array rope(const array &x, int dims, bool traditional, float base, float scale, int offset, StreamOrDevice s={})
array scaled_dot_product_attention(const array &queries, const array &keys, const array &values, const float scale, const std::optional< array > &mask=std::nullopt, StreamOrDevice s={})
Computes: O = softmax(Q @ K.T) @ V.
array rms_norm(const array &x, const array &weight, float eps, StreamOrDevice s={})
std::variant< std::monostate, Stream, Device > StreamOrDevice
Definition utils.h:14