5#include <unordered_set>
12#define DEFINE_VMAP() \
13 virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
14 const std::vector<array>& inputs, const std::vector<int>& axes) \
17#define DEFINE_GRADS() \
18 std::vector<array> jvp( \
19 const std::vector<array>& primals, \
20 const std::vector<array>& tangents, \
21 const std::vector<int>& argnums) override; \
23 std::vector<array> vjp( \
24 const std::vector<array>& primals, \
25 const std::vector<array>& cotangents, \
26 const std::vector<int>& argnums, \
27 const std::vector<array>& outputs) override;
29#define DEFINE_PRINT(PRIMITIVE) \
30 void print(std::ostream& os) override { \
34#define DEFINE_DEFAULT_IS_EQUIVALENT() \
35 bool is_equivalent(const Primitive& other) const override { \
39#define DEFINE_INPUT_OUTPUT_SHAPE() \
40 std::vector<Shape> output_shapes(const std::vector<array>& inputs) \
42 return {inputs[0].shape()}; \
70 const std::vector<array>& inputs,
71 std::vector<array>& outputs) = 0;
73 const std::vector<array>& inputs,
74 std::vector<array>& outputs) = 0;
79 virtual std::vector<array>
jvp(
80 const std::vector<array>& primals,
81 const std::vector<array>& tangents,
82 const std::vector<int>& argnums);
87 virtual std::vector<array>
vjp(
88 const std::vector<array>& primals,
89 const std::vector<array>& cotangents,
90 const std::vector<int>& argnums,
91 const std::vector<array>& outputs);
99 virtual std::pair<std::vector<array>, std::vector<int>>
vmap(
100 const std::vector<array>& inputs,
101 const std::vector<int>& axes);
104 virtual void print(std::ostream& os) = 0;
137 const std::vector<array>& inputs,
138 std::vector<array>& outputs)
override {
142 const std::vector<array>& inputs,
143 std::vector<array>& outputs)
override {
197 const std::vector<array>& primals,
198 const std::vector<array>& cotangents,
199 const std::vector<int>& argnums,
200 const std::vector<array>& outputs)
override;
207 return {alpha_, beta_};
227 return {start_, stop_, step_};
235 void eval(
const std::vector<array>& inputs,
array& out);
371 return {kth_, axis_};
378 void eval(
const std::vector<array>& inputs,
array& out);
400 return {reduce_type_, axis_};
404 ReduceType reduce_type_;
407 void eval(
const std::vector<array>& inputs,
array& out);
429 void eval(
const std::vector<array>& inputs,
array& out);
452 void eval(
const std::vector<array>& inputs,
array& out);
459 shape_(
std::move(shape)),
460 strides_(
std::move(strides)),
470 return std::make_tuple(shape_, strides_, offset_);
478 void eval(
const std::vector<array>& inputs,
array& out);
513 const std::vector<array>& primals,
514 const std::vector<array>& cotangents,
515 const std::vector<int>& argnums,
516 const std::vector<array>& outputs)
override;
527 void eval(
const std::vector<array>& inputs,
array& out);
538 const std::vector<array>& primals,
539 const std::vector<array>& cotangents,
540 const std::vector<int>& argnums,
541 const std::vector<array>& outputs)
override;
564 const
std::vector<
int>& ignore_axes);
571 void eval(
const std::vector<array>& inputs,
array& out);
572 std::vector<int> ignore_axes_;
596 void eval(
const std::vector<array>& inputs,
array& out);
629 std::vector<array> inputs,
630 std::vector<array> outputs,
631 std::vector<array> tape,
632 std::unordered_set<uintptr_t> constant_ids);
634 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
636 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
650 const std::vector<array> inputs_;
651 const std::vector<array> outputs_;
652 const std::vector<array> tape_;
653 const std::unordered_set<uintptr_t> constant_ids_;
655 std::string kernel_lib_;
678 void eval(
const std::vector<array>& inputs,
array& out);
713 bool allow_col_major_;
720 const std::vector<int>& kernel_strides,
721 const std::vector<int>& padding,
722 const std::vector<int>& kernel_dilation,
723 const std::vector<int>& input_dilation,
724 const int groups = 1,
725 const bool flip =
false)
728 kernel_strides_(kernel_strides),
729 kernel_dilation_(kernel_dilation),
730 input_dilation_(input_dilation),
738 const std::vector<array>& primals,
739 const std::vector<array>& cotangents,
740 const std::vector<int>& argnums,
741 const std::vector<array>& outputs)
override;
746 return std::make_tuple(
756 std::vector<int> padding_;
757 std::vector<int> kernel_strides_;
758 std::vector<int> kernel_dilation_;
759 std::vector<int> input_dilation_;
763 void eval(
const std::vector<array>& inputs,
array& out);
822 std::function<std::vector<array>(
823 const std::vector<array>&,
824 const std::vector<array>&,
825 const std::vector<array>&)>
vjp,
826 std::function<std::vector<array>(
827 const std::vector<array>&,
828 const std::vector<array>&,
829 const std::vector<int>&)>
jvp,
830 std::function<std::pair<std::vector<array>, std::vector<int>>(
831 const std::vector<array>&,
832 const std::vector<int>&)>
vmap)
834 num_outputs_(num_outputs),
839 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
841 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
849 void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
853 std::function<std::vector<array>(
854 const std::vector<array>&,
855 const std::vector<array>&,
856 const std::vector<array>&)>
858 std::function<std::vector<array>(
859 const std::vector<array>&,
860 const std::vector<array>&,
861 const std::vector<int>&)>
863 std::function<std::pair<std::vector<array>, std::vector<int>>(
864 const std::vector<array>&,
865 const std::vector<int>&)>
873 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
875 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
879 const std::vector<array>& primals,
880 const std::vector<array>& cotan,
881 const std::vector<int>& argnums,
882 const std::vector<array>& outputs)
override;
887 void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
911 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
913 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
921 return std::vector{inputs[0].shape(), inputs[0].shape()};
925 void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
987 void eval(
const std::vector<array>& inputs,
array& out);
1079 void eval(
const std::vector<array>& inputs,
array& out);
1080 std::vector<int> axes_;
1087 const std::vector<size_t>& axes,
1101 return std::make_tuple(axes_, inverse_, real_);
1105 std::vector<size_t> axes_;
1109 void eval(
const std::vector<array>& inputs,
array& out);
1128 return std::make_pair(start_axis_, end_axis_);
1134 void eval(
const std::vector<array>& inputs,
array& out);
1174 axes_(
std::move(axes)),
1175 slice_sizes_(
std::move(slice_sizes)) {}
1186 return {axes_, slice_sizes_};
1190 void eval(
const std::vector<array>& inputs,
array& out);
1191 std::vector<int> axes_;
1250 void eval(
const std::vector<array>& inputs,
array& out);
1305 std::shared_ptr<io::Reader> reader,
1307 bool swap_endianness =
false)
1309 reader_(
std::move(reader)),
1311 swap_endianness_(swap_endianness) {
1327 void eval(
const std::vector<array>& inputs,
array& out);
1328 std::shared_ptr<io::Reader> reader_;
1330 bool swap_endianness_;
1368 void eval(
const std::vector<array>& inputs,
array& out);
1463 const std::vector<array>& primals,
1464 const std::vector<array>& cotangents,
1465 const std::vector<int>& argnums,
1466 const std::vector<array>& outputs)
override;
1563 std::vector<int> axes,
1567 axes_(
std::move(axes)),
1568 inverted_(inverted),
1581 return {axes_, inverted_, dtype_};
1585 std::vector<int> axes_;
1589 void eval(
const std::vector<array>& inputs,
array& out);
1596 const std::vector<int>& axes,
1597 const Shape& low_pad_size,
1598 const Shape& high_pad_size)
1601 low_pad_size_(low_pad_size),
1602 high_pad_size_(high_pad_size) {}
1612 return std::make_tuple(axes_, low_pad_size_, high_pad_size_);
1616 std::vector<int> axes_;
1617 Shape low_pad_size_;
1618 Shape high_pad_size_;
1620 void eval(
const std::vector<array>& inputs,
array& out);
1637 return std::make_pair(kth_, axis_);
1644 void eval(
const std::vector<array>& inputs,
array& out);
1672 group_size_(group_size),
1685 return std::make_tuple(group_size_, bits_, transpose_);
1693 void eval(
const std::vector<array>& inputs,
array& out);
1700 group_size_(group_size),
1712 return std::make_tuple(group_size_, bits_, transpose_);
1720 void eval(
const std::vector<array>& inputs,
array& out);
1735 return {shape_, width_};
1742 void eval(
const std::vector<array>& inputs,
array& out);
1788 const std::vector<int>& axes)
1798 const
std::vector<
array>& cotangents,
1799 const
std::vector<
int>& argnums,
1800 const
std::vector<
array>& outputs) override;
1805 switch (reduce_type_) {
1827 std::pair<ReduceType, std::vector<int>>
state()
const {
1828 return {reduce_type_, axes_};
1832 ReduceType reduce_type_;
1833 std::vector<int> axes_;
1835 void eval(
const std::vector<array>& inputs,
array& out);
1866 reduce_type_(reduce_type),
1869 inclusive_(inclusive) {}
1879 switch (reduce_type_) {
1896 return std::make_tuple(reduce_type_, axis_, reverse_, inclusive_);
1900 ReduceType reduce_type_;
1905 void eval(
const std::vector<array>& inputs,
array& out);
1915 const std::vector<int>& axes)
1926 switch (reduce_type_) {
1944 std::pair<ReduceType, std::vector<int>>
state()
const {
1945 return {reduce_type_, axes_};
1949 void eval(
const std::vector<array>& inputs,
array& out);
1950 ReduceType reduce_type_;
1951 std::vector<int> axes_;
2026 const Shape& start_indices,
2027 const Shape& end_indices,
2028 const Shape& strides)
2030 start_indices_(start_indices),
2031 end_indices_(end_indices),
2032 strides_(strides) {}
2042 return std::make_tuple(start_indices_, end_indices_, strides_);
2046 Shape start_indices_;
2050 void eval(
const std::vector<array>& inputs,
array& out);
2057 const Shape& start_indices,
2058 const Shape& end_indices,
2059 const Shape& strides)
2061 start_indices_(start_indices),
2062 end_indices_(end_indices),
2063 strides_(strides) {}
2074 return std::make_tuple(start_indices_, end_indices_, strides_);
2078 Shape start_indices_;
2082 void eval(
const std::vector<array>& inputs,
array& out);
2089 axes_(
std::move(axes)),
2090 slice_size_(
std::move(slice_size)) {}
2101 return std::make_pair(axes_, slice_size_);
2105 std::vector<int> axes_;
2127 std::vector<int> axes_;
2149 void eval(
const std::vector<array>& inputs,
array& out);
2173 void eval(
const std::vector<array>& inputs,
array& out);
2181 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2183 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2191 return {indices_, axis_};
2195 void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
2243 void eval(
const std::vector<array>& inputs,
array& out);
2301 void eval(
const std::vector<array>& inputs,
array& out);
2302 std::vector<int> axes_;
2355 return std::make_pair(axis_, shape_);
2361 void eval(
const std::vector<array>& inputs,
array& out);
2401 std::vector<int> axes_;
2403 void eval(
const std::vector<array>& inputs,
array& out);
2411 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2413 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2419 void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
2427 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2429 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2451 return std::make_pair(tri_, upper_);
2455 void eval(
const std::vector<array>& inputs,
array& output);
2483 uplo_(
std::move(uplo)),
2484 compute_eigenvectors_(compute_eigenvectors) {}
2486 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2488 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2498 return std::make_pair(uplo_, compute_eigenvectors_);
2502 void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
2504 bool compute_eigenvectors_;
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Abs(Stream stream)
Definition primitives.h:156
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Add(Stream stream)
Definition primitives.h:173
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::pair< float, float > state() const
Definition primitives.h:206
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
AddMM(Stream stream, float alpha, float beta)
Definition primitives.h:190
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Arange(Stream stream, double start, double stop, double step)
Definition primitives.h:217
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::tuple< double, double, double > state() const
Definition primitives.h:226
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCos(Stream stream)
Definition primitives.h:240
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCosh(Stream stream)
Definition primitives.h:257
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcSin(Stream stream)
Definition primitives.h:274
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcSinh(Stream stream)
Definition primitives.h:291
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTan2(Stream stream)
Definition primitives.h:325
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcTan(Stream stream)
Definition primitives.h:308
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTanh(Stream stream)
Definition primitives.h:342
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::pair< int, int > state() const
Definition primitives.h:370
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArgPartition(Stream stream, int kth, int axis)
Definition primitives.h:359
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
ReduceType
Definition primitives.h:383
@ ArgMin
Definition primitives.h:384
@ ArgMax
Definition primitives.h:385
ArgReduce(Stream stream, ReduceType reduce_type, int axis)
Definition primitives.h:388
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::pair< ReduceType, int > state() const
Definition primitives.h:399
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
ArgSort(Stream stream, int axis)
Definition primitives.h:412
int state() const
Definition primitives.h:422
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:469
AsStrided(Stream stream, Shape shape, Strides strides, size_t offset)
Definition primitives.h:457
void eval_gpu(const std::vector< array > &inputs, array &out) override
AsType(Stream stream, Dtype dtype)
Definition primitives.h:434
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Dtype state() const
Definition primitives.h:445
void eval_cpu(const std::vector< array > &inputs, array &out) override
BitwiseBinary(Stream stream, Op op)
Definition primitives.h:485
void eval_cpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Op
Definition primitives.h:483
@ RightShift
Definition primitives.h:483
@ Or
Definition primitives.h:483
@ LeftShift
Definition primitives.h:483
@ And
Definition primitives.h:483
@ Xor
Definition primitives.h:483
auto state() const
Definition primitives.h:496
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
auto state() const
Definition primitives.h:520
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
BlockMaskedMM(Stream stream, int block_size)
Definition primitives.h:506
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
BroadcastAxes(Stream stream, std::vector< int > ignore_axes={})
Definition primitives.h:552
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:566
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const std::vector< array > &inputs, const std::vector< int > &ignore_axes)
Broadcast(Stream stream, const Shape &shape)
Definition primitives.h:577
static Shape output_shape(const std::vector< array > &inputs)
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< int > state() const
Definition primitives.h:589
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Ceil(Stream stream)
Definition primitives.h:601
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2467
Cholesky(Stream stream, bool upper)
Definition primitives.h:2462
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void print(std::ostream &os) override
Print the primitive.
Compiled(Stream stream, std::vector< array > inputs, std::vector< array > outputs, std::vector< array > tape, std::unordered_set< uintptr_t > constant_ids)
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
std::string lib_name() const
Definition primitives.h:645
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:671
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Concatenate(Stream stream, int axis)
Definition primitives.h:660
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Conjugate(Stream stream)
Definition primitives.h:683
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Contiguous(Stream stream, bool allow_col_major)
Definition primitives.h:699
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
Convolution(Stream stream, const std::vector< int > &kernel_strides, const std::vector< int > &padding, const std::vector< int > &kernel_dilation, const std::vector< int > &input_dilation, const int groups=1, const bool flip=false)
Definition primitives.h:718
auto state() const
Definition primitives.h:745
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
Copy(Stream stream)
Definition primitives.h:768
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cos(Stream stream)
Definition primitives.h:785
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Cosh(Stream stream)
Definition primitives.h:802
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Depends(Stream stream)
Definition primitives.h:871
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:920
DivMod(Stream stream)
Definition primitives.h:909
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Divide(Stream stream)
Definition primitives.h:892
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
DynamicSlice(Stream stream, std::vector< int > axes, Shape slice_size)
Definition primitives.h:2087
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2100
auto state() const
Definition primitives.h:2122
DynamicSliceUpdate(Stream stream, std::vector< int > axes)
Definition primitives.h:2111
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
auto state() const
Definition primitives.h:2497
Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)
Definition primitives.h:2481
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:975
Equal(Stream stream, bool equal_nan=false)
Definition primitives.h:964
auto state() const
Definition primitives.h:982
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Erf(Stream stream)
Definition primitives.h:993
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ErfInv(Stream stream)
Definition primitives.h:1010
void eval_cpu(const std::vector< array > &inputs, array &out) override
Exp(Stream stream)
Definition primitives.h:1027
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const array &input, const std::vector< int > &axes)
auto state() const
Definition primitives.h:1074
void eval_gpu(const std::vector< array > &inputs, array &out) override
ExpandDims(Stream stream, std::vector< int > axes)
Definition primitives.h:1060
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Expm1(Stream stream)
Definition primitives.h:1044
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
FFT(Stream stream, const std::vector< size_t > &axes, bool inverse, bool real)
Definition primitives.h:1085
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1100
static Shape output_shape(const array &input, int start_axis, int end_axis)
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
Flatten(Stream stream, int start_axis, int end_axis)
Definition primitives.h:1114
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1127
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Floor(Stream stream)
Definition primitives.h:1139
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Full(Stream stream)
Definition primitives.h:1156
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::pair< std::vector< int >, std::vector< int > > state() const
Definition primitives.h:1185
Gather(Stream stream, std::vector< int > axes, Shape slice_sizes)
Definition primitives.h:1172
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
GatherMM(Stream stream)
Definition primitives.h:532
auto state() const
Definition primitives.h:1711
GatherQMM(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1698
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
GreaterEqual(Stream stream)
Definition primitives.h:1214
void eval_gpu(const std::vector< array > &inputs, array &out) override
Greater(Stream stream)
Definition primitives.h:1197
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
Hadamard(Stream stream, float scale)
Definition primitives.h:1231
auto state() const
Definition primitives.h:1243
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Imag(Stream stream)
Definition primitives.h:1255
void eval_gpu(const std::vector< array > &inputs, array &output) override
Inverse(Stream stream, bool tri, bool upper)
Definition primitives.h:2442
auto state() const
Definition primitives.h:2450
void eval_cpu(const std::vector< array > &inputs, array &output) override
LessEqual(Stream stream)
Definition primitives.h:1286
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Less(Stream stream)
Definition primitives.h:1269
void eval_gpu(const std::vector< array > &inputs, array &out) override
Load(Stream stream, std::shared_ptr< io::Reader > reader, size_t offset, bool swap_endianness=false)
Definition primitives.h:1303
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Log1p(Stream stream)
Definition primitives.h:1373
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogAddExp(Stream stream)
Definition primitives.h:1440
Base
Definition primitives.h:1335
@ ten
Definition primitives.h:1335
@ two
Definition primitives.h:1335
@ e
Definition primitives.h:1335
Log(Stream stream, Base base)
Definition primitives.h:1337
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1352
Base state() const
Definition primitives.h:1348
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalAnd(Stream stream)
Definition primitives.h:1406
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalNot(Stream stream)
Definition primitives.h:1389
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
LogicalOr(Stream stream)
Definition primitives.h:1423
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Matmul(Stream stream)
Definition primitives.h:1457
Maximum(Stream stream)
Definition primitives.h:1476
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Minimum(Stream stream)
Definition primitives.h:1493
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Multiply(Stream stream)
Definition primitives.h:1510
void eval_gpu(const std::vector< array > &inputs, array &out) override
Negative(Stream stream)
Definition primitives.h:1527
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
NotEqual(Stream stream)
Definition primitives.h:1544
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:1577
NumberOfElements(Stream stream, std::vector< int > axes, bool inverted, Dtype dtype)
Definition primitives.h:1561
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::tuple< std::vector< int >, bool, Dtype > state() const
Definition primitives.h:1580
auto state() const
Definition primitives.h:1611
Pad(Stream stream, const std::vector< int > &axes, const Shape &low_pad_size, const Shape &high_pad_size)
Definition primitives.h:1594
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Partition(Stream stream, int kth, int axis)
Definition primitives.h:1625
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
auto state() const
Definition primitives.h:1636
void eval_cpu(const std::vector< array > &inputs, array &out) override
Power(Stream stream)
Definition primitives.h:1649
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:48
virtual void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
virtual std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
The vector-Jacobian product.
virtual ~Primitive()=default
Primitive(const Primitive &other)=delete
Primitive(Primitive &&other)=delete
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
Primitive & operator=(Primitive &&other)=delete
virtual bool is_equivalent(const Primitive &other) const
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:107
Primitive & operator=(const Primitive &other)=delete
const Device & device()
The device the primitive will run on.
Definition primitives.h:53
virtual std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
The Jacobian-vector product.
virtual std::vector< Shape > output_shapes(const std::vector< array > &inputs)
Get the output shapes of the primitive.
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes)
The primitive must know how to vectorize itself across the given axes.
virtual void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
virtual void print(std::ostream &os)=0
Print the primitive.
Primitive(Stream stream)
Definition primitives.h:50
QRF(Stream stream)
Definition primitives.h:2409
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
QuantizedMatmul(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1666
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1684
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::pair< std::vector< int >, int > state() const
Definition primitives.h:1734
RandomBits(Stream stream, const Shape &shape, int width)
Definition primitives.h:1725
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Real(Stream stream)
Definition primitives.h:1747
Reduce(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1785
ReduceType
Definition primitives.h:1783
@ Min
Definition primitives.h:1783
@ Or
Definition primitives.h:1783
@ Max
Definition primitives.h:1783
@ And
Definition primitives.h:1783
@ Sum
Definition primitives.h:1783
@ Prod
Definition primitives.h:1783
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1804
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::pair< ReduceType, std::vector< int > > state() const
Definition primitives.h:1827
Remainder(Stream stream)
Definition primitives.h:947
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const array &input, Shape shape)
void eval_gpu(const std::vector< array > &inputs, array &out) override
Reshape(Stream stream, const Shape &shape)
Definition primitives.h:1761
std::vector< int > state() const
Definition primitives.h:1771
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Round(Stream stream)
Definition primitives.h:1840
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
SVD(Stream stream)
Definition primitives.h:2425
void eval_cpu(const std::vector< array > &inputs, array &out) override
ReduceType
Definition primitives.h:1857
@ Prod
Definition primitives.h:1857
@ Min
Definition primitives.h:1857
@ Max
Definition primitives.h:1857
@ Sum
Definition primitives.h:1857
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
auto state() const
Definition primitives.h:1895
Scan(Stream stream, ReduceType reduce_type, int axis, bool reverse, bool inclusive)
Definition primitives.h:1859
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1877
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::pair< ReduceType, std::vector< int > > state() const
Definition primitives.h:1944
ReduceType
Definition primitives.h:1910
@ Sum
Definition primitives.h:1910
@ Max
Definition primitives.h:1910
@ Prod
Definition primitives.h:1910
@ None
Definition primitives.h:1910
@ Min
Definition primitives.h:1910
void eval_cpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1924
void eval_gpu(const std::vector< array > &inputs, array &out) override
Scatter(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1912
void eval_gpu(const std::vector< array > &inputs, array &out) override
Select(Stream stream)
Definition primitives.h:930
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sigmoid(Stream stream)
Definition primitives.h:1956
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Sign(Stream stream)
Definition primitives.h:1973
Sin(Stream stream)
Definition primitives.h:1990
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sinh(Stream stream)
Definition primitives.h:2007
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2041
Slice(Stream stream, const Shape &start_indices, const Shape &end_indices, const Shape &strides)
Definition primitives.h:2024
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
SliceUpdate(Stream stream, const Shape &start_indices, const Shape &end_indices, const Shape &strides)
Definition primitives.h:2055
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2073
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Softmax(Stream stream, bool precise)
Definition primitives.h:2132
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2144
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2166
Sort(Stream stream, int axis)
Definition primitives.h:2155
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
std::pair< std::vector< int >, int > state() const
Definition primitives.h:2190
Split(Stream stream, const Shape &indices, int axis)
Definition primitives.h:2178
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
auto state() const
Definition primitives.h:2230
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sqrt(Stream stream, bool recip=false)
Definition primitives.h:2220
void eval_gpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:2234
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Square(Stream stream)
Definition primitives.h:2203
Squeeze(Stream stream, std::vector< int > axes)
Definition primitives.h:2282
auto state() const
Definition primitives.h:2296
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const array &input, const std::vector< int > &axes)
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
StopGradient(Stream stream)
Definition primitives.h:2249
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Subtract(Stream stream)
Definition primitives.h:2265
Tan(Stream stream)
Definition primitives.h:2307
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Tanh(Stream stream)
Definition primitives.h:2324
void eval_cpu(const std::vector< array > &inputs, array &out) override
Transpose(Stream stream, const std::vector< int > &axes)
Definition primitives.h:2385
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< int > state() const
Definition primitives.h:2396
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:126
UnaryPrimitive & operator=(const UnaryPrimitive &other)=delete
UnaryPrimitive(Stream stream)
An abstract base class for a primitive with a single output.
Definition primitives.h:131
virtual void eval_gpu(const std::vector< array > &inputs, array &output)=0
UnaryPrimitive(UnaryPrimitive &&other)=delete
virtual void eval_cpu(const std::vector< array > &inputs, array &output)=0
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:141
UnaryPrimitive(const UnaryPrimitive &other)=delete
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:136
UnaryPrimitive & operator=(UnaryPrimitive &&other)=delete
virtual ~UnaryPrimitive()=default
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Unflatten(Stream stream, int axis, Shape shape)
Definition primitives.h:2341
static Shape output_shape(const array &input, int axis, const Shape &shape)
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2354
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2375
void print(std::ostream &os) override
Print the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
View(Stream stream, Dtype dtype)
Definition primitives.h:2366
void eval_gpu(const std::vector< array > &inputs, array &out) override
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
array tri(int n, int m, int k, Dtype type, StreamOrDevice s={})
array transpose(const array &a, std::vector< int > axes, StreamOrDevice s={})
Permutes the dimensions according to the given axes.
array real(const array &a, StreamOrDevice s={})
std::vector< ShapeElem > Shape
Definition array.h:21
Stream new_stream(Device d)
Make a new stream on the given device.
std::vector< int64_t > Strides
Definition array.h:22
void eval(std::vector< array > outputs)
#define DEFINE_DEFAULT_IS_EQUIVALENT()
Definition primitives.h:34
#define DEFINE_PRINT(PRIMITIVE)
Definition primitives.h:29
#define DEFINE_INPUT_OUTPUT_SHAPE()
Definition primitives.h:39
#define DEFINE_GRADS()
Definition primitives.h:17
#define DEFINE_VMAP()
Definition primitives.h:12
static constexpr DeviceType gpu
Definition device.h:14
static constexpr DeviceType cpu
Definition device.h:13
Device device
Definition stream.h:11