5#include <unordered_set>
12#define DEFINE_VMAP() \
13 virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
14 const std::vector<array>& inputs, const std::vector<int>& axes) \
17#define DEFINE_GRADS() \
18 std::vector<array> jvp( \
19 const std::vector<array>& primals, \
20 const std::vector<array>& tangents, \
21 const std::vector<int>& argnums) override; \
23 std::vector<array> vjp( \
24 const std::vector<array>& primals, \
25 const std::vector<array>& cotangents, \
26 const std::vector<int>& argnums, \
27 const std::vector<array>& outputs) override;
29#define DEFINE_PRINT(PRIMITIVE) \
30 void print(std::ostream& os) override { \
34#define DEFINE_DEFAULT_IS_EQUIVALENT() \
35 bool is_equivalent(const Primitive& other) const override { \
39#define DEFINE_INPUT_OUTPUT_SHAPE() \
40 std::vector<std::vector<int>> output_shapes( \
41 const std::vector<array>& inputs) override { \
42 return {inputs[0].shape()}; \
70 const std::vector<array>& inputs,
71 std::vector<array>& outputs) = 0;
73 const std::vector<array>& inputs,
74 std::vector<array>& outputs) = 0;
79 virtual std::vector<array>
jvp(
80 const std::vector<array>& primals,
81 const std::vector<array>& tangents,
82 const std::vector<int>& argnums);
87 virtual std::vector<array>
vjp(
88 const std::vector<array>& primals,
89 const std::vector<array>& cotangents,
90 const std::vector<int>& argnums,
91 const std::vector<array>& outputs);
99 virtual std::pair<std::vector<array>, std::vector<int>>
vmap(
100 const std::vector<array>& inputs,
101 const std::vector<int>& axes);
104 virtual void print(std::ostream& os) = 0;
114 const std::vector<array>& inputs);
138 const std::vector<array>& inputs,
139 std::vector<array>& outputs)
override {
143 const std::vector<array>& inputs,
144 std::vector<array>& outputs)
override {
169 void eval(const std::vector<
array>& inputs,
array& out);
186 void eval(const std::vector<
array>& inputs,
array& out);
198 const std::vector<array>& primals,
199 const std::vector<array>& cotangents,
200 const std::vector<int>& argnums,
201 const std::vector<array>& outputs)
override;
229 void eval(const std::vector<
array>& inputs,
array& out);
246 void eval(const std::vector<
array>& inputs,
array& out);
263 void eval(const std::vector<
array>& inputs,
array& out);
280 void eval(const std::vector<
array>& inputs,
array& out);
297 void eval(const std::vector<
array>& inputs,
array& out);
314 void eval(const std::vector<
array>& inputs,
array& out);
331 void eval(const std::vector<
array>& inputs,
array& out);
348 void eval(const std::vector<
array>& inputs,
array& out);
368 void eval(const std::vector<
array>& inputs,
array& out);
388 const std::vector<
array>& inputs) override;
394 void eval(const std::vector<
array>& inputs,
array& out);
413 void eval(const std::vector<
array>& inputs,
array& out);
433 void eval(const std::vector<
array>& inputs,
array& out);
440 std::vector<int> shape,
441 std::vector<size_t> strides,
444 shape_(
std::move(shape)),
445 strides_(
std::move(strides)),
456 std::vector<
int> shape_;
457 std::vector<
size_t> strides_;
460 void eval(const std::vector<
array>& inputs,
array& out);
475 void print(std::ostream& os) override;
491 const std::vector<array>& primals,
492 const std::vector<array>& cotangents,
493 const std::vector<int>& argnums,
494 const std::vector<array>& outputs)
override;
502 void eval(const std::vector<
array>& inputs,
array& out);
513 const std::vector<array>& primals,
514 const std::vector<array>& cotangents,
515 const std::vector<int>& argnums,
516 const std::vector<array>& outputs)
override;
522 void eval(const std::vector<
array>& inputs,
array& out);
539 std::vector<
int> shape_;
541 void eval(const std::vector<
array>& inputs,
array& out);
558 void eval(const std::vector<
array>& inputs,
array& out);
574 std::vector<array> inputs,
575 std::vector<array> outputs,
576 std::vector<array> tape,
577 std::unordered_set<uintptr_t> constant_ids);
579 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
581 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
587 const std::vector<
array>& inputs) override;
588 void print(std::ostream& os) override;
591 std::
string lib_name()
const {
596 const std::vector<array> inputs_;
597 const std::vector<array> outputs_;
598 const std::vector<array> tape_;
599 const std::unordered_set<uintptr_t> constant_ids_;
601 std::string kernel_lib_;
615 bool is_equivalent(const
Primitive& other) const override;
620 void eval(const std::vector<
array>& inputs,
array& out);
636 void eval(const std::vector<
array>& inputs,
array& out);
643 const std::vector<int>& kernel_strides,
644 const std::vector<int>& padding,
645 const std::vector<int>& kernel_dilation,
646 const std::vector<int>& input_dilation,
647 const int groups = 1,
648 const bool flip =
false)
651 kernel_strides_(kernel_strides),
652 kernel_dilation_(kernel_dilation),
653 input_dilation_(input_dilation),
661 const std::vector<array>& primals,
662 const std::vector<array>& cotangents,
663 const std::vector<int>& argnums,
664 const std::vector<array>& outputs)
override;
667 bool is_equivalent(const
Primitive& other) const override;
670 std::vector<
int> padding_;
671 std::vector<
int> kernel_strides_;
672 std::vector<
int> kernel_dilation_;
673 std::vector<
int> input_dilation_;
677 void eval(const std::vector<
array>& inputs,
array& out);
694 void eval(const std::vector<
array>& inputs,
array& out);
711 void eval(const std::vector<
array>& inputs,
array& out);
728 void eval(const std::vector<
array>& inputs,
array& out);
735 std::function<std::vector<array>(
736 const std::vector<array>&,
737 const std::vector<array>&,
738 const std::vector<array>&)> fun)
741 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
743 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
747 const std::vector<array>& primals,
748 const std::vector<array>& cotan,
749 const std::vector<int>& argnums,
750 const std::vector<array>& outputs)
override;
755 void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
757 std::function<std::vector<array>(
758 const std::vector<array>&,
759 const std::vector<array>&,
760 const std::vector<array>&)>
768 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
770 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
774 const std::vector<array>& primals,
775 const std::vector<array>& cotan,
776 const std::vector<int>& argnums,
777 const std::vector<array>& outputs)
override;
782 void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
799 void eval(const std::vector<
array>& inputs,
array& out);
806 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
808 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
815 std::vector<std::vector<
int>> output_shapes(
816 const std::vector<
array>& inputs)
override {
817 return std::vector{inputs[0].shape(), inputs[0].shape()};
821 void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
838 void eval(const std::vector<
array>& inputs,
array& out);
855 void eval(const std::vector<
array>& inputs,
array& out);
871 void print(std::ostream& os)
override {
880 void eval(
const std::vector<array>& inputs,
array& out);
898 void eval(const std::vector<
array>& inputs,
array& out);
915 void eval(const std::vector<
array>& inputs,
array& out);
932 void eval(const std::vector<
array>& inputs,
array& out);
948 void eval(const std::vector<
array>& inputs,
array& out);
955 const std::vector<size_t>& axes,
958 :
UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {};
967 bool is_equivalent(const
Primitive& other) const override;
970 std::vector<
size_t> axes_;
974 void eval(const std::vector<
array>& inputs,
array& out);
991 void eval(const std::vector<
array>& inputs,
array& out);
1007 void eval(const std::vector<
array>& inputs,
array& out);
1014 const std::vector<int>& axes,
1015 const std::vector<int>& slice_sizes)
1016 :
UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes) {};
1027 void eval(const std::vector<
array>& inputs,
array& out);
1028 std::vector<
int> axes_;
1029 std::vector<
int> slice_sizes_;
1046 void eval(const std::vector<
array>& inputs,
array& out);
1063 void eval(const std::vector<
array>& inputs,
array& out);
1080 void eval(const std::vector<
array>& inputs,
array& out);
1097 void eval(const std::vector<
array>& inputs,
array& out);
1104 std::shared_ptr<io::Reader> reader,
1106 bool swap_endianness =
false)
1110 swap_endianness_(swap_endianness) {};
1118 void eval(
const std::vector<array>& inputs,
array& out);
1119 std::shared_ptr<io::Reader> reader_;
1121 bool swap_endianness_;
1139 void print(std::ostream& os)
override {
1155 void eval(
const std::vector<array>& inputs,
array& out);
1171 void eval(const std::vector<
array>& inputs,
array& out);
1188 void eval(const std::vector<
array>& inputs,
array& out);
1205 void eval(const std::vector<
array>& inputs,
array& out);
1222 void eval(const std::vector<
array>& inputs,
array& out);
1239 void eval(const std::vector<
array>& inputs,
array& out);
1250 const std::vector<array>& primals,
1251 const std::vector<array>& cotangents,
1252 const std::vector<int>& argnums,
1253 const std::vector<array>& outputs)
override;
1274 void eval(const std::vector<
array>& inputs,
array& out);
1291 void eval(const std::vector<
array>& inputs,
array& out);
1308 void eval(const std::vector<
array>& inputs,
array& out);
1325 void eval(const std::vector<
array>& inputs,
array& out);
1342 void eval(const std::vector<
array>& inputs,
array& out);
1349 std::vector<int> axes,
1353 axes_(
std::move(axes)),
1354 inverted_(inverted),
1363 std::vector<std::vector<
int>> output_shapes(
1364 const std::vector<
array>& inputs)
override {
1369 std::vector<int> axes_;
1373 void eval(
const std::vector<array>& inputs,
array& out);
1380 const std::vector<int>& axes,
1381 const std::vector<int>& low_pad_size,
1382 const std::vector<int>& high_pad_size)
1385 low_pad_size_(low_pad_size),
1386 high_pad_size_(high_pad_size) {};
1397 std::vector<
int> axes_;
1398 std::vector<
int> low_pad_size_;
1399 std::vector<
int> high_pad_size_;
1401 void eval(const std::vector<
array>& inputs,
array& out);
1422 void eval(const std::vector<
array>& inputs,
array& out);
1439 void eval(const std::vector<
array>& inputs,
array& out);
1450 group_size_(group_size),
1467 void eval(const std::vector<
array>& inputs,
array& out);
1474 group_size_(group_size),
1491 void eval(const std::vector<
array>& inputs,
array& out);
1507 std::vector<
int> shape_;
1510 void eval(const std::vector<
array>& inputs,
array& out);
1527 std::vector<
int> shape_;
1529 void eval(const std::vector<
array>& inputs,
array& out);
1531 std::pair<
bool, std::vector<
size_t>> prepare_reshape(
1534 void shared_buffer_reshape(
1536 const std::vector<
size_t>& out_strides,
1547 const std::vector<int>& axes)
1548 :
UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {};
1556 const std::vector<
array>& primals,
1557 const std::vector<
array>& cotangents,
1558 const std::vector<
int>& argnums,
1559 const std::vector<
array>& outputs) override;
1561 std::vector<std::vector<
int>> output_shapes(
1562 const std::vector<
array>& inputs) override;
1564 void print(std::ostream& os)
override {
1565 switch (reduce_type_) {
1590 std::vector<int> axes_;
1592 void eval(
const std::vector<array>& inputs,
array& out);
1609 void eval(const std::vector<
array>& inputs,
array& out);
1623 reduce_type_(reduce_type),
1626 inclusive_(inclusive) {};
1634 void print(std::ostream& os)
override {
1636 switch (reduce_type_) {
1659 void eval(
const std::vector<array>& inputs,
array& out);
1669 const std::vector<int>& axes)
1670 :
UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {};
1678 switch (reduce_type_) {
1698 void eval(
const std::vector<array>& inputs,
array& out);
1700 std::vector<int> axes_;
1717 void eval(const std::vector<
array>& inputs,
array& out);
1734 void eval(const std::vector<
array>& inputs,
array& out);
1751 void eval(const std::vector<
array>& inputs,
array& out);
1768 void eval(const std::vector<
array>& inputs,
array& out);
1775 const std::vector<int>& start_indices,
1776 const std::vector<int>& end_indices,
1777 const std::vector<int>& strides)
1779 start_indices_(start_indices),
1780 end_indices_(end_indices),
1781 strides_(strides) {};
1792 std::vector<
int> start_indices_;
1793 std::vector<
int> end_indices_;
1794 std::vector<
int> strides_;
1796 void eval(const std::vector<
array>& inputs,
array& out);
1803 const std::vector<int>& start_indices,
1804 const std::vector<int>& end_indices,
1805 const std::vector<int>& strides)
1807 start_indices_(start_indices),
1808 end_indices_(end_indices),
1809 strides_(strides) {};
1820 std::vector<
int> start_indices_;
1821 std::vector<
int> end_indices_;
1822 std::vector<
int> strides_;
1824 void eval(const std::vector<
array>& inputs,
array& out);
1826 std::tuple<int64_t, std::vector<int64_t>> prepare_slice(const
array& in);
1845 void eval(const std::vector<
array>& inputs,
array& out);
1866 void eval(const std::vector<
array>& inputs,
array& out);
1871 explicit Split(
Stream stream,
const std::vector<int>& indices,
int axis)
1872 :
Primitive(stream), indices_(indices), axis_(axis) {};
1874 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
1876 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
1885 void eval(const std::vector<
array>& inputs, std::vector<
array>& outputs);
1887 std::vector<
int> indices_;
1905 void eval(const std::vector<
array>& inputs,
array& out);
1921 void print(std::ostream& os)
override {
1930 void eval(
const std::vector<array>& inputs,
array& out);
1947 void eval(const std::vector<
array>& inputs,
array& out);
1964 void eval(const std::vector<
array>& inputs,
array& out);
1981 void eval(const std::vector<
array>& inputs,
array& out);
1998 void eval(const std::vector<
array>& inputs,
array& out);
2013 void eval(const std::vector<
array>& inputs,
array& out);
2025 void print(std::ostream& os) override;
2046 std::vector<
int> axes_;
2048 void eval(const std::vector<
array>& inputs,
array& out);
2056 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2058 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2064 void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
2072 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2074 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2081 void eval(const std::vector<
array>& inputs, std::vector<
array>& outputs);
2096 void eval(const std::vector<
array>& inputs,
array& output);
2111 void eval(const std::vector<
array>& inputs,
array& output);
Definition primitives.h:155
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Abs(Stream stream)
Definition primitives.h:157
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:164
std::vector< std::vector< int > > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:166
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:165
Definition primitives.h:172
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Add(Stream stream)
Definition primitives.h:174
Definition primitives.h:189
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
AddMM(Stream stream, float alpha, float beta)
Definition primitives.h:191
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition primitives.h:213
Arange(Stream stream, double start, double stop, double step)
Definition primitives.h:215
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:232
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCos(Stream stream)
Definition primitives.h:234
Definition primitives.h:249
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCosh(Stream stream)
Definition primitives.h:251
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:266
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcSin(Stream stream)
Definition primitives.h:268
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:283
ArcSinh(Stream stream)
Definition primitives.h:285
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:317
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTan2(Stream stream)
Definition primitives.h:319
Definition primitives.h:300
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcTan(Stream stream)
Definition primitives.h:302
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:334
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTanh(Stream stream)
Definition primitives.h:336
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:351
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArgPartition(Stream stream, int kth, int axis)
Definition primitives.h:353
Definition primitives.h:371
ReduceType
Definition primitives.h:373
@ ArgMin
Definition primitives.h:374
@ ArgMax
Definition primitives.h:375
ArgReduce(Stream stream, ReduceType reduce_type, int axis)
Definition primitives.h:378
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:397
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArgSort(Stream stream, int axis)
Definition primitives.h:399
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:436
AsStrided(Stream stream, std::vector< int > shape, std::vector< size_t > strides, size_t offset)
Definition primitives.h:438
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:416
void eval_gpu(const std::vector< array > &inputs, array &out) override
AsType(Stream stream, Dtype dtype)
Definition primitives.h:418
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:463
BitwiseBinary(Stream stream, Op op)
Definition primitives.h:467
void eval_cpu(const std::vector< array > &inputs, array &out) override
Op
Definition primitives.h:465
@ And
Definition primitives.h:465
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:482
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
BlockMaskedMM(Stream stream, int block_size)
Definition primitives.h:484
Definition primitives.h:525
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Broadcast(Stream stream, const std::vector< int > &shape)
Definition primitives.h:527
Definition primitives.h:544
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Ceil(Stream stream)
Definition primitives.h:546
Definition primitives.h:2099
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cholesky(Stream stream, bool upper)
Definition primitives.h:2101
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:561
Compiled(Stream stream, std::vector< array > inputs, std::vector< array > outputs, std::vector< array > tape, std::unordered_set< uintptr_t > constant_ids)
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:604
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Concatenate(Stream stream, int axis)
Definition primitives.h:606
Definition primitives.h:623
Conjugate(Stream stream)
Definition primitives.h:625
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:639
void eval_gpu(const std::vector< array > &inputs, array &out) override
Convolution(Stream stream, const std::vector< int > &kernel_strides, const std::vector< int > &padding, const std::vector< int > &kernel_dilation, const std::vector< int > &input_dilation, const int groups=1, const bool flip=false)
Definition primitives.h:641
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition primitives.h:680
void eval_gpu(const std::vector< array > &inputs, array &out) override
Copy(Stream stream)
Definition primitives.h:682
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:697
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cos(Stream stream)
Definition primitives.h:699
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:714
void eval_gpu(const std::vector< array > &inputs, array &out) override
Cosh(Stream stream)
Definition primitives.h:716
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:731
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
CustomVJP(Stream stream, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> fun)
Definition primitives.h:733
Definition primitives.h:764
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Depends(Stream stream)
Definition primitives.h:766
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:802
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
DivMod(Stream stream)
Definition primitives.h:804
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:785
Divide(Stream stream)
Definition primitives.h:787
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:858
Equal(Stream stream, bool equal_nan=false)
Definition primitives.h:860
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:884
Erf(Stream stream)
Definition primitives.h:886
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:901
void eval_gpu(const std::vector< array > &inputs, array &out) override
ErfInv(Stream stream)
Definition primitives.h:903
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:918
Exp(Stream stream)
Definition primitives.h:920
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:935
Expm1(Stream stream)
Definition primitives.h:937
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:951
FFT(Stream stream, const std::vector< size_t > &axes, bool inverse, bool real)
Definition primitives.h:953
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:977
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Floor(Stream stream)
Definition primitives.h:979
Definition primitives.h:994
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Full(Stream stream)
Definition primitives.h:996
Definition primitives.h:1010
Gather(Stream stream, const std::vector< int > &axes, const std::vector< int > &slice_sizes)
Definition primitives.h:1012
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:505
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
GatherMM(Stream stream)
Definition primitives.h:507
Definition primitives.h:1470
GatherQMM(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1472
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1049
void eval_cpu(const std::vector< array > &inputs, array &out) override
GreaterEqual(Stream stream)
Definition primitives.h:1051
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1032
Greater(Stream stream)
Definition primitives.h:1034
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2085
void eval_gpu(const std::vector< array > &inputs, array &output) override
Inverse(Stream stream)
Definition primitives.h:2087
void eval_cpu(const std::vector< array > &inputs, array &output) override
Definition primitives.h:1083
LessEqual(Stream stream)
Definition primitives.h:1085
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1066
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Less(Stream stream)
Definition primitives.h:1068
Definition primitives.h:1100
void eval_gpu(const std::vector< array > &inputs, array &out) override
Load(Stream stream, std::shared_ptr< io::Reader > reader, size_t offset, bool swap_endianness=false)
Definition primitives.h:1102
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1158
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Log1p(Stream stream)
Definition primitives.h:1160
Definition primitives.h:1225
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogAddExp(Stream stream)
Definition primitives.h:1227
Definition primitives.h:1124
Base
Definition primitives.h:1126
Log(Stream stream, Base base)
Definition primitives.h:1128
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1191
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalAnd(Stream stream)
Definition primitives.h:1193
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1174
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalNot(Stream stream)
Definition primitives.h:1176
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1208
void eval_cpu(const std::vector< array > &inputs, array &out) override
LogicalOr(Stream stream)
Definition primitives.h:1210
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1242
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
Matmul(Stream stream)
Definition primitives.h:1244
Definition primitives.h:1260
Maximum(Stream stream)
Definition primitives.h:1262
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1277
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Minimum(Stream stream)
Definition primitives.h:1279
Definition primitives.h:1294
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Multiply(Stream stream)
Definition primitives.h:1296
Definition primitives.h:1311
void eval_gpu(const std::vector< array > &inputs, array &out) override
Negative(Stream stream)
Definition primitives.h:1313
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1328
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
NotEqual(Stream stream)
Definition primitives.h:1330
Definition primitives.h:1345
void eval_gpu(const std::vector< array > &inputs, array &out) override
NumberOfElements(Stream stream, std::vector< int > axes, bool inverted, Dtype dtype)
Definition primitives.h:1347
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1376
void eval_cpu(const std::vector< array > &inputs, array &out) override
Pad(Stream stream, const std::vector< int > &axes, const std::vector< int > &low_pad_size, const std::vector< int > &high_pad_size)
Definition primitives.h:1378
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1404
void eval_cpu(const std::vector< array > &inputs, array &out) override
Partition(Stream stream, int kth, int axis)
Definition primitives.h:1406
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1425
void eval_cpu(const std::vector< array > &inputs, array &out) override
Power(Stream stream)
Definition primitives.h:1427
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:48
virtual void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
virtual std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
The vector-Jacobian product.
virtual ~Primitive()=default
Primitive(const Primitive &other)=delete
Primitive(Primitive &&other)=delete
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
Primitive & operator=(Primitive &&other)=delete
virtual bool is_equivalent(const Primitive &other) const
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:107
Primitive & operator=(const Primitive &other)=delete
virtual std::vector< std::vector< int > > output_shapes(const std::vector< array > &inputs)
Get the output shapes of the primitive.
const Device & device()
The device the primitive will run on.
Definition primitives.h:53
virtual std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
The Jacobian-vector product.
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes)
The primitive must know how to vectorize itself across the given axes.
virtual void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
virtual void print(std::ostream &os)=0
Print the primitive.
Primitive(Stream stream)
Definition primitives.h:50
Definition primitives.h:2052
QRF(Stream stream)
Definition primitives.h:2054
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:1442
void eval_gpu(const std::vector< array > &inputs, array &out) override
QuantizedMatmul(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1444
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1494
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
RandomBits(Stream stream, const std::vector< int > &shape, int width)
Definition primitives.h:1496
Definition primitives.h:1540
Reduce(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1544
ReduceType
Definition primitives.h:1542
@ And
Definition primitives.h:1542
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:841
Remainder(Stream stream)
Definition primitives.h:843
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1513
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Reshape(Stream stream, const std::vector< int > &shape)
Definition primitives.h:1515
Definition primitives.h:1595
Round(Stream stream)
Definition primitives.h:1597
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2068
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
SVD(Stream stream)
Definition primitives.h:2070
Definition primitives.h:1612
void eval_cpu(const std::vector< array > &inputs, array &out) override
ReduceType
Definition primitives.h:1614
@ Max
Definition primitives.h:1614
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Scan(Stream stream, ReduceType reduce_type, int axis, bool reverse, bool inclusive)
Definition primitives.h:1616
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1662
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
ReduceType
Definition primitives.h:1664
@ Max
Definition primitives.h:1664
void eval_cpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1676
void eval_gpu(const std::vector< array > &inputs, array &out) override
Scatter(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1666
Definition primitives.h:824
void eval_gpu(const std::vector< array > &inputs, array &out) override
Select(Stream stream)
Definition primitives.h:826
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1703
Sigmoid(Stream stream)
Definition primitives.h:1705
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1720
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Sign(Stream stream)
Definition primitives.h:1722
Definition primitives.h:1737
Sin(Stream stream)
Definition primitives.h:1739
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1754
Sinh(Stream stream)
Definition primitives.h:1756
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1771
void eval_cpu(const std::vector< array > &inputs, array &out) override
Slice(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
Definition primitives.h:1773
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1799
SliceUpdate(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
Definition primitives.h:1801
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1829
void eval_gpu(const std::vector< array > &inputs, array &out) override
Softmax(Stream stream, bool precise)
Definition primitives.h:1831
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1849
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sort(Stream stream, int axis)
Definition primitives.h:1851
Definition primitives.h:1869
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Split(Stream stream, const std::vector< int > &indices, int axis)
Definition primitives.h:1871
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:1908
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sqrt(Stream stream, bool recip=false)
Definition primitives.h:1910
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1891
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Square(Stream stream)
Definition primitives.h:1893
Definition primitives.h:1934
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
StopGradient(Stream stream)
Definition primitives.h:1936
Definition primitives.h:1950
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Subtract(Stream stream)
Definition primitives.h:1952
Definition primitives.h:1967
Tan(Stream stream)
Definition primitives.h:1969
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1984
void eval_gpu(const std::vector< array > &inputs, array &out) override
Tanh(Stream stream)
Definition primitives.h:1986
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2032
Transpose(Stream stream, const std::vector< int > &axes)
Definition primitives.h:2034
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:127
UnaryPrimitive & operator=(const UnaryPrimitive &other)=delete
UnaryPrimitive(Stream stream)
An abstract base class for a primitive with a single output.
Definition primitives.h:132
virtual void eval_gpu(const std::vector< array > &inputs, array &output)=0
UnaryPrimitive(UnaryPrimitive &&other)=delete
virtual void eval_cpu(const std::vector< array > &inputs, array &output)=0
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:142
UnaryPrimitive(const UnaryPrimitive &other)=delete
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:137
UnaryPrimitive & operator=(UnaryPrimitive &&other)=delete
virtual ~UnaryPrimitive()=default
Definition primitives.h:2016
void eval_cpu(const std::vector< array > &inputs, array &out) override
View(Stream stream, Dtype dtype)
Definition primitives.h:2018
void eval_gpu(const std::vector< array > &inputs, array &out) override
Op op
Definition binary.h:141
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
array transpose(const array &a, std::vector< int > axes, StreamOrDevice s={})
Permutes the dimensions according to the given axes.
void eval(std::vector< array > outputs)
#define DEFINE_DEFAULT_IS_EQUIVALENT()
Definition primitives.h:34
#define DEFINE_PRINT(PRIMITIVE)
Definition primitives.h:29
#define DEFINE_INPUT_OUTPUT_SHAPE()
Definition primitives.h:39
#define DEFINE_GRADS()
Definition primitives.h:17
#define DEFINE_VMAP()
Definition primitives.h:12
Definition binary_ops.h:270
Definition binary_ops.h:277
Device device
Definition stream.h:11