5#include <unordered_set>
12#define DEFINE_VMAP() \
13 virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
14 const std::vector<array>& inputs, const std::vector<int>& axes) \
17#define DEFINE_GRADS() \
18 std::vector<array> jvp( \
19 const std::vector<array>& primals, \
20 const std::vector<array>& tangents, \
21 const std::vector<int>& argnums) override; \
23 std::vector<array> vjp( \
24 const std::vector<array>& primals, \
25 const std::vector<array>& cotangents, \
26 const std::vector<int>& argnums, \
27 const std::vector<array>& outputs) override;
29#define DEFINE_PRINT(PRIMITIVE) \
30 void print(std::ostream& os) override { \
34#define DEFINE_DEFAULT_IS_EQUIVALENT() \
35 bool is_equivalent(const Primitive& other) const override { \
39#define DEFINE_INPUT_OUTPUT_SHAPE() \
40 std::vector<Shape> output_shapes(const std::vector<array>& inputs) \
42 return {inputs[0].shape()}; \
70 const std::vector<array>& inputs,
71 std::vector<array>& outputs) = 0;
73 const std::vector<array>& inputs,
74 std::vector<array>& outputs) = 0;
79 virtual std::vector<array>
jvp(
80 const std::vector<array>& primals,
81 const std::vector<array>& tangents,
82 const std::vector<int>& argnums);
87 virtual std::vector<array>
vjp(
88 const std::vector<array>& primals,
89 const std::vector<array>& cotangents,
90 const std::vector<int>& argnums,
91 const std::vector<array>& outputs);
99 virtual std::pair<std::vector<array>, std::vector<int>>
vmap(
100 const std::vector<array>& inputs,
101 const std::vector<int>& axes);
104 virtual void print(std::ostream& os) = 0;
137 const std::vector<array>& inputs,
138 std::vector<array>& outputs)
override {
142 const std::vector<array>& inputs,
143 std::vector<array>& outputs)
override {
196 return {alpha_, beta_};
216 return {start_, stop_, step_};
337 return {kth_, axis_};
364 return {reduce_type_, axis_};
368 ReduceType reduce_type_;
417 shape_(
std::move(shape)),
418 strides_(
std::move(strides)),
428 return std::make_tuple(shape_, strides_, offset_);
436 void eval(
const std::vector<array>& inputs,
array& out);
484 const std::vector<array>& primals,
485 const std::vector<array>& cotangents,
486 const std::vector<int>& argnums,
487 const std::vector<array>& outputs)
override;
507 const std::vector<array>& primals,
508 const std::vector<array>& cotangents,
509 const std::vector<int>& argnums,
510 const std::vector<array>& outputs)
override;
530 const
std::vector<
int>& ignore_axes);
537 void eval(
const std::vector<array>& inputs,
array& out);
538 std::vector<int> ignore_axes_;
562 void eval(
const std::vector<array>& inputs,
array& out);
592 std::vector<array> inputs,
593 std::vector<array> outputs,
594 std::vector<array> tape,
595 std::unordered_set<uintptr_t> constant_ids);
597 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
599 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
613 const std::vector<array> inputs_;
614 const std::vector<array> outputs_;
615 const std::vector<array> tape_;
616 const std::unordered_set<uintptr_t> constant_ids_;
618 std::string kernel_lib_;
671 bool allow_col_major_;
678 const std::vector<int>& kernel_strides,
679 const std::vector<int>& padding,
680 const std::vector<int>& kernel_dilation,
681 const std::vector<int>& input_dilation,
682 const int groups = 1,
683 const bool flip =
false)
686 kernel_strides_(kernel_strides),
687 kernel_dilation_(kernel_dilation),
688 input_dilation_(input_dilation),
696 const std::vector<array>& primals,
697 const std::vector<array>& cotangents,
698 const std::vector<int>& argnums,
699 const std::vector<array>& outputs)
override;
704 return std::make_tuple(
714 std::vector<int> padding_;
715 std::vector<int> kernel_strides_;
716 std::vector<int> kernel_dilation_;
717 std::vector<int> input_dilation_;
772 std::function<std::vector<array>(
773 const std::vector<array>&,
774 const std::vector<array>&,
775 const std::vector<array>&)>
vjp,
776 std::function<std::vector<array>(
777 const std::vector<array>&,
778 const std::vector<array>&,
779 const std::vector<int>&)>
jvp,
780 std::function<std::pair<std::vector<array>, std::vector<int>>(
781 const std::vector<array>&,
782 const std::vector<int>&)>
vmap)
784 num_outputs_(num_outputs),
789 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
791 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
799 void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
803 std::function<std::vector<array>(
804 const std::vector<array>&,
805 const std::vector<array>&,
806 const std::vector<array>&)>
808 std::function<std::vector<array>(
809 const std::vector<array>&,
810 const std::vector<array>&,
811 const std::vector<int>&)>
813 std::function<std::pair<std::vector<array>, std::vector<int>>(
814 const std::vector<array>&,
815 const std::vector<int>&)>
823 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
825 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
829 const std::vector<array>& primals,
830 const std::vector<array>& cotan,
831 const std::vector<int>& argnums,
832 const std::vector<array>& outputs)
override;
837 void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
858 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
860 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
868 return std::vector{inputs[0].shape(), inputs[0].shape()};
1004 void eval(
const std::vector<array>& inputs,
array& out);
1005 std::vector<int> axes_;
1012 const std::vector<size_t>& axes,
1026 return std::make_tuple(axes_, inverse_, real_);
1030 std::vector<size_t> axes_;
1051 return std::make_pair(start_axis_, end_axis_);
1057 void eval(
const std::vector<array>& inputs,
array& out);
1091 axes_(
std::move(axes)),
1092 slice_sizes_(
std::move(slice_sizes)) {}
1103 return {axes_, slice_sizes_};
1107 std::vector<int> axes_;
1228 std::shared_ptr<io::Reader> reader,
1230 bool swap_endianness =
false)
1232 reader_(
std::move(reader)),
1234 swap_endianness_(swap_endianness) {
1250 std::shared_ptr<io::Reader> reader_;
1252 bool swap_endianness_;
1449 std::vector<int> axes,
1453 axes_(
std::move(axes)),
1454 inverted_(inverted),
1467 return {axes_, inverted_, dtype_};
1471 std::vector<int> axes_;
1475 void eval(
const std::vector<array>& inputs,
array& out);
1482 const std::vector<int>& axes,
1483 const Shape& low_pad_size,
1484 const Shape& high_pad_size)
1487 low_pad_size_(low_pad_size),
1488 high_pad_size_(high_pad_size) {}
1498 return std::make_tuple(axes_, low_pad_size_, high_pad_size_);
1502 std::vector<int> axes_;
1503 Shape low_pad_size_;
1504 Shape high_pad_size_;
1521 return std::make_pair(kth_, axis_);
1551 group_size_(group_size),
1564 return std::make_tuple(group_size_, bits_, transpose_);
1577 group_size_(group_size),
1589 return std::make_tuple(group_size_, bits_, transpose_);
1610 return {shape_, width_};
1661 const std::vector<int>& axes)
1671 const
std::vector<
array>& cotangents,
1672 const
std::vector<
int>& argnums,
1673 const
std::vector<
array>& outputs) override;
1678 switch (reduce_type_) {
1700 std::pair<ReduceType, std::vector<int>>
state()
const {
1701 return {reduce_type_, axes_};
1705 ReduceType reduce_type_;
1706 std::vector<int> axes_;
1734 reduce_type_(reduce_type),
1737 inclusive_(inclusive) {}
1747 switch (reduce_type_) {
1764 return std::make_tuple(reduce_type_, axis_, reverse_, inclusive_);
1768 ReduceType reduce_type_;
1781 const std::vector<int>& axes)
1792 switch (reduce_type_) {
1810 std::pair<ReduceType, std::vector<int>>
state()
const {
1811 return {reduce_type_, axes_};
1815 ReduceType reduce_type_;
1816 std::vector<int> axes_;
1833 os <<
"ScatterAxis";
1834 switch (reduce_type_) {
1845 std::pair<ReduceType, int>
state()
const {
1846 return {reduce_type_, axis_};
1850 ReduceType reduce_type_;
1914 const Shape& start_indices,
1915 const Shape& end_indices,
1916 const Shape& strides)
1918 start_indices_(start_indices),
1919 end_indices_(end_indices),
1920 strides_(strides) {}
1930 return std::make_tuple(start_indices_, end_indices_, strides_);
1934 Shape start_indices_;
1943 const Shape& start_indices,
1944 const Shape& end_indices,
1945 const Shape& strides)
1947 start_indices_(start_indices),
1948 end_indices_(end_indices),
1949 strides_(strides) {}
1960 return std::make_tuple(start_indices_, end_indices_, strides_);
1964 Shape start_indices_;
1973 axes_(
std::move(axes)),
1974 slice_size_(
std::move(slice_size)) {}
1985 return std::make_pair(axes_, slice_size_);
1989 std::vector<int> axes_;
2011 std::vector<int> axes_;
2062 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2064 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2072 return {indices_, axis_};
2076 void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
2175 void eval(
const std::vector<array>& inputs,
array& out);
2176 std::vector<int> axes_;
2223 return std::make_pair(axis_, shape_);
2229 void eval(
const std::vector<array>& inputs,
array& out);
2269 std::vector<int> axes_;
2271 void eval(
const std::vector<array>& inputs,
array& out);
2279 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2281 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2292 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2294 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2313 return std::make_pair(tri_, upper_);
2343 uplo_(
std::move(uplo)),
2344 compute_eigenvectors_(compute_eigenvectors) {}
2345 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2347 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2357 return std::make_pair(uplo_, compute_eigenvectors_);
2362 bool compute_eigenvectors_;
2369 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
2371 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Abs(Stream stream)
Definition primitives.h:156
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Add(Stream stream)
Definition primitives.h:170
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::pair< float, float > state() const
Definition primitives.h:195
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
AddMM(Stream stream, float alpha, float beta)
Definition primitives.h:184
Arange(Stream stream, double start, double stop, double step)
Definition primitives.h:206
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::tuple< double, double, double > state() const
Definition primitives.h:215
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCos(Stream stream)
Definition primitives.h:227
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCosh(Stream stream)
Definition primitives.h:241
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcSin(Stream stream)
Definition primitives.h:255
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcSinh(Stream stream)
Definition primitives.h:269
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTan2(Stream stream)
Definition primitives.h:297
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcTan(Stream stream)
Definition primitives.h:283
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTanh(Stream stream)
Definition primitives.h:311
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::pair< int, int > state() const
Definition primitives.h:336
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArgPartition(Stream stream, int kth, int axis)
Definition primitives.h:325
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
ReduceType
Definition primitives.h:347
@ ArgMin
Definition primitives.h:348
@ ArgMax
Definition primitives.h:349
ArgReduce(Stream stream, ReduceType reduce_type, int axis)
Definition primitives.h:352
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::pair< ReduceType, int > state() const
Definition primitives.h:363
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
ArgSort(Stream stream, int axis)
Definition primitives.h:374
int state() const
Definition primitives.h:384
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:427
AsStrided(Stream stream, Shape shape, Strides strides, size_t offset)
Definition primitives.h:415
void eval_gpu(const std::vector< array > &inputs, array &out) override
AsType(Stream stream, Dtype dtype)
Definition primitives.h:394
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Dtype state() const
Definition primitives.h:405
void eval_cpu(const std::vector< array > &inputs, array &out) override
BitwiseBinary(Stream stream, Op op)
Definition primitives.h:443
void eval_cpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Op
Definition primitives.h:441
@ RightShift
Definition primitives.h:441
@ Or
Definition primitives.h:441
@ LeftShift
Definition primitives.h:441
@ And
Definition primitives.h:441
@ Xor
Definition primitives.h:441
auto state() const
Definition primitives.h:454
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
BitwiseInvert(Stream stream)
Definition primitives.h:464
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
auto state() const
Definition primitives.h:491
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
BlockMaskedMM(Stream stream, int block_size)
Definition primitives.h:477
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
BroadcastAxes(Stream stream, std::vector< int > ignore_axes={})
Definition primitives.h:518
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:532
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const std::vector< array > &inputs, const std::vector< int > &ignore_axes)
Broadcast(Stream stream, const Shape &shape)
Definition primitives.h:543
static Shape output_shape(const std::vector< array > &inputs)
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< int > state() const
Definition primitives.h:555
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Ceil(Stream stream)
Definition primitives.h:567
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2328
Cholesky(Stream stream, bool upper)
Definition primitives.h:2323
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void print(std::ostream &os) override
Print the primitive.
Compiled(Stream stream, std::vector< array > inputs, std::vector< array > outputs, std::vector< array > tape, std::unordered_set< uintptr_t > constant_ids)
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
std::string lib_name() const
Definition primitives.h:608
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:634
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Concatenate(Stream stream, int axis)
Definition primitives.h:623
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Conjugate(Stream stream)
Definition primitives.h:644
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Contiguous(Stream stream, bool allow_col_major)
Definition primitives.h:657
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
Convolution(Stream stream, const std::vector< int > &kernel_strides, const std::vector< int > &padding, const std::vector< int > &kernel_dilation, const std::vector< int > &input_dilation, const int groups=1, const bool flip=false)
Definition primitives.h:676
auto state() const
Definition primitives.h:703
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
Copy(Stream stream)
Definition primitives.h:724
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cos(Stream stream)
Definition primitives.h:741
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Cosh(Stream stream)
Definition primitives.h:755
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Depends(Stream stream)
Definition primitives.h:821
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:867
DivMod(Stream stream)
Definition primitives.h:856
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Divide(Stream stream)
Definition primitives.h:842
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
DynamicSlice(Stream stream, std::vector< int > axes, Shape slice_size)
Definition primitives.h:1971
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1984
auto state() const
Definition primitives.h:2006
DynamicSliceUpdate(Stream stream, std::vector< int > axes)
Definition primitives.h:1995
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
auto state() const
Definition primitives.h:2356
Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)
Definition primitives.h:2341
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:913
Equal(Stream stream, bool equal_nan=false)
Definition primitives.h:902
auto state() const
Definition primitives.h:920
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Erf(Stream stream)
Definition primitives.h:930
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ErfInv(Stream stream)
Definition primitives.h:944
void eval_cpu(const std::vector< array > &inputs, array &out) override
Exp(Stream stream)
Definition primitives.h:958
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const array &input, const std::vector< int > &axes)
auto state() const
Definition primitives.h:999
void eval_gpu(const std::vector< array > &inputs, array &out) override
ExpandDims(Stream stream, std::vector< int > axes)
Definition primitives.h:985
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Expm1(Stream stream)
Definition primitives.h:972
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
FFT(Stream stream, const std::vector< size_t > &axes, bool inverse, bool real)
Definition primitives.h:1010
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1025
static Shape output_shape(const array &input, int start_axis, int end_axis)
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
Flatten(Stream stream, int start_axis, int end_axis)
Definition primitives.h:1037
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1050
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Floor(Stream stream)
Definition primitives.h:1062
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Full(Stream stream)
Definition primitives.h:1076
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
GatherAxis(Stream stream, int axis)
Definition primitives.h:1113
auto state() const
Definition primitives.h:1124
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::pair< std::vector< int >, std::vector< int > > state() const
Definition primitives.h:1102
Gather(Stream stream, std::vector< int > axes, Shape slice_sizes)
Definition primitives.h:1089
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
GatherMM(Stream stream)
Definition primitives.h:501
auto state() const
Definition primitives.h:1588
GatherQMM(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1575
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
GreaterEqual(Stream stream)
Definition primitives.h:1148
void eval_gpu(const std::vector< array > &inputs, array &out) override
Greater(Stream stream)
Definition primitives.h:1134
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
Hadamard(Stream stream, float scale)
Definition primitives.h:1162
auto state() const
Definition primitives.h:1174
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Imag(Stream stream)
Definition primitives.h:1184
void eval_gpu(const std::vector< array > &inputs, array &output) override
Inverse(Stream stream, bool tri, bool upper)
Definition primitives.h:2304
auto state() const
Definition primitives.h:2312
void eval_cpu(const std::vector< array > &inputs, array &output) override
LUF(Stream stream)
Definition primitives.h:2368
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
LessEqual(Stream stream)
Definition primitives.h:1212
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Less(Stream stream)
Definition primitives.h:1198
void eval_gpu(const std::vector< array > &inputs, array &out) override
Load(Stream stream, std::shared_ptr< io::Reader > reader, size_t offset, bool swap_endianness=false)
Definition primitives.h:1226
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Log1p(Stream stream)
Definition primitives.h:1294
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogAddExp(Stream stream)
Definition primitives.h:1349
Base
Definition primitives.h:1257
@ ten
Definition primitives.h:1257
@ two
Definition primitives.h:1257
@ e
Definition primitives.h:1257
Log(Stream stream, Base base)
Definition primitives.h:1259
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1274
Base state() const
Definition primitives.h:1270
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalAnd(Stream stream)
Definition primitives.h:1321
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalNot(Stream stream)
Definition primitives.h:1307
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
LogicalOr(Stream stream)
Definition primitives.h:1335
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Matmul(Stream stream)
Definition primitives.h:1363
Maximum(Stream stream)
Definition primitives.h:1377
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Minimum(Stream stream)
Definition primitives.h:1391
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Multiply(Stream stream)
Definition primitives.h:1405
void eval_gpu(const std::vector< array > &inputs, array &out) override
Negative(Stream stream)
Definition primitives.h:1419
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
NotEqual(Stream stream)
Definition primitives.h:1433
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:1463
NumberOfElements(Stream stream, std::vector< int > axes, bool inverted, Dtype dtype)
Definition primitives.h:1447
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::tuple< std::vector< int >, bool, Dtype > state() const
Definition primitives.h:1466
auto state() const
Definition primitives.h:1497
Pad(Stream stream, const std::vector< int > &axes, const Shape &low_pad_size, const Shape &high_pad_size)
Definition primitives.h:1480
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Partition(Stream stream, int kth, int axis)
Definition primitives.h:1509
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
auto state() const
Definition primitives.h:1520
void eval_cpu(const std::vector< array > &inputs, array &out) override
Power(Stream stream)
Definition primitives.h:1531
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:48
virtual void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
virtual std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
The vector-Jacobian product.
virtual ~Primitive()=default
Primitive(const Primitive &other)=delete
Primitive(Primitive &&other)=delete
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
Primitive & operator=(Primitive &&other)=delete
virtual bool is_equivalent(const Primitive &other) const
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:107
Primitive & operator=(const Primitive &other)=delete
const Device & device()
The device the primitive will run on.
Definition primitives.h:53
virtual std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
The Jacobian-vector product.
virtual std::vector< Shape > output_shapes(const std::vector< array > &inputs)
Get the output shapes of the primitive.
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes)
The primitive must know how to vectorize itself across the given axes.
virtual void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
virtual void print(std::ostream &os)=0
Print the primitive.
Primitive(Stream stream)
Definition primitives.h:50
QRF(Stream stream)
Definition primitives.h:2277
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
QuantizedMatmul(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1545
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1563
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::pair< std::vector< int >, int > state() const
Definition primitives.h:1609
RandomBits(Stream stream, const Shape &shape, int width)
Definition primitives.h:1600
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Real(Stream stream)
Definition primitives.h:1620
Reduce(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1658
ReduceType
Definition primitives.h:1656
@ Min
Definition primitives.h:1656
@ Or
Definition primitives.h:1656
@ Max
Definition primitives.h:1656
@ And
Definition primitives.h:1656
@ Sum
Definition primitives.h:1656
@ Prod
Definition primitives.h:1656
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1677
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::pair< ReduceType, std::vector< int > > state() const
Definition primitives.h:1700
Remainder(Stream stream)
Definition primitives.h:888
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const array &input, Shape shape)
void eval_gpu(const std::vector< array > &inputs, array &out) override
Reshape(Stream stream, const Shape &shape)
Definition primitives.h:1634
std::vector< int > state() const
Definition primitives.h:1644
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Round(Stream stream)
Definition primitives.h:1711
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
SVD(Stream stream)
Definition primitives.h:2290
void eval_cpu(const std::vector< array > &inputs, array &out) override
ReduceType
Definition primitives.h:1725
@ Prod
Definition primitives.h:1725
@ Min
Definition primitives.h:1725
@ Max
Definition primitives.h:1725
@ Sum
Definition primitives.h:1725
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
auto state() const
Definition primitives.h:1763
Scan(Stream stream, ReduceType reduce_type, int axis, bool reverse, bool inclusive)
Definition primitives.h:1727
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1745
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::pair< ReduceType, int > state() const
Definition primitives.h:1845
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1832
void eval_gpu(const std::vector< array > &inputs, array &out) override
ScatterAxis(Stream stream, ReduceType reduce_type, int axis)
Definition primitives.h:1823
ReduceType
Definition primitives.h:1821
@ Sum
Definition primitives.h:1821
@ None
Definition primitives.h:1821
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::pair< ReduceType, std::vector< int > > state() const
Definition primitives.h:1810
ReduceType
Definition primitives.h:1776
@ Sum
Definition primitives.h:1776
@ Max
Definition primitives.h:1776
@ Prod
Definition primitives.h:1776
@ None
Definition primitives.h:1776
@ Min
Definition primitives.h:1776
void eval_cpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1790
void eval_gpu(const std::vector< array > &inputs, array &out) override
Scatter(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1778
void eval_gpu(const std::vector< array > &inputs, array &out) override
Select(Stream stream)
Definition primitives.h:874
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sigmoid(Stream stream)
Definition primitives.h:1856
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Sign(Stream stream)
Definition primitives.h:1870
Sin(Stream stream)
Definition primitives.h:1884
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sinh(Stream stream)
Definition primitives.h:1898
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1929
Slice(Stream stream, const Shape &start_indices, const Shape &end_indices, const Shape &strides)
Definition primitives.h:1912
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
SliceUpdate(Stream stream, const Shape &start_indices, const Shape &end_indices, const Shape &strides)
Definition primitives.h:1941
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1959
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Softmax(Stream stream, bool precise)
Definition primitives.h:2016
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2028
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2049
Sort(Stream stream, int axis)
Definition primitives.h:2038
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
std::pair< std::vector< int >, int > state() const
Definition primitives.h:2071
Split(Stream stream, const Shape &indices, int axis)
Definition primitives.h:2059
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
auto state() const
Definition primitives.h:2108
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sqrt(Stream stream, bool recip=false)
Definition primitives.h:2098
void eval_gpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:2112
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Square(Stream stream)
Definition primitives.h:2084
Squeeze(Stream stream, std::vector< int > axes)
Definition primitives.h:2156
auto state() const
Definition primitives.h:2170
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const array &input, const std::vector< int > &axes)
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
StopGradient(Stream stream)
Definition primitives.h:2126
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Subtract(Stream stream)
Definition primitives.h:2142
Tan(Stream stream)
Definition primitives.h:2181
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Tanh(Stream stream)
Definition primitives.h:2195
void eval_cpu(const std::vector< array > &inputs, array &out) override
Transpose(Stream stream, const std::vector< int > &axes)
Definition primitives.h:2253
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< int > state() const
Definition primitives.h:2264
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:126
UnaryPrimitive & operator=(const UnaryPrimitive &other)=delete
UnaryPrimitive(Stream stream)
An abstract base class for a primitive with a single output.
Definition primitives.h:131
virtual void eval_gpu(const std::vector< array > &inputs, array &output)=0
UnaryPrimitive(UnaryPrimitive &&other)=delete
virtual void eval_cpu(const std::vector< array > &inputs, array &output)=0
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:141
UnaryPrimitive(const UnaryPrimitive &other)=delete
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:136
UnaryPrimitive & operator=(UnaryPrimitive &&other)=delete
virtual ~UnaryPrimitive()=default
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Unflatten(Stream stream, int axis, Shape shape)
Definition primitives.h:2209
static Shape output_shape(const array &input, int axis, const Shape &shape)
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2222
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2243
void print(std::ostream &os) override
Print the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
View(Stream stream, Dtype dtype)
Definition primitives.h:2234
void eval_gpu(const std::vector< array > &inputs, array &out) override
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
array tri(int n, int m, int k, Dtype type, StreamOrDevice s={})
array transpose(const array &a, std::vector< int > axes, StreamOrDevice s={})
Permutes the dimensions according to the given axes.
array real(const array &a, StreamOrDevice s={})
std::vector< ShapeElem > Shape
Definition array.h:21
Stream new_stream(Device d)
Make a new stream on the given device.
std::vector< int64_t > Strides
Definition array.h:22
void eval(std::vector< array > outputs)
#define DEFINE_DEFAULT_IS_EQUIVALENT()
Definition primitives.h:34
#define DEFINE_PRINT(PRIMITIVE)
Definition primitives.h:29
#define DEFINE_INPUT_OUTPUT_SHAPE()
Definition primitives.h:39
#define DEFINE_GRADS()
Definition primitives.h:17
#define DEFINE_VMAP()
Definition primitives.h:12
static constexpr DeviceType gpu
Definition device.h:14
static constexpr DeviceType cpu
Definition device.h:13
Device device
Definition stream.h:11