MLX
Loading...
Searching...
No Matches
fast_primitives.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#include <optional>
4
5#include "mlx/primitives.h"
6
7namespace mlx::core::fast {
8
9// Custom primitive accepts a fallback function which it uses for
10// transformations. Transformations are virtual so that derived classes may
11// override the default behavior.
12class Custom : public Primitive {
13 public:
14 explicit Custom(
16 std::function<std::vector<array>(std::vector<array>)> fallback)
17 : Primitive(stream), fallback_(fallback) {}
18
19 virtual std::pair<std::vector<array>, std::vector<int>> vmap(
20 const std::vector<array>& inputs,
21 const std::vector<int>& axes) override;
22
23 virtual std::vector<array> jvp(
24 const std::vector<array>& primals,
25 const std::vector<array>& tangents,
26 const std::vector<int>& argnums) override;
27
28 virtual std::vector<array> vjp(
29 const std::vector<array>& primals,
30 const std::vector<array>& cotangents,
31 const std::vector<int>& argnums,
32 const std::vector<array>& outputs) override;
33
34 private:
35 std::function<std::vector<array>(std::vector<array>)> fallback_;
36};
37
38class RMSNorm : public Custom {
39 public:
42 std::function<std::vector<array>(std::vector<array>)> fallback,
43 float eps)
44 : Custom(stream, fallback), eps_(eps) {}
45
46 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
47 override {
48 throw std::runtime_error("NYI");
49 }
50 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
51 override;
52
53 std::vector<array> vjp(
54 const std::vector<array>& primals,
55 const std::vector<array>& cotangents,
56 const std::vector<int>& argnums,
57 const std::vector<array>& outputs) override;
58
60 bool is_equivalent(const Primitive& other) const override;
62
63 private:
64 std::function<std::vector<array>(std::vector<array>)> fallback_;
65 float eps_;
66};
67
68class RMSNormVJP : public Custom {
69 public:
72 std::function<std::vector<array>(std::vector<array>)> fallback,
73 float eps)
74 : Custom(stream, fallback), eps_(eps) {}
75
76 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
77 override {
78 throw std::runtime_error("NYI");
79 }
80 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
81 override;
82
84 bool is_equivalent(const Primitive& other) const override;
85
86 private:
87 std::function<std::vector<array>(std::vector<array>)> fallback_;
88 float eps_;
89};
90
91class LayerNorm : public Custom {
92 public:
94 Stream stream,
95 std::function<std::vector<array>(std::vector<array>)> fallback,
96 float eps)
97 : Custom(stream, fallback), eps_(eps) {}
98
99 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
100 override {
101 throw std::runtime_error("NYI");
102 }
103 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
104 override;
105
106 std::vector<array> vjp(
107 const std::vector<array>& primals,
108 const std::vector<array>& cotangents,
109 const std::vector<int>& argnums,
110 const std::vector<array>& outputs) override;
111
113 bool is_equivalent(const Primitive& other) const override;
115
116 private:
117 std::function<std::vector<array>(std::vector<array>)> fallback_;
118 float eps_;
119};
120
121class LayerNormVJP : public Custom {
122 public:
124 Stream stream,
125 std::function<std::vector<array>(std::vector<array>)> fallback,
126 float eps)
127 : Custom(stream, fallback), eps_(eps) {}
128
129 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
130 override {
131 throw std::runtime_error("NYI");
132 }
133 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
134 override;
135
137 bool is_equivalent(const Primitive& other) const override;
138
139 private:
140 std::function<std::vector<array>(std::vector<array>)> fallback_;
141 float eps_;
142};
143
144class RoPE : public Custom {
145 public:
147 Stream stream,
148 std::function<std::vector<array>(std::vector<array>)> fallback,
149 int dims,
150 bool traditional,
151 float base,
152 float scale,
153 int offset,
154 bool forward)
155 : Custom(stream, fallback),
156 dims_(dims),
157 traditional_(traditional),
158 base_(base),
159 scale_(scale),
160 offset_(offset),
161 forward_(forward) {}
162
163 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
164 override {
165 throw std::runtime_error("NYI");
166 }
167 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
168 override;
169
170 std::vector<array> vjp(
171 const std::vector<array>& primals,
172 const std::vector<array>& cotangents,
173 const std::vector<int>& argnums,
174 const std::vector<array>& outputs) override;
175
177 bool is_equivalent(const Primitive& other) const override;
179
180 private:
181 std::function<std::vector<array>(std::vector<array>)> fallback_;
182 int dims_;
183 bool traditional_;
184 float base_;
185 float scale_;
186 int offset_;
187 bool forward_;
188};
189
191 public:
193 Stream stream,
194 std::function<std::vector<array>(std::vector<array>)> fallback,
195 const float scale,
196 const bool needs_mask)
197 : Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {}
198
199 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
200 override {
201 throw std::runtime_error("NYI");
202 }
203
204 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
205 override {
206 eval_gpu(inputs, outputs[0]);
207 }
208
209 void eval_gpu(const std::vector<array>& inputs, array& out);
210 bool is_equivalent(const Primitive& other) const override;
211
214
215 private:
216 std::function<std::vector<array>(std::vector<array>)> fallback_;
217 float scale_;
218 bool needs_mask_;
219};
220
221class AffineQuantize : public Custom {
222 public:
224 Stream stream,
225 std::function<std::vector<array>(std::vector<array>)> fallback,
226 int group_size,
227 int bits,
228 bool dequantize)
229 : Custom(stream, fallback),
230 group_size_(group_size),
231 bits_(bits),
232 dequantize_(dequantize) {}
233
234 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
235 override;
236
237 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
238 override;
239
241
242 bool is_equivalent(const Primitive& other) const override;
243 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
244
245 private:
246 std::function<std::vector<array>(std::vector<array>)> fallback_;
247 int group_size_;
248 int bits_;
249 bool dequantize_;
250};
251
253 bool shape = false;
254 bool strides = false;
255 bool ndim = false;
256};
257
258class CustomKernel : public Primitive {
259 public:
261 Stream stream,
262 std::string name,
263 std::string source,
264 std::tuple<int, int, int> grid,
265 std::tuple<int, int, int> threadgroup,
266 std::vector<CustomKernelShapeInfo> shape_infos,
267 bool ensure_row_contiguous,
268 std::optional<float> init_value)
269 : Primitive(stream),
270 source_(std::move(source)),
271 name_(std::move(name)),
272 grid_(grid),
273 threadgroup_(threadgroup),
274 shape_infos_(std::move(shape_infos)),
275 ensure_row_contiguous_(ensure_row_contiguous),
276 init_value_(init_value) {}
277
278 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
279 override {
280 throw std::runtime_error("Custom Metal kernels only run on GPU.");
281 }
282
283 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
284 override;
285
287
288 private:
289 std::string source_;
290 std::string name_;
291 std::tuple<int, int, int> grid_;
292 std::tuple<int, int, int> threadgroup_;
293 std::vector<CustomKernelShapeInfo> shape_infos_;
294 bool ensure_row_contiguous_;
295 std::optional<float> init_value_;
296};
297
298} // namespace mlx::core::fast
Definition primitives.h:48
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
virtual bool is_equivalent(const Primitive &other) const
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:107
Definition array.h:23
Definition fast_primitives.h:221
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
AffineQuantize(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, int group_size, int bits, bool dequantize)
Definition fast_primitives.h:223
Definition fast_primitives.h:12
Custom(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)
Definition fast_primitives.h:14
virtual std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes) override
The primitive must know how to vectorize itself across the given axes.
virtual std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
The Jacobian-vector product.
Definition fast_primitives.h:258
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:278
CustomKernel(Stream stream, std::string name, std::string source, std::tuple< int, int, int > grid, std::tuple< int, int, int > threadgroup, std::vector< CustomKernelShapeInfo > shape_infos, bool ensure_row_contiguous, std::optional< float > init_value)
Definition fast_primitives.h:260
Definition fast_primitives.h:91
DEFINE_PRINT(LayerNorm) bool is_equivalent(const Primitive &other) const override
LayerNorm(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:93
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:99
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition fast_primitives.h:121
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:129
LayerNormVJP(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:123
DEFINE_PRINT(LayerNormVJP) bool is_equivalent(const Primitive &other) const override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition fast_primitives.h:38
RMSNorm(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:40
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:46
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
DEFINE_INPUT_OUTPUT_SHAPE() private float eps_
Definition fast_primitives.h:61
DEFINE_PRINT(RMSNorm) bool is_equivalent(const Primitive &other) const override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition fast_primitives.h:68
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
DEFINE_PRINT(RMSNormVJP) bool is_equivalent(const Primitive &other) const override
RMSNormVJP(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:70
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:76
Definition fast_primitives.h:144
RoPE(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, int dims, bool traditional, float base, float scale, int offset, bool forward)
Definition fast_primitives.h:146
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:163
DEFINE_PRINT(RoPE) bool is_equivalent(const Primitive &other) const override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition fast_primitives.h:190
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition fast_primitives.h:204
ScaledDotProductAttention(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, const float scale, const bool needs_mask)
Definition fast_primitives.h:192
DEFINE_PRINT(ScaledDotProductAttention)
void eval_gpu(const std::vector< array > &inputs, array &out)
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:199
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
Definition fast.h:9
#define DEFINE_INPUT_OUTPUT_SHAPE()
Definition primitives.h:39
void dequantize(const device uint8_t *w, U scale, U bias, threadgroup U *w_local)
Definition quantized.h:372
Definition stream.h:9
Definition fast_primitives.h:252