5#include <unordered_set>
12#define DEFINE_VMAP() \
13 virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
14 const std::vector<array>& inputs, const std::vector<int>& axes) \
17#define DEFINE_GRADS() \
18 std::vector<array> jvp( \
19 const std::vector<array>& primals, \
20 const std::vector<array>& tangents, \
21 const std::vector<int>& argnums) override; \
23 std::vector<array> vjp( \
24 const std::vector<array>& primals, \
25 const std::vector<array>& cotangents, \
26 const std::vector<int>& argnums, \
27 const std::vector<array>& outputs) override;
29#define DEFINE_PRINT(PRIMITIVE) \
30 void print(std::ostream& os) override { \
34#define DEFINE_DEFAULT_IS_EQUIVALENT() \
35 bool is_equivalent(const Primitive& other) const override { \
39#define DEFINE_INPUT_OUTPUT_SHAPE() \
40 std::vector<std::vector<int>> output_shapes( \
41 const std::vector<array>& inputs) override { \
42 return {inputs[0].shape()}; \
70 const std::vector<array>& inputs,
71 std::vector<array>& outputs) = 0;
73 const std::vector<array>& inputs,
74 std::vector<array>& outputs) = 0;
79 virtual std::vector<array>
jvp(
80 const std::vector<array>& primals,
81 const std::vector<array>& tangents,
82 const std::vector<int>& argnums);
87 virtual std::vector<array>
vjp(
88 const std::vector<array>& primals,
89 const std::vector<array>& cotangents,
90 const std::vector<int>& argnums,
91 const std::vector<array>& outputs);
99 virtual std::pair<std::vector<array>, std::vector<int>>
vmap(
100 const std::vector<array>& inputs,
101 const std::vector<int>& axes);
104 virtual void print(std::ostream& os) = 0;
114 const std::vector<array>& inputs);
138 const std::vector<array>& inputs,
139 std::vector<array>& outputs)
override {
143 const std::vector<array>& inputs,
144 std::vector<array>& outputs)
override {
169 void eval(const std::vector<
array>& inputs,
array& out);
186 void eval(const std::vector<
array>& inputs,
array& out);
198 const std::vector<array>& primals,
199 const std::vector<array>& cotangents,
200 const std::vector<int>& argnums,
201 const std::vector<array>& outputs)
override;
229 void eval(const std::vector<
array>& inputs,
array& out);
246 void eval(const std::vector<
array>& inputs,
array& out);
263 void eval(const std::vector<
array>& inputs,
array& out);
280 void eval(const std::vector<
array>& inputs,
array& out);
297 void eval(const std::vector<
array>& inputs,
array& out);
314 void eval(const std::vector<
array>& inputs,
array& out);
331 void eval(const std::vector<
array>& inputs,
array& out);
348 void eval(const std::vector<
array>& inputs,
array& out);
369 void eval(const std::vector<
array>& inputs,
array& out);
390 const std::vector<
array>& inputs) override;
396 void eval(const std::vector<
array>& inputs,
array& out);
415 void eval(const std::vector<
array>& inputs,
array& out);
435 void eval(const std::vector<
array>& inputs,
array& out);
442 std::vector<int> shape,
443 std::vector<size_t> strides,
446 shape_(
std::move(shape)),
447 strides_(
std::move(strides)),
458 std::vector<
int> shape_;
459 std::vector<
size_t> strides_;
462 void eval(const std::vector<
array>& inputs,
array& out);
478 void print(std::ostream& os) override;
494 const std::vector<array>& primals,
495 const std::vector<array>& cotangents,
496 const std::vector<int>& argnums,
497 const std::vector<array>& outputs)
override;
505 void eval(const std::vector<
array>& inputs,
array& out);
516 const std::vector<array>& primals,
517 const std::vector<array>& cotangents,
518 const std::vector<int>& argnums,
519 const std::vector<array>& outputs)
override;
525 void eval(const std::vector<
array>& inputs,
array& out);
542 std::vector<
int> shape_;
544 void eval(const std::vector<
array>& inputs,
array& out);
561 void eval(const std::vector<
array>& inputs,
array& out);
577 std::vector<array> inputs,
578 std::vector<array> outputs,
579 std::vector<array> tape,
580 std::unordered_set<uintptr_t> constant_ids);
582 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
584 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
590 const std::vector<
array>& inputs) override;
591 void print(std::ostream& os) override;
594 std::
string lib_name()
const {
599 const std::vector<array> inputs_;
600 const std::vector<array> outputs_;
601 const std::vector<array> tape_;
602 const std::unordered_set<uintptr_t> constant_ids_;
604 std::string kernel_lib_;
618 bool is_equivalent(const
Primitive& other) const override;
623 void eval(const std::vector<
array>& inputs,
array& out);
639 void eval(const std::vector<
array>& inputs,
array& out);
655 bool is_equivalent(const
Primitive& other) const override;
658 bool allow_col_major_;
665 const std::vector<int>& kernel_strides,
666 const std::vector<int>& padding,
667 const std::vector<int>& kernel_dilation,
668 const std::vector<int>& input_dilation,
669 const int groups = 1,
670 const bool flip =
false)
673 kernel_strides_(kernel_strides),
674 kernel_dilation_(kernel_dilation),
675 input_dilation_(input_dilation),
683 const std::vector<array>& primals,
684 const std::vector<array>& cotangents,
685 const std::vector<int>& argnums,
686 const std::vector<array>& outputs)
override;
689 bool is_equivalent(const
Primitive& other) const override;
692 std::vector<
int> padding_;
693 std::vector<
int> kernel_strides_;
694 std::vector<
int> kernel_dilation_;
695 std::vector<
int> input_dilation_;
699 void eval(const std::vector<
array>& inputs,
array& out);
716 void eval(const std::vector<
array>& inputs,
array& out);
733 void eval(const std::vector<
array>& inputs,
array& out);
750 void eval(const std::vector<
array>& inputs,
array& out);
758 std::function<std::vector<array>(
759 const std::vector<array>&,
760 const std::vector<array>&,
761 const std::vector<array>&)> vjp,
762 std::function<std::vector<array>(
763 const std::vector<array>&,
764 const std::vector<array>&,
765 const std::vector<int>&)> jvp,
766 std::function<std::pair<std::vector<array>, std::vector<int>>(
767 const std::vector<array>&,
768 const std::vector<int>&)> vmap)
770 num_outputs_(num_outputs),
775 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
777 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
785 void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
789 std::function<std::vector<array>(
790 const std::vector<array>&,
791 const std::vector<array>&,
792 const std::vector<array>&)>
794 std::function<std::vector<array>(
795 const std::vector<array>&,
796 const std::vector<array>&,
797 const std::vector<int>&)>
799 std::function<std::pair<std::vector<array>, std::vector<int>>(
800 const std::vector<array>&,
801 const std::vector<int>&)>
809 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
811 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
815 const std::vector<array>& primals,
816 const std::vector<array>& cotan,
817 const std::vector<int>& argnums,
818 const std::vector<array>& outputs)
override;
823 void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
840 void eval(const std::vector<
array>& inputs,
array& out);
847 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
849 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
856 std::vector<std::vector<
int>> output_shapes(
857 const std::vector<
array>& inputs)
override {
858 return std::vector{inputs[0].shape(), inputs[0].shape()};
862 void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
879 void eval(const std::vector<
array>& inputs,
array& out);
896 void eval(const std::vector<
array>& inputs,
array& out);
912 void print(std::ostream& os)
override {
921 void eval(
const std::vector<array>& inputs,
array& out);
939 void eval(const std::vector<
array>& inputs,
array& out);
956 void eval(const std::vector<
array>& inputs,
array& out);
973 void eval(const std::vector<
array>& inputs,
array& out);
989 void eval(const std::vector<
array>& inputs,
array& out);
996 const std::vector<size_t>& axes,
1011 std::vector<
size_t> axes_;
1015 void eval(const std::vector<
array>& inputs,
array& out);
1032 void eval(const std::vector<
array>& inputs,
array& out);
1048 void eval(const std::vector<
array>& inputs,
array& out);
1055 const std::vector<int>& axes,
1056 const std::vector<int>& slice_sizes)
1057 :
UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes) {}
1068 void eval(const std::vector<
array>& inputs,
array& out);
1069 std::vector<
int> axes_;
1070 std::vector<
int> slice_sizes_;
1087 void eval(const std::vector<
array>& inputs,
array& out);
1104 void eval(const std::vector<
array>& inputs,
array& out);
1125 void eval(const std::vector<
array>& inputs,
array& out);
1156 void eval(const std::vector<
array>& inputs,
array& out);
1173 void eval(const std::vector<
array>& inputs,
array& out);
1180 std::shared_ptr<io::Reader> reader,
1182 bool swap_endianness =
false)
1184 reader_(
std::move(reader)),
1186 swap_endianness_(swap_endianness) {
1187 if (stream.
device == Device::gpu) {
1199 static Stream io_stream = new_stream(Device::cpu);
1202 void eval(
const std::vector<array>& inputs,
array& out);
1203 std::shared_ptr<io::Reader> reader_;
1205 bool swap_endianness_;
1223 void print(std::ostream& os)
override {
1239 void eval(
const std::vector<array>& inputs,
array& out);
1255 void eval(const std::vector<
array>& inputs,
array& out);
1272 void eval(const std::vector<
array>& inputs,
array& out);
1289 void eval(const std::vector<
array>& inputs,
array& out);
1306 void eval(const std::vector<
array>& inputs,
array& out);
1323 void eval(const std::vector<
array>& inputs,
array& out);
1334 const std::vector<array>& primals,
1335 const std::vector<array>& cotangents,
1336 const std::vector<int>& argnums,
1337 const std::vector<array>& outputs)
override;
1358 void eval(const std::vector<
array>& inputs,
array& out);
1375 void eval(const std::vector<
array>& inputs,
array& out);
1392 void eval(const std::vector<
array>& inputs,
array& out);
1409 void eval(const std::vector<
array>& inputs,
array& out);
1426 void eval(const std::vector<
array>& inputs,
array& out);
1433 std::vector<int> axes,
1437 axes_(
std::move(axes)),
1438 inverted_(inverted),
1447 std::vector<std::vector<
int>> output_shapes(
1448 const std::vector<
array>& inputs)
override {
1453 std::vector<int> axes_;
1457 void eval(
const std::vector<array>& inputs,
array& out);
1464 const std::vector<int>& axes,
1465 const std::vector<int>& low_pad_size,
1466 const std::vector<int>& high_pad_size)
1469 low_pad_size_(low_pad_size),
1470 high_pad_size_(high_pad_size) {}
1481 std::vector<
int> axes_;
1482 std::vector<
int> low_pad_size_;
1483 std::vector<
int> high_pad_size_;
1485 void eval(const std::vector<
array>& inputs,
array& out);
1506 void eval(const std::vector<
array>& inputs,
array& out);
1523 void eval(const std::vector<
array>& inputs,
array& out);
1534 group_size_(group_size),
1551 void eval(const std::vector<
array>& inputs,
array& out);
1558 group_size_(group_size),
1575 void eval(const std::vector<
array>& inputs,
array& out);
1591 std::vector<
int> shape_;
1594 void eval(const std::vector<
array>& inputs,
array& out);
1625 std::vector<
int> shape_;
1627 void eval(const std::vector<
array>& inputs,
array& out);
1629 std::pair<
bool, std::vector<
size_t>> prepare_reshape(
1632 void shared_buffer_reshape(
1634 const std::vector<
size_t>& out_strides,
1645 const std::vector<int>& axes)
1646 :
UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1654 const std::vector<
array>& primals,
1655 const std::vector<
array>& cotangents,
1656 const std::vector<
int>& argnums,
1657 const std::vector<
array>& outputs) override;
1659 std::vector<std::vector<
int>> output_shapes(
1660 const std::vector<
array>& inputs) override;
1662 void print(std::ostream& os)
override {
1663 switch (reduce_type_) {
1688 std::vector<int> axes_;
1690 void eval(
const std::vector<array>& inputs,
array& out);
1707 void eval(const std::vector<
array>& inputs,
array& out);
1721 reduce_type_(reduce_type),
1724 inclusive_(inclusive) {}
1732 void print(std::ostream& os)
override {
1734 switch (reduce_type_) {
1757 void eval(
const std::vector<array>& inputs,
array& out);
1767 const std::vector<int>& axes)
1768 :
UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1778 switch (reduce_type_) {
1798 void eval(
const std::vector<array>& inputs,
array& out);
1800 std::vector<int> axes_;
1817 void eval(const std::vector<
array>& inputs,
array& out);
1834 void eval(const std::vector<
array>& inputs,
array& out);
1851 void eval(const std::vector<
array>& inputs,
array& out);
1868 void eval(const std::vector<
array>& inputs,
array& out);
1875 const std::vector<int>& start_indices,
1876 const std::vector<int>& end_indices,
1877 const std::vector<int>& strides)
1879 start_indices_(start_indices),
1880 end_indices_(end_indices),
1881 strides_(strides) {}
1892 std::vector<
int> start_indices_;
1893 std::vector<
int> end_indices_;
1894 std::vector<
int> strides_;
1896 void eval(const std::vector<
array>& inputs,
array& out);
1903 const std::vector<int>& start_indices,
1904 const std::vector<int>& end_indices,
1905 const std::vector<int>& strides)
1907 start_indices_(start_indices),
1908 end_indices_(end_indices),
1909 strides_(strides) {}
1920 std::vector<
int> start_indices_;
1921 std::vector<
int> end_indices_;
1922 std::vector<
int> strides_;
1924 void eval(const std::vector<
array>& inputs,
array& out);
1926 std::tuple<int64_t, std::vector<int64_t>> prepare_slice(const
array& in);
1945 void eval(const std::vector<
array>& inputs,
array& out);
1966 void eval(const std::vector<
array>& inputs,
array& out);
1971 explicit Split(
Stream stream,
const std::vector<int>& indices,
int axis)
1972 :
Primitive(stream), indices_(indices), axis_(axis) {}
1974 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
1976 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
1985 void eval(const std::vector<
array>& inputs, std::vector<
array>& outputs);
1987 std::vector<
int> indices_;
2005 void eval(const std::vector<
array>& inputs,
array& out);
2021 void print(std::ostream& os)
override {
2030 void eval(
const std::vector<array>& inputs,
array& out);
2047 void eval(const std::vector<
array>& inputs,
array& out);
2064 void eval(const std::vector<
array>& inputs,
array& out);
2081 void eval(const std::vector<
array>& inputs,
array& out);
2098 void eval(const std::vector<
array>& inputs,
array& out);
2113 void eval(const std::vector<
array>& inputs,
array& out);
2125 void print(std::ostream& os) override;
2146 std::vector<
int> axes_;
2148 void eval(const std::vector<
array>& inputs,
array& out);
2156 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2158 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2164 void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
2172 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2174 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2181 void eval(const std::vector<
array>& inputs, std::vector<
array>& outputs);
2197 void eval(const std::vector<
array>& inputs,
array& output);
2214 void eval(const std::vector<
array>& inputs,
array& output);
2220 explicit Eigh(
Stream stream, std::string uplo,
bool compute_eigenvectors)
2222 uplo_(
std::move(uplo)),
2223 compute_eigenvectors_(compute_eigenvectors) {}
2225 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2227 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2233 std::vector<std::vector<
int>> output_shapes(
2234 const std::vector<
array>& inputs)
override {
2235 auto shape = inputs[0].shape();
2237 if (compute_eigenvectors_) {
2238 return {shape, inputs[0].shape()};
2245 if (
auto* p =
dynamic_cast<const Eigh*
>(&other)) {
2246 return uplo_ == p->uplo_ &&
2247 compute_eigenvectors_ == p->compute_eigenvectors_;
2253 void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
2255 bool compute_eigenvectors_;
Definition primitives.h:155
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Abs(Stream stream)
Definition primitives.h:157
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:164
std::vector< std::vector< int > > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:166
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:165
Definition primitives.h:172
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Add(Stream stream)
Definition primitives.h:174
Definition primitives.h:189
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
AddMM(Stream stream, float alpha, float beta)
Definition primitives.h:191
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition primitives.h:213
Arange(Stream stream, double start, double stop, double step)
Definition primitives.h:215
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:232
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCos(Stream stream)
Definition primitives.h:234
Definition primitives.h:249
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCosh(Stream stream)
Definition primitives.h:251
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:266
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcSin(Stream stream)
Definition primitives.h:268
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:283
ArcSinh(Stream stream)
Definition primitives.h:285
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:317
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTan2(Stream stream)
Definition primitives.h:319
Definition primitives.h:300
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcTan(Stream stream)
Definition primitives.h:302
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:334
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTanh(Stream stream)
Definition primitives.h:336
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:351
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArgPartition(Stream stream, int kth, int axis)
Definition primitives.h:353
Definition primitives.h:372
ReduceType
Definition primitives.h:374
@ ArgMin
Definition primitives.h:375
@ ArgMax
Definition primitives.h:376
ArgReduce(Stream stream, ReduceType reduce_type, int axis)
Definition primitives.h:379
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:399
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArgSort(Stream stream, int axis)
Definition primitives.h:401
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:438
AsStrided(Stream stream, std::vector< int > shape, std::vector< size_t > strides, size_t offset)
Definition primitives.h:440
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:418
void eval_gpu(const std::vector< array > &inputs, array &out) override
AsType(Stream stream, Dtype dtype)
Definition primitives.h:420
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:465
BitwiseBinary(Stream stream, Op op)
Definition primitives.h:469
void eval_cpu(const std::vector< array > &inputs, array &out) override
Op
Definition primitives.h:467
@ And
Definition primitives.h:467
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:485
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
BlockMaskedMM(Stream stream, int block_size)
Definition primitives.h:487
Definition primitives.h:528
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Broadcast(Stream stream, const std::vector< int > &shape)
Definition primitives.h:530
Definition primitives.h:547
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Ceil(Stream stream)
Definition primitives.h:549
Definition primitives.h:2202
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cholesky(Stream stream, bool upper)
Definition primitives.h:2204
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:564
Compiled(Stream stream, std::vector< array > inputs, std::vector< array > outputs, std::vector< array > tape, std::unordered_set< uintptr_t > constant_ids)
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:607
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Concatenate(Stream stream, int axis)
Definition primitives.h:609
Definition primitives.h:626
Conjugate(Stream stream)
Definition primitives.h:628
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:642
Contiguous(Stream stream, bool allow_col_major)
Definition primitives.h:644
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:661
void eval_gpu(const std::vector< array > &inputs, array &out) override
Convolution(Stream stream, const std::vector< int > &kernel_strides, const std::vector< int > &padding, const std::vector< int > &kernel_dilation, const std::vector< int > &input_dilation, const int groups=1, const bool flip=false)
Definition primitives.h:663
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition primitives.h:702
void eval_gpu(const std::vector< array > &inputs, array &out) override
Copy(Stream stream)
Definition primitives.h:704
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:719
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cos(Stream stream)
Definition primitives.h:721
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:736
void eval_gpu(const std::vector< array > &inputs, array &out) override
Cosh(Stream stream)
Definition primitives.h:738
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:805
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Depends(Stream stream)
Definition primitives.h:807
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:843
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
DivMod(Stream stream)
Definition primitives.h:845
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:826
Divide(Stream stream)
Definition primitives.h:828
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2218
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:2244
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)
Definition primitives.h:2220
Definition primitives.h:899
Equal(Stream stream, bool equal_nan=false)
Definition primitives.h:901
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:925
Erf(Stream stream)
Definition primitives.h:927
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:942
void eval_gpu(const std::vector< array > &inputs, array &out) override
ErfInv(Stream stream)
Definition primitives.h:944
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:959
Exp(Stream stream)
Definition primitives.h:961
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:976
Expm1(Stream stream)
Definition primitives.h:978
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:992
FFT(Stream stream, const std::vector< size_t > &axes, bool inverse, bool real)
Definition primitives.h:994
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1018
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Floor(Stream stream)
Definition primitives.h:1020
Definition primitives.h:1035
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Full(Stream stream)
Definition primitives.h:1037
Definition primitives.h:1051
Gather(Stream stream, const std::vector< int > &axes, const std::vector< int > &slice_sizes)
Definition primitives.h:1053
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:508
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
GatherMM(Stream stream)
Definition primitives.h:510
Definition primitives.h:1554
GatherQMM(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1556
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1090
void eval_cpu(const std::vector< array > &inputs, array &out) override
GreaterEqual(Stream stream)
Definition primitives.h:1092
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1073
Greater(Stream stream)
Definition primitives.h:1075
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1107
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Hadamard(Stream stream, float scale)
Definition primitives.h:1109
Definition primitives.h:1128
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Imag(Stream stream)
Definition primitives.h:1130
Definition primitives.h:2185
void eval_gpu(const std::vector< array > &inputs, array &output) override
Inverse(Stream stream, bool tri, bool upper)
Definition primitives.h:2187
void eval_cpu(const std::vector< array > &inputs, array &output) override
Definition primitives.h:1159
LessEqual(Stream stream)
Definition primitives.h:1161
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1142
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Less(Stream stream)
Definition primitives.h:1144
Definition primitives.h:1176
void eval_gpu(const std::vector< array > &inputs, array &out) override
Load(Stream stream, std::shared_ptr< io::Reader > reader, size_t offset, bool swap_endianness=false)
Definition primitives.h:1178
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1242
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Log1p(Stream stream)
Definition primitives.h:1244
Definition primitives.h:1309
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogAddExp(Stream stream)
Definition primitives.h:1311
Definition primitives.h:1208
Base
Definition primitives.h:1210
Log(Stream stream, Base base)
Definition primitives.h:1212
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1275
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalAnd(Stream stream)
Definition primitives.h:1277
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1258
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalNot(Stream stream)
Definition primitives.h:1260
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1292
void eval_cpu(const std::vector< array > &inputs, array &out) override
LogicalOr(Stream stream)
Definition primitives.h:1294
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1326
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
Matmul(Stream stream)
Definition primitives.h:1328
Definition primitives.h:1344
Maximum(Stream stream)
Definition primitives.h:1346
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1361
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Minimum(Stream stream)
Definition primitives.h:1363
Definition primitives.h:1378
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Multiply(Stream stream)
Definition primitives.h:1380
Definition primitives.h:1395
void eval_gpu(const std::vector< array > &inputs, array &out) override
Negative(Stream stream)
Definition primitives.h:1397
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1412
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
NotEqual(Stream stream)
Definition primitives.h:1414
Definition primitives.h:1429
void eval_gpu(const std::vector< array > &inputs, array &out) override
NumberOfElements(Stream stream, std::vector< int > axes, bool inverted, Dtype dtype)
Definition primitives.h:1431
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1460
void eval_cpu(const std::vector< array > &inputs, array &out) override
Pad(Stream stream, const std::vector< int > &axes, const std::vector< int > &low_pad_size, const std::vector< int > &high_pad_size)
Definition primitives.h:1462
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1488
void eval_cpu(const std::vector< array > &inputs, array &out) override
Partition(Stream stream, int kth, int axis)
Definition primitives.h:1490
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1509
void eval_cpu(const std::vector< array > &inputs, array &out) override
Power(Stream stream)
Definition primitives.h:1511
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:48
virtual void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
virtual std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
The vector-Jacobian product.
virtual ~Primitive()=default
Primitive(const Primitive &other)=delete
Primitive(Primitive &&other)=delete
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
Primitive & operator=(Primitive &&other)=delete
virtual bool is_equivalent(const Primitive &other) const
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:107
Primitive & operator=(const Primitive &other)=delete
virtual std::vector< std::vector< int > > output_shapes(const std::vector< array > &inputs)
Get the output shapes of the primitive.
const Device & device()
The device the primitive will run on.
Definition primitives.h:53
virtual std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
The Jacobian-vector product.
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes)
The primitive must know how to vectorize itself across the given axes.
virtual void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
virtual void print(std::ostream &os)=0
Print the primitive.
Primitive(Stream stream)
Definition primitives.h:50
Definition primitives.h:2152
QRF(Stream stream)
Definition primitives.h:2154
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:1526
void eval_gpu(const std::vector< array > &inputs, array &out) override
QuantizedMatmul(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1528
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1578
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
RandomBits(Stream stream, const std::vector< int > &shape, int width)
Definition primitives.h:1580
Definition primitives.h:1597
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Real(Stream stream)
Definition primitives.h:1599
Definition primitives.h:1638
Reduce(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1642
ReduceType
Definition primitives.h:1640
@ And
Definition primitives.h:1640
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:882
Remainder(Stream stream)
Definition primitives.h:884
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1611
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Reshape(Stream stream, const std::vector< int > &shape)
Definition primitives.h:1613
Definition primitives.h:1693
Round(Stream stream)
Definition primitives.h:1695
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2168
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
SVD(Stream stream)
Definition primitives.h:2170
Definition primitives.h:1710
void eval_cpu(const std::vector< array > &inputs, array &out) override
ReduceType
Definition primitives.h:1712
@ Max
Definition primitives.h:1712
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Scan(Stream stream, ReduceType reduce_type, int axis, bool reverse, bool inclusive)
Definition primitives.h:1714
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1760
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
ReduceType
Definition primitives.h:1762
@ Max
Definition primitives.h:1762
void eval_cpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1776
void eval_gpu(const std::vector< array > &inputs, array &out) override
Scatter(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1764
Definition primitives.h:865
void eval_gpu(const std::vector< array > &inputs, array &out) override
Select(Stream stream)
Definition primitives.h:867
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1803
Sigmoid(Stream stream)
Definition primitives.h:1805
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1820
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Sign(Stream stream)
Definition primitives.h:1822
Definition primitives.h:1837
Sin(Stream stream)
Definition primitives.h:1839
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1854
Sinh(Stream stream)
Definition primitives.h:1856
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1871
void eval_cpu(const std::vector< array > &inputs, array &out) override
Slice(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
Definition primitives.h:1873
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1899
SliceUpdate(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
Definition primitives.h:1901
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1929
void eval_gpu(const std::vector< array > &inputs, array &out) override
Softmax(Stream stream, bool precise)
Definition primitives.h:1931
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1949
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sort(Stream stream, int axis)
Definition primitives.h:1951
Definition primitives.h:1969
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Split(Stream stream, const std::vector< int > &indices, int axis)
Definition primitives.h:1971
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:2008
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sqrt(Stream stream, bool recip=false)
Definition primitives.h:2010
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1991
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Square(Stream stream)
Definition primitives.h:1993
Definition primitives.h:2034
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
StopGradient(Stream stream)
Definition primitives.h:2036
Definition primitives.h:2050
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Subtract(Stream stream)
Definition primitives.h:2052
Definition primitives.h:2067
Tan(Stream stream)
Definition primitives.h:2069
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2084
void eval_gpu(const std::vector< array > &inputs, array &out) override
Tanh(Stream stream)
Definition primitives.h:2086
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2132
Transpose(Stream stream, const std::vector< int > &axes)
Definition primitives.h:2134
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:127
UnaryPrimitive & operator=(const UnaryPrimitive &other)=delete
UnaryPrimitive(Stream stream)
An abstract base class for a primitive with a single output.
Definition primitives.h:132
virtual void eval_gpu(const std::vector< array > &inputs, array &output)=0
UnaryPrimitive(UnaryPrimitive &&other)=delete
virtual void eval_cpu(const std::vector< array > &inputs, array &output)=0
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:142
UnaryPrimitive(const UnaryPrimitive &other)=delete
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:137
UnaryPrimitive & operator=(UnaryPrimitive &&other)=delete
virtual ~UnaryPrimitive()=default
Definition primitives.h:2116
void eval_cpu(const std::vector< array > &inputs, array &out) override
View(Stream stream, Dtype dtype)
Definition primitives.h:2118
void eval_gpu(const std::vector< array > &inputs, array &out) override
Op op
Definition binary.h:129
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
array tri(int n, int m, int k, Dtype type, StreamOrDevice s={})
array transpose(const array &a, std::vector< int > axes, StreamOrDevice s={})
Permutes the dimensions according to the given axes.
array real(const array &a, StreamOrDevice s={})
std::pair< std::vector< array >, std::vector< array > > jvp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &tangents)
Computes the output and Jacobian-vector product (JVP) of a function.
std::pair< std::vector< array >, std::vector< array > > vjp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &cotangents)
Computes the output and vector-Jacobian product (VJP) of a function.
void eval(std::vector< array > outputs)
std::function< array(const array &)> vmap(const std::function< array(const array &)> &fun, int in_axis=0, int out_axis=0)
Automatically vectorize a unary function over the requested axes.
#define DEFINE_DEFAULT_IS_EQUIVALENT()
Definition primitives.h:34
#define DEFINE_PRINT(PRIMITIVE)
Definition primitives.h:29
#define DEFINE_INPUT_OUTPUT_SHAPE()
Definition primitives.h:39
#define DEFINE_GRADS()
Definition primitives.h:17
#define DEFINE_VMAP()
Definition primitives.h:12
Definition binary_ops.h:270
Definition binary_ops.h:277
Device device
Definition stream.h:11