MLX
 
Loading...
Searching...
No Matches
fast_primitives.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#include <optional>
4
5#include "mlx/primitives.h"
6
7namespace mlx::core::fast {
8
9// Custom primitive accepts a fallback function which it uses for
10// transformations. Transformations are virtual so that derived classes may
11// override the default behavior.
12class Custom : public Primitive {
13 public:
14 explicit Custom(
16 std::function<std::vector<array>(std::vector<array>)> fallback)
17 : Primitive(stream), fallback_(fallback) {}
18
19 virtual std::pair<std::vector<array>, std::vector<int>> vmap(
20 const std::vector<array>& inputs,
21 const std::vector<int>& axes) override;
22
23 virtual std::vector<array> jvp(
24 const std::vector<array>& primals,
25 const std::vector<array>& tangents,
26 const std::vector<int>& argnums) override;
27
28 virtual std::vector<array> vjp(
29 const std::vector<array>& primals,
30 const std::vector<array>& cotangents,
31 const std::vector<int>& argnums,
32 const std::vector<array>& outputs) override;
33
34 private:
35 std::function<std::vector<array>(std::vector<array>)> fallback_;
36};
37
38class RMSNorm : public Custom {
39 public:
42 std::function<std::vector<array>(std::vector<array>)> fallback,
43 float eps)
44 : Custom(stream, fallback), eps_(eps) {}
45
46 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
47 override {
48 throw std::runtime_error("NYI");
49 }
50 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
51 override;
52
53 std::vector<array> vjp(
54 const std::vector<array>& primals,
55 const std::vector<array>& cotangents,
56 const std::vector<int>& argnums,
57 const std::vector<array>& outputs) override;
58
60 bool is_equivalent(const Primitive& other) const override;
62
63 auto state() const {
64 return std::make_pair(nullptr, eps_);
65 }
66
67 private:
68 std::function<std::vector<array>(std::vector<array>)> fallback_;
69 float eps_;
70};
71
72class RMSNormVJP : public Custom {
73 public:
76 std::function<std::vector<array>(std::vector<array>)> fallback,
77 float eps)
78 : Custom(stream, fallback), eps_(eps) {}
79
80 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
81 override {
82 throw std::runtime_error("NYI");
83 }
84 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
85 override;
86
88 bool is_equivalent(const Primitive& other) const override;
89 auto state() const {
90 return std::make_pair(nullptr, eps_);
91 }
92
93 private:
94 std::function<std::vector<array>(std::vector<array>)> fallback_;
95 float eps_;
96};
97
98class LayerNorm : public Custom {
99 public:
102 std::function<std::vector<array>(std::vector<array>)> fallback,
103 float eps)
104 : Custom(stream, fallback), eps_(eps) {}
105
106 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
107 override {
108 throw std::runtime_error("NYI");
109 }
110 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
111 override;
112
113 std::vector<array> vjp(
114 const std::vector<array>& primals,
115 const std::vector<array>& cotangents,
116 const std::vector<int>& argnums,
117 const std::vector<array>& outputs) override;
118
120 bool is_equivalent(const Primitive& other) const override;
122 auto state() const {
123 return std::make_pair(nullptr, eps_);
124 }
125
126 private:
127 std::function<std::vector<array>(std::vector<array>)> fallback_;
128 float eps_;
129};
130
131class LayerNormVJP : public Custom {
132 public:
135 std::function<std::vector<array>(std::vector<array>)> fallback,
136 float eps)
137 : Custom(stream, fallback), eps_(eps) {}
138
139 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
140 override {
141 throw std::runtime_error("NYI");
142 }
143 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
144 override;
145
147 bool is_equivalent(const Primitive& other) const override;
148 auto state() const {
149 return std::make_pair(nullptr, eps_);
150 }
151
152 private:
153 std::function<std::vector<array>(std::vector<array>)> fallback_;
154 float eps_;
155};
156
157class RoPE : public Custom {
158 public:
161 std::function<std::vector<array>(std::vector<array>)> fallback,
162 int dims,
163 bool traditional,
164 float base,
165 float scale,
166 bool forward)
167 : Custom(stream, fallback),
168 dims_(dims),
169 traditional_(traditional),
170 base_(base),
171 scale_(scale),
172 forward_(forward) {}
173
174 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
175 override {
176 throw std::runtime_error("NYI");
177 }
178 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
179 override;
180
181 std::vector<array> vjp(
182 const std::vector<array>& primals,
183 const std::vector<array>& cotangents,
184 const std::vector<int>& argnums,
185 const std::vector<array>& outputs) override;
186
188 bool is_equivalent(const Primitive& other) const override;
190 auto state() const {
191 return std::make_tuple(
192 nullptr, dims_, traditional_, base_, scale_, forward_);
193 }
194
195 private:
196 std::function<std::vector<array>(std::vector<array>)> fallback_;
197 int dims_;
198 bool traditional_;
199 float base_;
200 float scale_;
201 bool forward_;
202};
203
205 public:
208 std::function<std::vector<array>(std::vector<array>)> fallback,
209 const float scale)
210 : Custom(stream, fallback), scale_(scale) {}
211
212 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
213 override {
214 throw std::runtime_error("NYI");
215 }
216
217 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
218 override {
219 eval_gpu(inputs, outputs[0]);
220 }
221
222 void eval_gpu(const std::vector<array>& inputs, array& out);
223 bool is_equivalent(const Primitive& other) const override;
224
227 auto state() const {
228 return std::make_pair(nullptr, scale_);
229 }
230
231 private:
232 std::function<std::vector<array>(std::vector<array>)> fallback_;
233 float scale_;
234};
235
236class AffineQuantize : public Custom {
237 public:
240 std::function<std::vector<array>(std::vector<array>)> fallback,
241 int group_size,
242 int bits,
243 bool dequantize)
244 : Custom(stream, fallback),
245 group_size_(group_size),
246 bits_(bits),
247 dequantize_(dequantize) {}
248
249 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
250 override;
251
252 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
253 override;
254
256
257 bool is_equivalent(const Primitive& other) const override;
258 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
259 auto state() const {
260 return std::make_tuple(nullptr, group_size_, bits_, dequantize_);
261 }
262
263 private:
264 std::function<std::vector<array>(std::vector<array>)> fallback_;
265 int group_size_;
266 int bits_;
267 bool dequantize_;
268};
269
271 bool shape = false;
272 bool strides = false;
273 bool ndim = false;
274};
275
276class CustomKernel : public Primitive {
277 public:
280 std::string name,
281 std::string source,
282 std::tuple<int, int, int> grid,
283 std::tuple<int, int, int> threadgroup,
284 std::vector<CustomKernelShapeInfo> shape_infos,
285 bool ensure_row_contiguous,
286 std::optional<float> init_value)
287 : Primitive(stream),
288 source_(std::move(source)),
289 name_(std::move(name)),
290 grid_(grid),
291 threadgroup_(threadgroup),
292 shape_infos_(std::move(shape_infos)),
293 ensure_row_contiguous_(ensure_row_contiguous),
294 init_value_(init_value) {}
295
296 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
297 override {
298 throw std::runtime_error("Custom Metal kernels only run on GPU.");
299 }
300
301 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
302 override;
303
305
306 private:
307 std::string source_;
308 std::string name_;
309 std::tuple<int, int, int> grid_;
310 std::tuple<int, int, int> threadgroup_;
311 std::vector<CustomKernelShapeInfo> shape_infos_;
312 bool ensure_row_contiguous_;
313 std::optional<float> init_value_;
314};
315
316} // namespace mlx::core::fast
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
virtual bool is_equivalent(const Primitive &other) const
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:107
Primitive(Stream stream)
Definition primitives.h:50
Definition array.h:24
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
AffineQuantize(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, int group_size, int bits, bool dequantize)
Definition fast_primitives.h:238
auto state() const
Definition fast_primitives.h:259
Custom(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)
Definition fast_primitives.h:14
virtual std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes) override
The primitive must know how to vectorize itself across the given axes.
virtual std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
The Jacobian-vector product.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:296
CustomKernel(Stream stream, std::string name, std::string source, std::tuple< int, int, int > grid, std::tuple< int, int, int > threadgroup, std::vector< CustomKernelShapeInfo > shape_infos, bool ensure_row_contiguous, std::optional< float > init_value)
Definition fast_primitives.h:278
DEFINE_PRINT(LayerNorm) bool is_equivalent(const Primitive &other) const override
LayerNorm(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:100
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:106
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
DEFINE_INPUT_OUTPUT_SHAPE() auto state() const
Definition fast_primitives.h:121
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:139
LayerNormVJP(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:133
DEFINE_PRINT(LayerNormVJP) bool is_equivalent(const Primitive &other) const override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
auto state() const
Definition fast_primitives.h:148
RMSNorm(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:40
DEFINE_INPUT_OUTPUT_SHAPE() auto state() const
Definition fast_primitives.h:61
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:46
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
DEFINE_PRINT(RMSNorm) bool is_equivalent(const Primitive &other) const override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
auto state() const
Definition fast_primitives.h:89
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
DEFINE_PRINT(RMSNormVJP) bool is_equivalent(const Primitive &other) const override
RMSNormVJP(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:74
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:80
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:174
DEFINE_PRINT(RoPE) bool is_equivalent(const Primitive &other) const override
DEFINE_INPUT_OUTPUT_SHAPE() auto state() const
Definition fast_primitives.h:189
RoPE(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, int dims, bool traditional, float base, float scale, bool forward)
Definition fast_primitives.h:159
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition fast_primitives.h:217
DEFINE_INPUT_OUTPUT_SHAPE() auto state() const
Definition fast_primitives.h:226
DEFINE_PRINT(ScaledDotProductAttention)
ScaledDotProductAttention(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, const float scale)
Definition fast_primitives.h:206
void eval_gpu(const std::vector< array > &inputs, array &out)
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:212
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
array dequantize(const array &w, const array &scales, const array &biases, int group_size=64, int bits=4, StreamOrDevice s={})
Dequantize a matrix produced by quantize()
Definition fast.h:9
Definition stream.h:9
Definition fast_primitives.h:270
bool strides
Definition fast_primitives.h:272
bool shape
Definition fast_primitives.h:271
bool ndim
Definition fast_primitives.h:273