MLX
Loading...
Searching...
No Matches
fast_primitives.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#include "mlx/primitives.h"
4
5namespace mlx::core::fast {
6
7// Custom primitive accepts a fallback function which it uses for
8// transformations. Transformations are virtual so that derived classes may
9// override the default behavior.
10class Custom : public Primitive {
11 public:
12 explicit Custom(
14 std::function<std::vector<array>(std::vector<array>)> fallback)
15 : Primitive(stream), fallback_(fallback) {}
16
17 virtual std::pair<std::vector<array>, std::vector<int>> vmap(
18 const std::vector<array>& inputs,
19 const std::vector<int>& axes) override;
20
21 virtual std::vector<array> jvp(
22 const std::vector<array>& primals,
23 const std::vector<array>& tangents,
24 const std::vector<int>& argnums) override;
25
26 virtual std::vector<array> vjp(
27 const std::vector<array>& primals,
28 const std::vector<array>& cotangents,
29 const std::vector<int>& argnums,
30 const std::vector<array>& outputs) override;
31
32 private:
33 std::function<std::vector<array>(std::vector<array>)> fallback_;
34};
35
36class RMSNorm : public Custom {
37 public:
40 std::function<std::vector<array>(std::vector<array>)> fallback,
41 float eps)
42 : Custom(stream, fallback), eps_(eps) {}
43
44 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
45 override {
46 throw std::runtime_error("NYI");
47 }
48 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
49 override;
50
51 std::vector<array> vjp(
52 const std::vector<array>& primals,
53 const std::vector<array>& cotangents,
54 const std::vector<int>& argnums,
55 const std::vector<array>& outputs) override;
56
58 bool is_equivalent(const Primitive& other) const override;
59
60 private:
61 std::function<std::vector<array>(std::vector<array>)> fallback_;
62 float eps_;
63};
64
65class RMSNormVJP : public Custom {
66 public:
69 std::function<std::vector<array>(std::vector<array>)> fallback,
70 float eps)
71 : Custom(stream, fallback), eps_(eps) {}
72
73 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
74 override {
75 throw std::runtime_error("NYI");
76 }
77 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
78 override;
79
81 bool is_equivalent(const Primitive& other) const override;
82
83 private:
84 std::function<std::vector<array>(std::vector<array>)> fallback_;
85 float eps_;
86};
87
88class LayerNorm : public Custom {
89 public:
92 std::function<std::vector<array>(std::vector<array>)> fallback,
93 float eps)
94 : Custom(stream, fallback), eps_(eps) {}
95
96 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
97 override {
98 throw std::runtime_error("NYI");
99 }
100 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
101 override;
102
103 std::vector<array> vjp(
104 const std::vector<array>& primals,
105 const std::vector<array>& cotangents,
106 const std::vector<int>& argnums,
107 const std::vector<array>& outputs) override;
108
110 bool is_equivalent(const Primitive& other) const override;
111
112 private:
113 std::function<std::vector<array>(std::vector<array>)> fallback_;
114 float eps_;
115};
116
117class LayerNormVJP : public Custom {
118 public:
121 std::function<std::vector<array>(std::vector<array>)> fallback,
122 float eps)
123 : Custom(stream, fallback), eps_(eps) {}
124
125 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
126 override {
127 throw std::runtime_error("NYI");
128 }
129 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
130 override;
131
133 bool is_equivalent(const Primitive& other) const override;
134
135 private:
136 std::function<std::vector<array>(std::vector<array>)> fallback_;
137 float eps_;
138};
139
140class RoPE : public Custom {
141 public:
144 std::function<std::vector<array>(std::vector<array>)> fallback,
145 int dims,
146 bool traditional,
147 float base,
148 float scale,
149 int offset,
150 bool forward)
151 : Custom(stream, fallback),
152 dims_(dims),
153 traditional_(traditional),
154 base_(base),
155 scale_(scale),
156 offset_(offset),
157 forward_(forward) {}
158
159 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
160 override {
161 throw std::runtime_error("NYI");
162 }
163 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
164 override;
165
166 std::vector<array> vjp(
167 const std::vector<array>& primals,
168 const std::vector<array>& cotangents,
169 const std::vector<int>& argnums,
170 const std::vector<array>& outputs) override;
171
173 bool is_equivalent(const Primitive& other) const override;
174
175 private:
176 std::function<std::vector<array>(std::vector<array>)> fallback_;
177 int dims_;
178 bool traditional_;
179 float base_;
180 float scale_;
181 int offset_;
182 bool forward_;
183};
184
186 public:
189 std::function<std::vector<array>(std::vector<array>)> fallback,
190 const float scale,
191 const bool needs_mask)
192 : Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {}
193
194 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
195 override {
196 throw std::runtime_error("NYI");
197 }
198
199 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
200 override {
201 eval_gpu(inputs, outputs[0]);
202 }
203
204 void eval_gpu(const std::vector<array>& inputs, array& out);
205 bool is_equivalent(const Primitive& other) const override;
206
208
209 private:
210 std::function<std::vector<array>(std::vector<array>)> fallback_;
211 float scale_;
212 bool needs_mask_;
213};
214
215} // namespace mlx::core::fast
Definition primitives.h:48
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
virtual bool is_equivalent(const Primitive &other) const
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:107
Definition array.h:20
Definition fast_primitives.h:10
Custom(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)
Definition fast_primitives.h:12
virtual std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes) override
The primitive must know how to vectorize itself across the given axes.
virtual std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
The Jacobian-vector product.
Definition fast_primitives.h:88
DEFINE_PRINT(LayerNorm) bool is_equivalent(const Primitive &other) const override
LayerNorm(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:90
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:96
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition fast_primitives.h:117
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:125
LayerNormVJP(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:119
DEFINE_PRINT(LayerNormVJP) bool is_equivalent(const Primitive &other) const override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition fast_primitives.h:36
RMSNorm(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:38
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:44
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
DEFINE_PRINT(RMSNorm) bool is_equivalent(const Primitive &other) const override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition fast_primitives.h:65
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
DEFINE_PRINT(RMSNormVJP) bool is_equivalent(const Primitive &other) const override
RMSNormVJP(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:67
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:73
Definition fast_primitives.h:140
RoPE(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, int dims, bool traditional, float base, float scale, int offset, bool forward)
Definition fast_primitives.h:142
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:159
DEFINE_PRINT(RoPE) bool is_equivalent(const Primitive &other) const override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition fast_primitives.h:185
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition fast_primitives.h:199
ScaledDotProductAttention(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, const float scale, const bool needs_mask)
Definition fast_primitives.h:187
DEFINE_PRINT(ScaledDotProductAttention)
void eval_gpu(const std::vector< array > &inputs, array &out)
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:194
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Definition fast.h:9
Definition stream.h:9