14 std::function<std::vector<array>(std::vector<array>)> fallback)
17 virtual std::pair<std::vector<array>, std::vector<int>>
vmap(
18 const std::vector<array>& inputs,
19 const std::vector<int>& axes)
override;
21 virtual std::vector<array>
jvp(
22 const std::vector<array>& primals,
23 const std::vector<array>& tangents,
24 const std::vector<int>& argnums)
override;
26 virtual std::vector<array>
vjp(
27 const std::vector<array>& primals,
28 const std::vector<array>& cotangents,
29 const std::vector<int>& argnums,
30 const std::vector<array>& outputs)
override;
33 std::function<std::vector<array>(std::vector<array>)> fallback_;
40 std::function<std::vector<array>(std::vector<array>)> fallback,
44 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
46 throw std::runtime_error(
"NYI");
48 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
51 std::vector<array>
vjp(
52 const std::vector<array>& primals,
53 const std::vector<array>& cotangents,
54 const std::vector<int>& argnums,
55 const std::vector<array>& outputs)
override;
61 std::function<std::vector<array>(std::vector<array>)> fallback_;
69 std::function<std::vector<array>(std::vector<array>)> fallback,
73 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
75 throw std::runtime_error(
"NYI");
77 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
84 std::function<std::vector<array>(std::vector<array>)> fallback_;
92 std::function<std::vector<array>(std::vector<array>)> fallback,
96 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
98 throw std::runtime_error(
"NYI");
100 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
104 const std::vector<array>& primals,
105 const std::vector<array>& cotangents,
106 const std::vector<int>& argnums,
107 const std::vector<array>& outputs)
override;
113 std::function<std::vector<array>(std::vector<array>)> fallback_;
121 std::function<std::vector<array>(std::vector<array>)> fallback,
125 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
127 throw std::runtime_error(
"NYI");
129 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
136 std::function<std::vector<array>(std::vector<array>)> fallback_;
144 std::function<std::vector<array>(std::vector<array>)> fallback,
153 traditional_(traditional),
159 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
161 throw std::runtime_error(
"NYI");
163 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
167 const std::vector<array>& primals,
168 const std::vector<array>& cotangents,
169 const std::vector<int>& argnums,
170 const std::vector<array>& outputs)
override;
176 std::function<std::vector<array>(std::vector<array>)> fallback_;
189 std::function<std::vector<array>(std::vector<array>)> fallback,
191 const bool needs_mask)
192 :
Custom(
stream, fallback), scale_(scale), needs_mask_(needs_mask) {}
194 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
196 throw std::runtime_error(
"NYI");
199 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
210 std::function<std::vector<array>(std::vector<array>)> fallback_;
219 std::function<std::vector<array>(std::vector<array>)> fallback,
224 group_size_(group_size),
228 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
230 throw std::runtime_error(
"NYI");
233 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
239 std::function<std::vector<array>(std::vector<array>)> fallback_;
257 std::tuple<int, int, int> grid,
258 std::tuple<int, int, int> threadgroup,
259 std::vector<CustomKernelShapeInfo> shape_infos,
260 bool ensure_row_contiguous)
265 threadgroup_(threadgroup),
266 shape_infos_(shape_infos),
267 ensure_row_contiguous_(ensure_row_contiguous) {}
269 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
271 throw std::runtime_error(
"Custom Metal kernels only run on GPU.");
274 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
282 std::tuple<int, int, int> grid_;
283 std::tuple<int, int, int> threadgroup_;
284 std::vector<CustomKernelShapeInfo> shape_infos_;
285 bool ensure_row_contiguous_;
Definition primitives.h:48
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
virtual bool is_equivalent(const Primitive &other) const
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:107
Definition fast_primitives.h:215
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:228
DEFINE_PRINT(AffineQuantize)
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
AffineQuantize(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, int group_size, int bits, bool dequantize)
Definition fast_primitives.h:217
Definition fast_primitives.h:10
Custom(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)
Definition fast_primitives.h:12
virtual std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes) override
The primitive must know how to vectorize itself across the given axes.
virtual std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
The Jacobian-vector product.
Definition fast_primitives.h:251
DEFINE_PRINT(CustomKernel)
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:269
CustomKernel(Stream stream, std::string name, std::string source, std::tuple< int, int, int > grid, std::tuple< int, int, int > threadgroup, std::vector< CustomKernelShapeInfo > shape_infos, bool ensure_row_contiguous)
Definition fast_primitives.h:253
Definition fast_primitives.h:88
DEFINE_PRINT(LayerNorm) bool is_equivalent(const Primitive &other) const override
LayerNorm(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:90
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:96
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition fast_primitives.h:117
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:125
LayerNormVJP(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:119
DEFINE_PRINT(LayerNormVJP) bool is_equivalent(const Primitive &other) const override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition fast_primitives.h:36
RMSNorm(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:38
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:44
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
DEFINE_PRINT(RMSNorm) bool is_equivalent(const Primitive &other) const override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition fast_primitives.h:65
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
DEFINE_PRINT(RMSNormVJP) bool is_equivalent(const Primitive &other) const override
RMSNormVJP(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:67
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:73
Definition fast_primitives.h:140
RoPE(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, int dims, bool traditional, float base, float scale, int offset, bool forward)
Definition fast_primitives.h:142
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:159
DEFINE_PRINT(RoPE) bool is_equivalent(const Primitive &other) const override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition fast_primitives.h:185
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition fast_primitives.h:199
ScaledDotProductAttention(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, const float scale, const bool needs_mask)
Definition fast_primitives.h:187
DEFINE_PRINT(ScaledDotProductAttention)
void eval_gpu(const std::vector< array > &inputs, array &out)
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:194
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
array dequantize(const array &w, const array &scales, const array &biases, int group_size=64, int bits=4, StreamOrDevice s={})
Dequantize a matrix produced by quantize()
Definition fast_primitives.h:245
bool strides
Definition fast_primitives.h:247
bool shape
Definition fast_primitives.h:246
bool ndim
Definition fast_primitives.h:248