MLX
Loading...
Searching...
No Matches
fast_primitives.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#include <optional>
4
5#include "mlx/primitives.h"
6
7namespace mlx::core::fast {
8
9// Custom primitive accepts a fallback function which it uses for
10// transformations. Transformations are virtual so that derived classes may
11// override the default behavior.
12class Custom : public Primitive {
13 public:
14 explicit Custom(
16 std::function<std::vector<array>(std::vector<array>)> fallback)
17 : Primitive(stream), fallback_(fallback) {}
18
19 virtual std::pair<std::vector<array>, std::vector<int>> vmap(
20 const std::vector<array>& inputs,
21 const std::vector<int>& axes) override;
22
23 virtual std::vector<array> jvp(
24 const std::vector<array>& primals,
25 const std::vector<array>& tangents,
26 const std::vector<int>& argnums) override;
27
28 virtual std::vector<array> vjp(
29 const std::vector<array>& primals,
30 const std::vector<array>& cotangents,
31 const std::vector<int>& argnums,
32 const std::vector<array>& outputs) override;
33
34 private:
35 std::function<std::vector<array>(std::vector<array>)> fallback_;
36};
37
38class RMSNorm : public Custom {
39 public:
42 std::function<std::vector<array>(std::vector<array>)> fallback,
43 float eps)
44 : Custom(stream, fallback), eps_(eps) {}
45
46 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
47 override {
48 throw std::runtime_error("NYI");
49 }
50 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
51 override;
52
53 std::vector<array> vjp(
54 const std::vector<array>& primals,
55 const std::vector<array>& cotangents,
56 const std::vector<int>& argnums,
57 const std::vector<array>& outputs) override;
58
60 bool is_equivalent(const Primitive& other) const override;
61
62 private:
63 std::function<std::vector<array>(std::vector<array>)> fallback_;
64 float eps_;
65};
66
67class RMSNormVJP : public Custom {
68 public:
71 std::function<std::vector<array>(std::vector<array>)> fallback,
72 float eps)
73 : Custom(stream, fallback), eps_(eps) {}
74
75 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
76 override {
77 throw std::runtime_error("NYI");
78 }
79 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
80 override;
81
83 bool is_equivalent(const Primitive& other) const override;
84
85 private:
86 std::function<std::vector<array>(std::vector<array>)> fallback_;
87 float eps_;
88};
89
90class LayerNorm : public Custom {
91 public:
94 std::function<std::vector<array>(std::vector<array>)> fallback,
95 float eps)
96 : Custom(stream, fallback), eps_(eps) {}
97
98 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
99 override {
100 throw std::runtime_error("NYI");
101 }
102 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
103 override;
104
105 std::vector<array> vjp(
106 const std::vector<array>& primals,
107 const std::vector<array>& cotangents,
108 const std::vector<int>& argnums,
109 const std::vector<array>& outputs) override;
110
112 bool is_equivalent(const Primitive& other) const override;
113
114 private:
115 std::function<std::vector<array>(std::vector<array>)> fallback_;
116 float eps_;
117};
118
119class LayerNormVJP : public Custom {
120 public:
123 std::function<std::vector<array>(std::vector<array>)> fallback,
124 float eps)
125 : Custom(stream, fallback), eps_(eps) {}
126
127 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
128 override {
129 throw std::runtime_error("NYI");
130 }
131 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
132 override;
133
135 bool is_equivalent(const Primitive& other) const override;
136
137 private:
138 std::function<std::vector<array>(std::vector<array>)> fallback_;
139 float eps_;
140};
141
142class RoPE : public Custom {
143 public:
146 std::function<std::vector<array>(std::vector<array>)> fallback,
147 int dims,
148 bool traditional,
149 float base,
150 float scale,
151 int offset,
152 bool forward)
153 : Custom(stream, fallback),
154 dims_(dims),
155 traditional_(traditional),
156 base_(base),
157 scale_(scale),
158 offset_(offset),
159 forward_(forward) {}
160
161 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
162 override {
163 throw std::runtime_error("NYI");
164 }
165 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
166 override;
167
168 std::vector<array> vjp(
169 const std::vector<array>& primals,
170 const std::vector<array>& cotangents,
171 const std::vector<int>& argnums,
172 const std::vector<array>& outputs) override;
173
175 bool is_equivalent(const Primitive& other) const override;
176
177 private:
178 std::function<std::vector<array>(std::vector<array>)> fallback_;
179 int dims_;
180 bool traditional_;
181 float base_;
182 float scale_;
183 int offset_;
184 bool forward_;
185};
186
188 public:
191 std::function<std::vector<array>(std::vector<array>)> fallback,
192 const float scale,
193 const bool needs_mask)
194 : Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {}
195
196 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
197 override {
198 throw std::runtime_error("NYI");
199 }
200
201 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
202 override {
203 eval_gpu(inputs, outputs[0]);
204 }
205
206 void eval_gpu(const std::vector<array>& inputs, array& out);
207 bool is_equivalent(const Primitive& other) const override;
208
210
211 private:
212 std::function<std::vector<array>(std::vector<array>)> fallback_;
213 float scale_;
214 bool needs_mask_;
215};
216
217class AffineQuantize : public Custom {
218 public:
221 std::function<std::vector<array>(std::vector<array>)> fallback,
222 int group_size,
223 int bits,
224 bool dequantize)
225 : Custom(stream, fallback),
226 group_size_(group_size),
227 bits_(bits),
228 dequantize_(dequantize) {}
229
230 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
231 override {
232 throw std::runtime_error("NYI");
233 }
234
235 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
236 override;
237
239
240 private:
241 std::function<std::vector<array>(std::vector<array>)> fallback_;
242 int group_size_;
243 int bits_;
244 bool dequantize_;
245};
246
248 bool shape = false;
249 bool strides = false;
250 bool ndim = false;
251};
252
253class CustomKernel : public Primitive {
254 public:
257 std::string name,
258 std::string source,
259 std::tuple<int, int, int> grid,
260 std::tuple<int, int, int> threadgroup,
261 std::vector<CustomKernelShapeInfo> shape_infos,
262 bool ensure_row_contiguous,
263 std::optional<float> init_value)
264 : Primitive(stream),
265 source_(source),
266 name_(name),
267 grid_(grid),
268 threadgroup_(threadgroup),
269 shape_infos_(shape_infos),
270 ensure_row_contiguous_(ensure_row_contiguous),
271 init_value_(init_value) {}
272
273 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
274 override {
275 throw std::runtime_error("Custom Metal kernels only run on GPU.");
276 }
277
278 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
279 override;
280
282
283 private:
284 std::string source_;
285 std::string name_;
286 std::tuple<int, int, int> grid_;
287 std::tuple<int, int, int> threadgroup_;
288 std::vector<CustomKernelShapeInfo> shape_infos_;
289 bool ensure_row_contiguous_;
290 std::optional<float> init_value_;
291};
292
293} // namespace mlx::core::fast
Definition primitives.h:48
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
virtual bool is_equivalent(const Primitive &other) const
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:107
Definition array.h:20
Definition fast_primitives.h:217
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:230
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
AffineQuantize(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, int group_size, int bits, bool dequantize)
Definition fast_primitives.h:219
Definition fast_primitives.h:12
Custom(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)
Definition fast_primitives.h:14
virtual std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes) override
The primitive must know how to vectorize itself across the given axes.
virtual std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
The Jacobian-vector product.
Definition fast_primitives.h:253
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:273
CustomKernel(Stream stream, std::string name, std::string source, std::tuple< int, int, int > grid, std::tuple< int, int, int > threadgroup, std::vector< CustomKernelShapeInfo > shape_infos, bool ensure_row_contiguous, std::optional< float > init_value)
Definition fast_primitives.h:255
Definition fast_primitives.h:90
DEFINE_PRINT(LayerNorm) bool is_equivalent(const Primitive &other) const override
LayerNorm(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:92
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:98
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition fast_primitives.h:119
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:127
LayerNormVJP(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:121
DEFINE_PRINT(LayerNormVJP) bool is_equivalent(const Primitive &other) const override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition fast_primitives.h:38
RMSNorm(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:40
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:46
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
DEFINE_PRINT(RMSNorm) bool is_equivalent(const Primitive &other) const override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition fast_primitives.h:67
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
DEFINE_PRINT(RMSNormVJP) bool is_equivalent(const Primitive &other) const override
RMSNormVJP(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:69
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:75
Definition fast_primitives.h:142
RoPE(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, int dims, bool traditional, float base, float scale, int offset, bool forward)
Definition fast_primitives.h:144
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:161
DEFINE_PRINT(RoPE) bool is_equivalent(const Primitive &other) const override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition fast_primitives.h:187
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition fast_primitives.h:201
ScaledDotProductAttention(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, const float scale, const bool needs_mask)
Definition fast_primitives.h:189
DEFINE_PRINT(ScaledDotProductAttention)
void eval_gpu(const std::vector< array > &inputs, array &out)
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:196
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
array dequantize(const array &w, const array &scales, const array &biases, int group_size=64, int bits=4, StreamOrDevice s={})
Dequantize a matrix produced by quantize()
Definition fast.h:9
Definition stream.h:9
Definition fast_primitives.h:247
bool strides
Definition fast_primitives.h:249
bool shape
Definition fast_primitives.h:248
bool ndim
Definition fast_primitives.h:250