MLX
 
Loading...
Searching...
No Matches
fast.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <optional>
6#include <variant>
7
8#include "mlx/utils.h"
9
10namespace mlx::core::fast {
11
13 const array& x,
14 const std::optional<array>& weight,
15 float eps,
16 StreamOrDevice s = {});
17
19 const array& x,
20 const std::optional<array>& weight,
21 const std::optional<array>& bias,
22 float eps,
23 StreamOrDevice s = {});
24
26 const array& x,
27 int dims,
28 bool traditional,
29 std::optional<float> base,
30 float scale,
31 int offset,
32 const std::optional<array>& freqs = std::nullopt,
33 StreamOrDevice s = {});
34
36 const array& x,
37 int dims,
38 bool traditional,
39 std::optional<float> base,
40 float scale,
41 const array& offset,
42 const std::optional<array>& freqs = std::nullopt,
43 StreamOrDevice s = {});
44
47 const array& queries,
48 const array& keys,
49 const array& values,
50 const float scale,
51 const std::variant<std::monostate, std::string, array>& mask = {},
52 const std::optional<int> memory_efficient_threshold = std::nullopt,
53 StreamOrDevice s = {});
54
55std::tuple<array, array, array> affine_quantize(
56 const array& w,
57 int group_size = 64,
58 int bits = 4,
59 StreamOrDevice s = {});
60
62 const array& w,
63 const array& scales,
64 const array& biases,
65 int group_size = 64,
66 int bits = 4,
67 StreamOrDevice s = {});
68
69typedef std::variant<int, bool, Dtype> TemplateArg;
70
71typedef std::function<std::vector<array>(
72 const std::vector<array>&,
73 const std::vector<Shape>&,
74 const std::vector<Dtype>&,
75 std::tuple<int, int, int>,
76 std::tuple<int, int, int>,
77 std::vector<std::pair<std::string, TemplateArg>>,
78 std::optional<float>,
79 bool,
82
84 const std::string& name,
85 const std::vector<std::string>& input_names,
86 const std::vector<std::string>& output_names,
87 const std::string& source,
88 const std::string& header = "",
89 bool ensure_row_contiguous = true,
90 bool atomic_outputs = false);
91
92} // namespace mlx::core::fast
Definition array.h:24
Definition fast.h:10
array layer_norm(const array &x, const std::optional< array > &weight, const std::optional< array > &bias, float eps, StreamOrDevice s={})
array affine_dequantize(const array &w, const array &scales, const array &biases, int group_size=64, int bits=4, StreamOrDevice s={})
array scaled_dot_product_attention(const array &queries, const array &keys, const array &values, const float scale, const std::variant< std::monostate, std::string, array > &mask={}, const std::optional< int > memory_efficient_threshold=std::nullopt, StreamOrDevice s={})
Computes: O = softmax(Q @ K.T) @ V.
array rope(const array &x, int dims, bool traditional, std::optional< float > base, float scale, int offset, const std::optional< array > &freqs=std::nullopt, StreamOrDevice s={})
array rms_norm(const array &x, const std::optional< array > &weight, float eps, StreamOrDevice s={})
std::variant< int, bool, Dtype > TemplateArg
Definition fast.h:69
std::function< std::vector< array >(const std::vector< array > &, const std::vector< Shape > &, const std::vector< Dtype > &, std::tuple< int, int, int >, std::tuple< int, int, int >, std::vector< std::pair< std::string, TemplateArg > >, std::optional< float >, bool, StreamOrDevice)> MetalKernelFunction
Definition fast.h:81
std::tuple< array, array, array > affine_quantize(const array &w, int group_size=64, int bits=4, StreamOrDevice s={})
MetalKernelFunction metal_kernel(const std::string &name, const std::vector< std::string > &input_names, const std::vector< std::string > &output_names, const std::string &source, const std::string &header="", bool ensure_row_contiguous=true, bool atomic_outputs=false)
std::variant< std::monostate, Stream, Device > StreamOrDevice
Definition utils.h:15