MLX
Loading...
Searching...
No Matches
fast.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <map>
6#include <optional>
7
8#include "mlx/utils.h"
9
10namespace mlx::core::fast {
11
13 const array& x,
14 const array& weight,
15 float eps,
16 StreamOrDevice s = {});
17
19 const array& x,
20 const std::optional<array>& weight,
21 const std::optional<array>& bias,
22 float eps,
23 StreamOrDevice s = {});
24
26 const array& x,
27 int dims,
28 bool traditional,
29 std::optional<float> base,
30 float scale,
31 int offset,
32 const std::optional<array>& freqs = std::nullopt,
33 StreamOrDevice s = {});
34
37 const array& queries,
38 const array& keys,
39 const array& values,
40 const float scale,
41 const std::optional<array>& mask = std::nullopt,
42 const std::optional<int>& memory_efficient_threshold = std::nullopt,
43 StreamOrDevice s = {});
44
45std::tuple<array, array, array> affine_quantize(
46 const array& w,
47 int group_size = 64,
48 int bits = 4,
49 StreamOrDevice s = {});
50
52 const array& w,
53 const array& scales,
54 const array& biases,
55 int group_size = 64,
56 int bits = 4,
57 StreamOrDevice s = {});
58
60 const array& w,
61 const array& scales,
62 const array& biases,
63 int group_size = 64,
64 int bits = 4,
65 StreamOrDevice s = {});
66
67typedef std::variant<int, bool, Dtype> TemplateArg;
68
70 public:
72 const std::string& name,
73 const std::string& source,
74 bool ensure_row_contiguous)
75 : name_(name),
76 source_(source),
77 ensure_row_contiguous_(ensure_row_contiguous) {}
78
79 std::map<std::string, array> operator()(
80 std::map<std::string, array>& inputs,
81 std::map<std::string, std::vector<int>> output_shapes,
82 std::map<std::string, Dtype> output_dtypes,
83 std::tuple<int, int, int> grid,
84 std::tuple<int, int, int> threadgroup,
85 std::optional<std::map<std::string, TemplateArg>> template_args =
86 std::nullopt,
87 bool verbose = false,
88 StreamOrDevice s = {});
89
90 private:
91 std::string name_;
92 std::string source_;
93 bool ensure_row_contiguous_ = true;
94};
95} // namespace mlx::core::fast
Definition array.h:20
Definition fast.h:69
MetalKernel(const std::string &name, const std::string &source, bool ensure_row_contiguous)
Definition fast.h:71
std::map< std::string, array > operator()(std::map< std::string, array > &inputs, std::map< std::string, std::vector< int > > output_shapes, std::map< std::string, Dtype > output_dtypes, std::tuple< int, int, int > grid, std::tuple< int, int, int > threadgroup, std::optional< std::map< std::string, TemplateArg > > template_args=std::nullopt, bool verbose=false, StreamOrDevice s={})
Definition fast.h:10
array layer_norm(const array &x, const std::optional< array > &weight, const std::optional< array > &bias, float eps, StreamOrDevice s={})
array affine_dequantize(const array &w, const array &scales, const array &biases, int group_size=64, int bits=4, StreamOrDevice s={})
array rope(const array &x, int dims, bool traditional, std::optional< float > base, float scale, int offset, const std::optional< array > &freqs=std::nullopt, StreamOrDevice s={})
array scaled_dot_product_attention(const array &queries, const array &keys, const array &values, const float scale, const std::optional< array > &mask=std::nullopt, const std::optional< int > &memory_efficient_threshold=std::nullopt, StreamOrDevice s={})
Computes: O = softmax(Q @ K.T) @ V.
std::variant< int, bool, Dtype > TemplateArg
Definition fast.h:67
std::tuple< array, array, array > affine_quantize(const array &w, int group_size=64, int bits=4, StreamOrDevice s={})
array rms_norm(const array &x, const array &weight, float eps, StreamOrDevice s={})
std::variant< std::monostate, Stream, Device > StreamOrDevice
Definition utils.h:14