5#include <unordered_set>
12#define DEFINE_VMAP() \
13 virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
14 const std::vector<array>& inputs, const std::vector<int>& axes) \
17#define DEFINE_GRADS() \
18 std::vector<array> jvp( \
19 const std::vector<array>& primals, \
20 const std::vector<array>& tangents, \
21 const std::vector<int>& argnums) override; \
23 std::vector<array> vjp( \
24 const std::vector<array>& primals, \
25 const std::vector<array>& cotangents, \
26 const std::vector<int>& argnums, \
27 const std::vector<array>& outputs) override;
29#define DEFINE_PRINT(PRIMITIVE) \
30 void print(std::ostream& os) override { \
34#define DEFINE_DEFAULT_IS_EQUIVALENT() \
35 bool is_equivalent(const Primitive& other) const override { \
39#define DEFINE_INPUT_OUTPUT_SHAPE() \
40 std::vector<Shape> output_shapes(const std::vector<array>& inputs) \
42 return {inputs[0].shape()}; \
70 const std::vector<array>& inputs,
71 std::vector<array>& outputs) = 0;
73 const std::vector<array>& inputs,
74 std::vector<array>& outputs) = 0;
79 virtual std::vector<array>
jvp(
80 const std::vector<array>& primals,
81 const std::vector<array>& tangents,
82 const std::vector<int>& argnums);
87 virtual std::vector<array>
vjp(
88 const std::vector<array>& primals,
89 const std::vector<array>& cotangents,
90 const std::vector<int>& argnums,
91 const std::vector<array>& outputs);
99 virtual std::pair<std::vector<array>, std::vector<int>>
vmap(
100 const std::vector<array>& inputs,
101 const std::vector<int>& axes);
104 virtual void print(std::ostream& os) = 0;
137 const std::vector<array>& inputs,
138 std::vector<array>& outputs)
override {
142 const std::vector<array>& inputs,
143 std::vector<array>& outputs)
override {
168 void eval(const std::vector<
array>& inputs,
array& out);
185 void eval(const std::vector<
array>& inputs,
array& out);
197 const std::vector<array>& primals,
198 const std::vector<array>& cotangents,
199 const std::vector<int>& argnums,
200 const std::vector<array>& outputs)
override;
229 void eval(const std::vector<
array>& inputs,
array& out);
246 void eval(const std::vector<
array>& inputs,
array& out);
263 void eval(const std::vector<
array>& inputs,
array& out);
280 void eval(const std::vector<
array>& inputs,
array& out);
297 void eval(const std::vector<
array>& inputs,
array& out);
314 void eval(const std::vector<
array>& inputs,
array& out);
331 void eval(const std::vector<
array>& inputs,
array& out);
348 void eval(const std::vector<
array>& inputs,
array& out);
369 void eval(const std::vector<
array>& inputs,
array& out);
395 void eval(const std::vector<
array>& inputs,
array& out);
414 void eval(const std::vector<
array>& inputs,
array& out);
434 void eval(const std::vector<
array>& inputs,
array& out);
441 shape_(
std::move(shape)),
442 strides_(
std::move(strides)),
457 void eval(const std::vector<
array>& inputs,
array& out);
473 void print(std::ostream& os) override;
489 const std::vector<array>& primals,
490 const std::vector<array>& cotangents,
491 const std::vector<int>& argnums,
492 const std::vector<array>& outputs)
override;
500 void eval(const std::vector<
array>& inputs,
array& out);
511 const std::vector<array>& primals,
512 const std::vector<array>& cotangents,
513 const std::vector<int>& argnums,
514 const std::vector<array>& outputs)
override;
520 void eval(const std::vector<
array>& inputs,
array& out);
539 void eval(const std::vector<
array>& inputs,
array& out);
556 void eval(const std::vector<
array>& inputs,
array& out);
572 std::vector<array> inputs,
573 std::vector<array> outputs,
574 std::vector<array> tape,
575 std::unordered_set<uintptr_t> constant_ids);
577 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
579 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
585 void print(std::ostream& os) override;
588 std::
string lib_name()
const {
593 const std::vector<array> inputs_;
594 const std::vector<array> outputs_;
595 const std::vector<array> tape_;
596 const std::unordered_set<uintptr_t> constant_ids_;
598 std::string kernel_lib_;
612 bool is_equivalent(const
Primitive& other) const override;
613 std::vector<
Shape> output_shapes(const std::vector<
array>& inputs) override;
618 void eval(const std::vector<
array>& inputs,
array& out);
634 void eval(const std::vector<
array>& inputs,
array& out);
650 bool is_equivalent(const
Primitive& other) const override;
653 bool allow_col_major_;
660 const std::vector<int>& kernel_strides,
661 const std::vector<int>& padding,
662 const std::vector<int>& kernel_dilation,
663 const std::vector<int>& input_dilation,
664 const int groups = 1,
665 const bool flip =
false)
668 kernel_strides_(kernel_strides),
669 kernel_dilation_(kernel_dilation),
670 input_dilation_(input_dilation),
678 const std::vector<array>& primals,
679 const std::vector<array>& cotangents,
680 const std::vector<int>& argnums,
681 const std::vector<array>& outputs)
override;
684 bool is_equivalent(const
Primitive& other) const override;
687 std::vector<
int> padding_;
688 std::vector<
int> kernel_strides_;
689 std::vector<
int> kernel_dilation_;
690 std::vector<
int> input_dilation_;
694 void eval(const std::vector<
array>& inputs,
array& out);
711 void eval(const std::vector<
array>& inputs,
array& out);
728 void eval(const std::vector<
array>& inputs,
array& out);
745 void eval(const std::vector<
array>& inputs,
array& out);
753 std::function<std::vector<array>(
754 const std::vector<array>&,
755 const std::vector<array>&,
756 const std::vector<array>&)> vjp,
757 std::function<std::vector<array>(
758 const std::vector<array>&,
759 const std::vector<array>&,
760 const std::vector<int>&)> jvp,
761 std::function<std::pair<std::vector<array>, std::vector<int>>(
762 const std::vector<array>&,
763 const std::vector<int>&)> vmap)
765 num_outputs_(num_outputs),
770 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
772 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
780 void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
784 std::function<std::vector<array>(
785 const std::vector<array>&,
786 const std::vector<array>&,
787 const std::vector<array>&)>
789 std::function<std::vector<array>(
790 const std::vector<array>&,
791 const std::vector<array>&,
792 const std::vector<int>&)>
794 std::function<std::pair<std::vector<array>, std::vector<int>>(
795 const std::vector<array>&,
796 const std::vector<int>&)>
804 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
806 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
810 const std::vector<array>& primals,
811 const std::vector<array>& cotan,
812 const std::vector<int>& argnums,
813 const std::vector<array>& outputs)
override;
818 void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
835 void eval(const std::vector<
array>& inputs,
array& out);
842 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
844 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
851 std::vector<
Shape> output_shapes(const std::vector<
array>& inputs)
override {
852 return std::vector{inputs[0].shape(), inputs[0].shape()};
856 void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
873 void eval(const std::vector<
array>& inputs,
array& out);
890 void eval(const std::vector<
array>& inputs,
array& out);
906 void print(std::ostream& os)
override {
915 void eval(
const std::vector<array>& inputs,
array& out);
933 void eval(const std::vector<
array>& inputs,
array& out);
950 void eval(const std::vector<
array>& inputs,
array& out);
967 void eval(const std::vector<
array>& inputs,
array& out);
983 void eval(const std::vector<
array>& inputs,
array& out);
990 const std::vector<size_t>& axes,
1005 std::vector<
size_t> axes_;
1009 void eval(const std::vector<
array>& inputs,
array& out);
1026 void eval(const std::vector<
array>& inputs,
array& out);
1042 void eval(const std::vector<
array>& inputs,
array& out);
1049 const std::vector<int>& axes,
1050 const std::vector<int>& slice_sizes)
1051 :
UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes) {}
1060 std::vector<
Shape> output_shapes(const std::vector<
array>& inputs) override;
1063 void eval(const std::vector<
array>& inputs,
array& out);
1064 std::vector<
int> axes_;
1065 std::vector<
int> slice_sizes_;
1082 void eval(const std::vector<
array>& inputs,
array& out);
1099 void eval(const std::vector<
array>& inputs,
array& out);
1120 void eval(const std::vector<
array>& inputs,
array& out);
1151 void eval(const std::vector<
array>& inputs,
array& out);
1168 void eval(const std::vector<
array>& inputs,
array& out);
1175 std::shared_ptr<io::Reader> reader,
1177 bool swap_endianness =
false)
1179 reader_(
std::move(reader)),
1181 swap_endianness_(swap_endianness) {
1182 if (stream.
device == Device::gpu) {
1194 static Stream io_stream = new_stream(Device::cpu);
1197 void eval(
const std::vector<array>& inputs,
array& out);
1198 std::shared_ptr<io::Reader> reader_;
1200 bool swap_endianness_;
1218 void print(std::ostream& os)
override {
1234 void eval(
const std::vector<array>& inputs,
array& out);
1250 void eval(const std::vector<
array>& inputs,
array& out);
1267 void eval(const std::vector<
array>& inputs,
array& out);
1284 void eval(const std::vector<
array>& inputs,
array& out);
1301 void eval(const std::vector<
array>& inputs,
array& out);
1318 void eval(const std::vector<
array>& inputs,
array& out);
1329 const std::vector<array>& primals,
1330 const std::vector<array>& cotangents,
1331 const std::vector<int>& argnums,
1332 const std::vector<array>& outputs)
override;
1337 std::vector<
Shape> output_shapes(const std::vector<
array>& inputs) override;
1354 void eval(const std::vector<
array>& inputs,
array& out);
1371 void eval(const std::vector<
array>& inputs,
array& out);
1388 void eval(const std::vector<
array>& inputs,
array& out);
1405 void eval(const std::vector<
array>& inputs,
array& out);
1422 void eval(const std::vector<
array>& inputs,
array& out);
1429 std::vector<int> axes,
1433 axes_(
std::move(axes)),
1434 inverted_(inverted),
1443 std::vector<
Shape> output_shapes(const std::vector<
array>& inputs)
override {
1448 std::vector<int> axes_;
1452 void eval(
const std::vector<array>& inputs,
array& out);
1459 const std::vector<int>& axes,
1460 const std::vector<int>& low_pad_size,
1461 const std::vector<int>& high_pad_size)
1464 low_pad_size_(low_pad_size),
1465 high_pad_size_(high_pad_size) {}
1476 std::vector<
int> axes_;
1477 std::vector<
int> low_pad_size_;
1478 std::vector<
int> high_pad_size_;
1480 void eval(const std::vector<
array>& inputs,
array& out);
1501 void eval(const std::vector<
array>& inputs,
array& out);
1518 void eval(const std::vector<
array>& inputs,
array& out);
1529 group_size_(group_size),
1540 std::vector<
Shape> output_shapes(const std::vector<
array>& inputs) override;
1547 void eval(const std::vector<
array>& inputs,
array& out);
1554 group_size_(group_size),
1571 void eval(const std::vector<
array>& inputs,
array& out);
1590 void eval(const std::vector<
array>& inputs,
array& out);
1623 void eval(const std::vector<
array>& inputs,
array& out);
1625 static std::pair<
bool,
Strides> prepare_reshape(
1628 static
void shared_buffer_reshape(
1641 const std::vector<int>& axes)
1642 :
UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1650 const std::vector<
array>& primals,
1651 const std::vector<
array>& cotangents,
1652 const std::vector<
int>& argnums,
1653 const std::vector<
array>& outputs) override;
1655 std::vector<
Shape> output_shapes(const std::vector<
array>& inputs) override;
1657 void print(std::ostream& os)
override {
1658 switch (reduce_type_) {
1683 std::vector<int> axes_;
1685 void eval(
const std::vector<array>& inputs,
array& out);
1702 void eval(const std::vector<
array>& inputs,
array& out);
1716 reduce_type_(reduce_type),
1719 inclusive_(inclusive) {}
1727 void print(std::ostream& os)
override {
1729 switch (reduce_type_) {
1752 void eval(
const std::vector<array>& inputs,
array& out);
1762 const std::vector<int>& axes)
1763 :
UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1773 switch (reduce_type_) {
1793 void eval(
const std::vector<array>& inputs,
array& out);
1795 std::vector<int> axes_;
1812 void eval(const std::vector<
array>& inputs,
array& out);
1829 void eval(const std::vector<
array>& inputs,
array& out);
1846 void eval(const std::vector<
array>& inputs,
array& out);
1863 void eval(const std::vector<
array>& inputs,
array& out);
1870 const std::vector<int>& start_indices,
1871 const std::vector<int>& end_indices,
1872 const std::vector<int>& strides)
1874 start_indices_(start_indices),
1875 end_indices_(end_indices),
1876 strides_(strides) {}
1887 std::vector<
int> start_indices_;
1888 std::vector<
int> end_indices_;
1889 std::vector<
int> strides_;
1891 void eval(const std::vector<
array>& inputs,
array& out);
1898 const std::vector<int>& start_indices,
1899 const std::vector<int>& end_indices,
1900 const std::vector<int>& strides)
1902 start_indices_(start_indices),
1903 end_indices_(end_indices),
1904 strides_(strides) {}
1915 std::vector<
int> start_indices_;
1916 std::vector<
int> end_indices_;
1917 std::vector<
int> strides_;
1919 void eval(const std::vector<
array>& inputs,
array& out);
1921 std::tuple<int64_t, std::vector<int64_t>> prepare_slice(const
array& in);
1940 void eval(const std::vector<
array>& inputs,
array& out);
1961 void eval(const std::vector<
array>& inputs,
array& out);
1966 explicit Split(
Stream stream,
const std::vector<int>& indices,
int axis)
1967 :
Primitive(stream), indices_(indices), axis_(axis) {}
1969 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
1971 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
1980 void eval(const std::vector<
array>& inputs, std::vector<
array>& outputs);
1982 std::vector<
int> indices_;
2000 void eval(const std::vector<
array>& inputs,
array& out);
2016 void print(std::ostream& os)
override {
2025 void eval(
const std::vector<array>& inputs,
array& out);
2042 void eval(const std::vector<
array>& inputs,
array& out);
2059 void eval(const std::vector<
array>& inputs,
array& out);
2076 void eval(const std::vector<
array>& inputs,
array& out);
2093 void eval(const std::vector<
array>& inputs,
array& out);
2108 void eval(const std::vector<
array>& inputs,
array& out);
2120 void print(std::ostream& os) override;
2139 std::vector<
Shape> output_shapes(const std::vector<
array>& inputs) override;
2142 std::vector<
int> axes_;
2144 void eval(const std::vector<
array>& inputs,
array& out);
2152 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2154 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2160 void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
2168 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2170 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2177 void eval(const std::vector<
array>& inputs, std::vector<
array>& outputs);
2193 void eval(const std::vector<
array>& inputs,
array& output);
2210 void eval(const std::vector<
array>& inputs,
array& output);
2216 explicit Eigh(
Stream stream, std::string uplo,
bool compute_eigenvectors)
2218 uplo_(
std::move(uplo)),
2219 compute_eigenvectors_(compute_eigenvectors) {}
2221 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2223 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2229 std::vector<
Shape> output_shapes(const std::vector<
array>& inputs) override;
2234 void eval(const std::vector<
array>& inputs, std::vector<
array>& outputs);
2236 bool compute_eigenvectors_;
Definition primitives.h:154
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Abs(Stream stream)
Definition primitives.h:156
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:163
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:164
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:165
Definition primitives.h:171
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Add(Stream stream)
Definition primitives.h:173
Definition primitives.h:188
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
AddMM(Stream stream, float alpha, float beta)
Definition primitives.h:190
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition primitives.h:212
Arange(Stream stream, double start, double stop, double step)
Definition primitives.h:214
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:232
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCos(Stream stream)
Definition primitives.h:234
Definition primitives.h:249
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCosh(Stream stream)
Definition primitives.h:251
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:266
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcSin(Stream stream)
Definition primitives.h:268
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:283
ArcSinh(Stream stream)
Definition primitives.h:285
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:317
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTan2(Stream stream)
Definition primitives.h:319
Definition primitives.h:300
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcTan(Stream stream)
Definition primitives.h:302
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:334
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTanh(Stream stream)
Definition primitives.h:336
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:351
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArgPartition(Stream stream, int kth, int axis)
Definition primitives.h:353
Definition primitives.h:372
ReduceType
Definition primitives.h:374
@ ArgMin
Definition primitives.h:375
@ ArgMax
Definition primitives.h:376
ArgReduce(Stream stream, ReduceType reduce_type, int axis)
Definition primitives.h:379
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:398
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArgSort(Stream stream, int axis)
Definition primitives.h:400
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:437
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
AsStrided(Stream stream, Shape shape, Strides strides, size_t offset)
Definition primitives.h:439
Definition primitives.h:417
void eval_gpu(const std::vector< array > &inputs, array &out) override
AsType(Stream stream, Dtype dtype)
Definition primitives.h:419
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:460
BitwiseBinary(Stream stream, Op op)
Definition primitives.h:464
void eval_cpu(const std::vector< array > &inputs, array &out) override
Op
Definition primitives.h:462
@ And
Definition primitives.h:462
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:480
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
BlockMaskedMM(Stream stream, int block_size)
Definition primitives.h:482
Definition primitives.h:523
Broadcast(Stream stream, const Shape &shape)
Definition primitives.h:525
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:542
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Ceil(Stream stream)
Definition primitives.h:544
Definition primitives.h:2198
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cholesky(Stream stream, bool upper)
Definition primitives.h:2200
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:559
Compiled(Stream stream, std::vector< array > inputs, std::vector< array > outputs, std::vector< array > tape, std::unordered_set< uintptr_t > constant_ids)
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:601
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Concatenate(Stream stream, int axis)
Definition primitives.h:603
Definition primitives.h:621
Conjugate(Stream stream)
Definition primitives.h:623
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:637
Contiguous(Stream stream, bool allow_col_major)
Definition primitives.h:639
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:656
void eval_gpu(const std::vector< array > &inputs, array &out) override
Convolution(Stream stream, const std::vector< int > &kernel_strides, const std::vector< int > &padding, const std::vector< int > &kernel_dilation, const std::vector< int > &input_dilation, const int groups=1, const bool flip=false)
Definition primitives.h:658
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition primitives.h:697
void eval_gpu(const std::vector< array > &inputs, array &out) override
Copy(Stream stream)
Definition primitives.h:699
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:714
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cos(Stream stream)
Definition primitives.h:716
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:731
void eval_gpu(const std::vector< array > &inputs, array &out) override
Cosh(Stream stream)
Definition primitives.h:733
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:800
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Depends(Stream stream)
Definition primitives.h:802
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:838
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
DivMod(Stream stream)
Definition primitives.h:840
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:821
Divide(Stream stream)
Definition primitives.h:823
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2214
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)
Definition primitives.h:2216
Definition primitives.h:893
Equal(Stream stream, bool equal_nan=false)
Definition primitives.h:895
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:919
Erf(Stream stream)
Definition primitives.h:921
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:936
void eval_gpu(const std::vector< array > &inputs, array &out) override
ErfInv(Stream stream)
Definition primitives.h:938
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:953
Exp(Stream stream)
Definition primitives.h:955
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:970
Expm1(Stream stream)
Definition primitives.h:972
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:986
FFT(Stream stream, const std::vector< size_t > &axes, bool inverse, bool real)
Definition primitives.h:988
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1012
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Floor(Stream stream)
Definition primitives.h:1014
Definition primitives.h:1029
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Full(Stream stream)
Definition primitives.h:1031
Definition primitives.h:1045
Gather(Stream stream, const std::vector< int > &axes, const std::vector< int > &slice_sizes)
Definition primitives.h:1047
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:503
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
GatherMM(Stream stream)
Definition primitives.h:505
Definition primitives.h:1550
GatherQMM(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1552
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1085
void eval_cpu(const std::vector< array > &inputs, array &out) override
GreaterEqual(Stream stream)
Definition primitives.h:1087
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1068
Greater(Stream stream)
Definition primitives.h:1070
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1102
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Hadamard(Stream stream, float scale)
Definition primitives.h:1104
Definition primitives.h:1123
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Imag(Stream stream)
Definition primitives.h:1125
Definition primitives.h:2181
void eval_gpu(const std::vector< array > &inputs, array &output) override
Inverse(Stream stream, bool tri, bool upper)
Definition primitives.h:2183
void eval_cpu(const std::vector< array > &inputs, array &output) override
Definition primitives.h:1154
LessEqual(Stream stream)
Definition primitives.h:1156
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1137
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Less(Stream stream)
Definition primitives.h:1139
Definition primitives.h:1171
void eval_gpu(const std::vector< array > &inputs, array &out) override
Load(Stream stream, std::shared_ptr< io::Reader > reader, size_t offset, bool swap_endianness=false)
Definition primitives.h:1173
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1237
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Log1p(Stream stream)
Definition primitives.h:1239
Definition primitives.h:1304
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogAddExp(Stream stream)
Definition primitives.h:1306
Definition primitives.h:1203
Base
Definition primitives.h:1205
Log(Stream stream, Base base)
Definition primitives.h:1207
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1270
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalAnd(Stream stream)
Definition primitives.h:1272
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1253
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalNot(Stream stream)
Definition primitives.h:1255
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1287
void eval_cpu(const std::vector< array > &inputs, array &out) override
LogicalOr(Stream stream)
Definition primitives.h:1289
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1321
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
Matmul(Stream stream)
Definition primitives.h:1323
Definition primitives.h:1340
Maximum(Stream stream)
Definition primitives.h:1342
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1357
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Minimum(Stream stream)
Definition primitives.h:1359
Definition primitives.h:1374
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Multiply(Stream stream)
Definition primitives.h:1376
Definition primitives.h:1391
void eval_gpu(const std::vector< array > &inputs, array &out) override
Negative(Stream stream)
Definition primitives.h:1393
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1408
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
NotEqual(Stream stream)
Definition primitives.h:1410
Definition primitives.h:1425
void eval_gpu(const std::vector< array > &inputs, array &out) override
NumberOfElements(Stream stream, std::vector< int > axes, bool inverted, Dtype dtype)
Definition primitives.h:1427
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1455
void eval_cpu(const std::vector< array > &inputs, array &out) override
Pad(Stream stream, const std::vector< int > &axes, const std::vector< int > &low_pad_size, const std::vector< int > &high_pad_size)
Definition primitives.h:1457
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1483
void eval_cpu(const std::vector< array > &inputs, array &out) override
Partition(Stream stream, int kth, int axis)
Definition primitives.h:1485
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1504
void eval_cpu(const std::vector< array > &inputs, array &out) override
Power(Stream stream)
Definition primitives.h:1506
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:48
virtual void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
virtual std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
The vector-Jacobian product.
virtual ~Primitive()=default
Primitive(const Primitive &other)=delete
Primitive(Primitive &&other)=delete
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
Primitive & operator=(Primitive &&other)=delete
virtual bool is_equivalent(const Primitive &other) const
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:107
Primitive & operator=(const Primitive &other)=delete
const Device & device()
The device the primitive will run on.
Definition primitives.h:53
virtual std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
The Jacobian-vector product.
virtual std::vector< Shape > output_shapes(const std::vector< array > &inputs)
Get the output shapes of the primitive.
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes)
The primitive must know how to vectorize itself across the given axes.
virtual void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
virtual void print(std::ostream &os)=0
Print the primitive.
Primitive(Stream stream)
Definition primitives.h:50
Definition primitives.h:2148
QRF(Stream stream)
Definition primitives.h:2150
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:1521
void eval_gpu(const std::vector< array > &inputs, array &out) override
QuantizedMatmul(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1523
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1574
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
RandomBits(Stream stream, const Shape &shape, int width)
Definition primitives.h:1576
Definition primitives.h:1593
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Real(Stream stream)
Definition primitives.h:1595
Definition primitives.h:1634
Reduce(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1638
ReduceType
Definition primitives.h:1636
@ And
Definition primitives.h:1636
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:876
Remainder(Stream stream)
Definition primitives.h:878
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1607
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Reshape(Stream stream, const Shape &shape)
Definition primitives.h:1609
Definition primitives.h:1688
Round(Stream stream)
Definition primitives.h:1690
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2164
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
SVD(Stream stream)
Definition primitives.h:2166
Definition primitives.h:1705
void eval_cpu(const std::vector< array > &inputs, array &out) override
ReduceType
Definition primitives.h:1707
@ Max
Definition primitives.h:1707
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Scan(Stream stream, ReduceType reduce_type, int axis, bool reverse, bool inclusive)
Definition primitives.h:1709
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1755
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
ReduceType
Definition primitives.h:1757
@ Max
Definition primitives.h:1757
void eval_cpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1771
void eval_gpu(const std::vector< array > &inputs, array &out) override
Scatter(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1759
Definition primitives.h:859
void eval_gpu(const std::vector< array > &inputs, array &out) override
Select(Stream stream)
Definition primitives.h:861
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1798
Sigmoid(Stream stream)
Definition primitives.h:1800
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1815
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Sign(Stream stream)
Definition primitives.h:1817
Definition primitives.h:1832
Sin(Stream stream)
Definition primitives.h:1834
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1849
Sinh(Stream stream)
Definition primitives.h:1851
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1866
void eval_cpu(const std::vector< array > &inputs, array &out) override
Slice(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
Definition primitives.h:1868
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1894
SliceUpdate(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
Definition primitives.h:1896
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1924
void eval_gpu(const std::vector< array > &inputs, array &out) override
Softmax(Stream stream, bool precise)
Definition primitives.h:1926
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1944
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sort(Stream stream, int axis)
Definition primitives.h:1946
Definition primitives.h:1964
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Split(Stream stream, const std::vector< int > &indices, int axis)
Definition primitives.h:1966
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:2003
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sqrt(Stream stream, bool recip=false)
Definition primitives.h:2005
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1986
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Square(Stream stream)
Definition primitives.h:1988
Definition primitives.h:2029
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
StopGradient(Stream stream)
Definition primitives.h:2031
Definition primitives.h:2045
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Subtract(Stream stream)
Definition primitives.h:2047
Definition primitives.h:2062
Tan(Stream stream)
Definition primitives.h:2064
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2079
void eval_gpu(const std::vector< array > &inputs, array &out) override
Tanh(Stream stream)
Definition primitives.h:2081
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2127
Transpose(Stream stream, const std::vector< int > &axes)
Definition primitives.h:2129
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:126
UnaryPrimitive & operator=(const UnaryPrimitive &other)=delete
UnaryPrimitive(Stream stream)
An abstract base class for a primitive with a single output.
Definition primitives.h:131
virtual void eval_gpu(const std::vector< array > &inputs, array &output)=0
UnaryPrimitive(UnaryPrimitive &&other)=delete
virtual void eval_cpu(const std::vector< array > &inputs, array &output)=0
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:141
UnaryPrimitive(const UnaryPrimitive &other)=delete
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:136
UnaryPrimitive & operator=(UnaryPrimitive &&other)=delete
virtual ~UnaryPrimitive()=default
Definition primitives.h:2111
void eval_cpu(const std::vector< array > &inputs, array &out) override
View(Stream stream, Dtype dtype)
Definition primitives.h:2113
void eval_gpu(const std::vector< array > &inputs, array &out) override
Op op
Definition binary.h:129
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
array tri(int n, int m, int k, Dtype type, StreamOrDevice s={})
array transpose(const array &a, std::vector< int > axes, StreamOrDevice s={})
Permutes the dimensions according to the given axes.
array real(const array &a, StreamOrDevice s={})
std::pair< std::vector< array >, std::vector< array > > jvp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &tangents)
Computes the output and Jacobian-vector product (JVP) of a function.
std::pair< std::vector< array >, std::vector< array > > vjp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &cotangents)
Computes the output and vector-Jacobian product (VJP) of a function.
void eval(std::vector< array > outputs)
std::vector< int32_t > Shape
Definition array.h:20
std::function< array(const array &)> vmap(const std::function< array(const array &)> &fun, int in_axis=0, int out_axis=0)
Automatically vectorize a unary function over the requested axes.
std::vector< size_t > Strides
Definition array.h:21
#define DEFINE_DEFAULT_IS_EQUIVALENT()
Definition primitives.h:34
#define DEFINE_PRINT(PRIMITIVE)
Definition primitives.h:29
#define DEFINE_INPUT_OUTPUT_SHAPE()
Definition primitives.h:39
#define DEFINE_GRADS()
Definition primitives.h:17
#define DEFINE_VMAP()
Definition primitives.h:12
Definition binary_ops.h:270
Definition binary_ops.h:277
Device device
Definition stream.h:11