5#include <unordered_set> 
   12#define DEFINE_VMAP()                                                 \ 
   13  virtual std::pair<std::vector<array>, std::vector<int>> vmap(       \ 
   14      const std::vector<array>& inputs, const std::vector<int>& axes) \ 
 
   17#define DEFINE_GRADS()                           \ 
   18  std::vector<array> jvp(                        \ 
   19      const std::vector<array>& primals,         \ 
   20      const std::vector<array>& tangents,        \ 
   21      const std::vector<int>& argnums) override; \ 
   23  std::vector<array> vjp(                        \ 
   24      const std::vector<array>& primals,         \ 
   25      const std::vector<array>& cotangents,      \ 
   26      const std::vector<int>& argnums,           \ 
   27      const std::vector<array>& outputs) override; 
 
   29#define DEFINE_PRINT(PRIMITIVE)           \ 
   30  void print(std::ostream& os) override { \ 
 
   34#define DEFINE_DEFAULT_IS_EQUIVALENT()                        \ 
   35  bool is_equivalent(const Primitive& other) const override { \ 
 
   39#define DEFINE_INPUT_OUTPUT_SHAPE()                                  \ 
   40  std::vector<Shape> output_shapes(const std::vector<array>& inputs) \ 
   42    return {inputs[0].shape()};                                      \ 
 
   70      const std::vector<array>& inputs,
 
   71      std::vector<array>& outputs) = 0;
 
   73      const std::vector<array>& inputs,
 
   74      std::vector<array>& outputs) = 0;
 
   79  virtual std::vector<array> 
jvp(
 
   80      const std::vector<array>& primals,
 
   81      const std::vector<array>& tangents,
 
   82      const std::vector<int>& argnums);
 
   87  virtual std::vector<array> 
vjp(
 
   88      const std::vector<array>& primals,
 
   89      const std::vector<array>& cotangents,
 
   90      const std::vector<int>& argnums,
 
   91      const std::vector<array>& outputs);
 
   99  virtual std::pair<std::vector<array>, std::vector<int>> 
vmap(
 
  100      const std::vector<array>& inputs,
 
  101      const std::vector<int>& axes);
 
  104  virtual void print(std::ostream& os) = 0;
 
 
  137      const std::vector<array>& inputs,
 
  138      std::vector<array>& outputs)
 override {
 
 
  142      const std::vector<array>& inputs,
 
  143      std::vector<array>& outputs)
 override {
 
 
 
  196    return {alpha_, beta_};
 
 
 
  216    return {start_, stop_, step_};
 
 
 
  337    return {kth_, axis_};
 
 
 
  364    return {reduce_type_, axis_};
 
 
  368  ReduceType reduce_type_;
 
 
  417        shape_(
std::move(shape)),
 
  418        strides_(
std::move(strides)),
 
 
  428    return std::make_tuple(shape_, strides_, offset_);
 
 
  436  void eval(
const std::vector<array>& inputs, 
array& out);
 
 
  471      const std::vector<array>& primals,
 
  472      const std::vector<array>& cotangents,
 
  473      const std::vector<int>& argnums,
 
  474      const std::vector<array>& outputs) 
override;
 
 
  494      const std::vector<array>& primals,
 
  495      const std::vector<array>& cotangents,
 
  496      const std::vector<int>& argnums,
 
  497      const std::vector<array>& outputs) 
override;
 
 
  517      const 
std::vector<
int>& ignore_axes);
 
  524  void eval(
const std::vector<array>& inputs, 
array& out);
 
  525  std::vector<int> ignore_axes_;
 
 
  549  void eval(
const std::vector<array>& inputs, 
array& out);
 
 
  579      std::vector<array> inputs,
 
  580      std::vector<array> outputs,
 
  581      std::vector<array> tape,
 
  582      std::unordered_set<uintptr_t> constant_ids);
 
  584  void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  586  void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  600  const std::vector<array> inputs_;
 
  601  const std::vector<array> outputs_;
 
  602  const std::vector<array> tape_;
 
  603  const std::unordered_set<uintptr_t> constant_ids_;
 
  605  std::string kernel_lib_;
 
 
  658  bool allow_col_major_;
 
 
  665      const std::vector<int>& kernel_strides,
 
  666      const std::vector<int>& padding,
 
  667      const std::vector<int>& kernel_dilation,
 
  668      const std::vector<int>& input_dilation,
 
  669      const int groups = 1,
 
  670      const bool flip = 
false)
 
  673        kernel_strides_(kernel_strides),
 
  674        kernel_dilation_(kernel_dilation),
 
  675        input_dilation_(input_dilation),
 
 
  683      const std::vector<array>& primals,
 
  684      const std::vector<array>& cotangents,
 
  685      const std::vector<int>& argnums,
 
  686      const std::vector<array>& outputs) 
override;
 
  691    return std::make_tuple(
 
 
  701  std::vector<int> padding_;
 
  702  std::vector<int> kernel_strides_;
 
  703  std::vector<int> kernel_dilation_;
 
  704  std::vector<int> input_dilation_;
 
 
  759      std::function<std::vector<array>(
 
  760          const std::vector<array>&,
 
  761          const std::vector<array>&,
 
  762          const std::vector<array>&)> 
vjp,
 
  763      std::function<std::vector<array>(
 
  764          const std::vector<array>&,
 
  765          const std::vector<array>&,
 
  766          const std::vector<int>&)> 
jvp,
 
  767      std::function<std::pair<std::vector<array>, std::vector<int>>(
 
  768          const std::vector<array>&,
 
  769          const std::vector<int>&)> 
vmap)
 
  771        num_outputs_(num_outputs),
 
 
  776  void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  778  void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  786  void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
 
  790  std::function<std::vector<array>(
 
  791      const std::vector<array>&,
 
  792      const std::vector<array>&,
 
  793      const std::vector<array>&)>
 
  795  std::function<std::vector<array>(
 
  796      const std::vector<array>&,
 
  797      const std::vector<array>&,
 
  798      const std::vector<int>&)>
 
  800  std::function<std::pair<std::vector<array>, std::vector<int>>(
 
  801      const std::vector<array>&,
 
  802      const std::vector<int>&)>
 
 
  810  void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  812  void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  816      const std::vector<array>& primals,
 
  817      const std::vector<array>& cotan,
 
  818      const std::vector<int>& argnums,
 
  819      const std::vector<array>& outputs) 
override;
 
  824  void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
 
 
  845  void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  847  void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  855    return std::vector{inputs[0].shape(), inputs[0].shape()};
 
 
 
  991  void eval(
const std::vector<array>& inputs, 
array& out);
 
  992  std::vector<int> axes_;
 
 
  999      const std::vector<size_t>& axes,
 
 
 1013    return std::make_tuple(axes_, inverse_, real_);
 
 
 1017  std::vector<size_t> axes_;
 
 
 1038    return std::make_pair(start_axis_, end_axis_);
 
 
 1044  void eval(
const std::vector<array>& inputs, 
array& out);
 
 
 1078        axes_(
std::move(axes)),
 
 1079        slice_sizes_(
std::move(slice_sizes)) {}
 
 
 1090    return {axes_, slice_sizes_};
 
 
 1094  std::vector<int> axes_;
 
 
 1215      std::shared_ptr<io::Reader> reader,
 
 1217      bool swap_endianness = 
false)
 
 1219        reader_(
std::move(reader)),
 
 1221        swap_endianness_(swap_endianness) {
 
 
 1237  std::shared_ptr<io::Reader> reader_;
 
 1239  bool swap_endianness_;
 
 
 1436      std::vector<int> axes,
 
 1440        axes_(
std::move(axes)),
 
 1441        inverted_(inverted),
 
 
 1454    return {axes_, inverted_, dtype_};
 
 
 1458  std::vector<int> axes_;
 
 1462  void eval(
const std::vector<array>& inputs, 
array& out);
 
 
 1469      const std::vector<int>& axes,
 
 1470      const Shape& low_pad_size,
 
 1471      const Shape& high_pad_size)
 
 1474        low_pad_size_(low_pad_size),
 
 1475        high_pad_size_(high_pad_size) {}
 
 
 1485    return std::make_tuple(axes_, low_pad_size_, high_pad_size_);
 
 
 1489  std::vector<int> axes_;
 
 1490  Shape low_pad_size_;
 
 1491  Shape high_pad_size_;
 
 
 1508    return std::make_pair(kth_, axis_);
 
 
 
 1538        group_size_(group_size),
 
 
 1551    return std::make_tuple(group_size_, bits_, transpose_);
 
 
 
 1564        group_size_(group_size),
 
 
 1576    return std::make_tuple(group_size_, bits_, transpose_);
 
 
 
 1597    return {shape_, width_};
 
 
 
 1648      const std::vector<int>& axes)
 
 
 1658      const 
std::vector<
array>& cotangents,
 
 1659      const 
std::vector<
int>& argnums,
 
 1660      const 
std::vector<
array>& outputs) override;
 
 1665    switch (reduce_type_) {
 
 
 1687  std::pair<ReduceType, std::vector<int>> 
state()
 const {
 
 1688    return {reduce_type_, axes_};
 
 
 1692  ReduceType reduce_type_;
 
 1693  std::vector<int> axes_;
 
 
 1721        reduce_type_(reduce_type),
 
 1724        inclusive_(inclusive) {}
 
 
 1734    switch (reduce_type_) {
 
 
 1751    return std::make_tuple(reduce_type_, axis_, reverse_, inclusive_);
 
 
 1755  ReduceType reduce_type_;
 
 
 1768      const std::vector<int>& axes)
 
 
 1779    switch (reduce_type_) {
 
 
 1797  std::pair<ReduceType, std::vector<int>> 
state()
 const {
 
 1798    return {reduce_type_, axes_};
 
 
 1802  ReduceType reduce_type_;
 
 1803  std::vector<int> axes_;
 
 
 1820    os << 
"ScatterAxis";
 
 1821    switch (reduce_type_) {
 
 
 1832  std::pair<ReduceType, int> 
state()
 const {
 
 1833    return {reduce_type_, axis_};
 
 
 1837  ReduceType reduce_type_;
 
 
 1901      const Shape& start_indices,
 
 1902      const Shape& end_indices,
 
 1903      const Shape& strides)
 
 1905        start_indices_(start_indices),
 
 1906        end_indices_(end_indices),
 
 1907        strides_(strides) {}
 
 
 1917    return std::make_tuple(start_indices_, end_indices_, strides_);
 
 
 1921  Shape start_indices_;
 
 
 1930      const Shape& start_indices,
 
 1931      const Shape& end_indices,
 
 1932      const Shape& strides)
 
 1934        start_indices_(start_indices),
 
 1935        end_indices_(end_indices),
 
 1936        strides_(strides) {}
 
 
 1947    return std::make_tuple(start_indices_, end_indices_, strides_);
 
 
 1951  Shape start_indices_;
 
 
 1960        axes_(
std::move(axes)),
 
 1961        slice_size_(
std::move(slice_size)) {}
 
 
 1972    return std::make_pair(axes_, slice_size_);
 
 
 1976  std::vector<int> axes_;
 
 
 1998  std::vector<int> axes_;
 
 
 2049  void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
 2051  void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
 2059    return {indices_, axis_};
 
 
 2063  void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
 
 
 2162  void eval(
const std::vector<array>& inputs, 
array& out);
 
 2163  std::vector<int> axes_;
 
 
 2210    return std::make_pair(axis_, shape_);
 
 
 2216  void eval(
const std::vector<array>& inputs, 
array& out);
 
 
 2256  std::vector<int> axes_;
 
 2258  void eval(
const std::vector<array>& inputs, 
array& out);
 
 
 2266  void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
 2268  void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
 
 2279  void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
 2281  void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
 
 2300    return std::make_pair(tri_, upper_);
 
 
 
 2330        uplo_(
std::move(uplo)),
 
 2331        compute_eigenvectors_(compute_eigenvectors) {}
 
 
 2333  void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
 2335  void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
 2345    return std::make_pair(uplo_, compute_eigenvectors_);
 
 
 2350  bool compute_eigenvectors_;
 
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Abs(Stream stream)
Definition primitives.h:156
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Add(Stream stream)
Definition primitives.h:170
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
std::pair< float, float > state() const
Definition primitives.h:195
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
AddMM(Stream stream, float alpha, float beta)
Definition primitives.h:184
 
Arange(Stream stream, double start, double stop, double step)
Definition primitives.h:206
 
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
std::tuple< double, double, double > state() const
Definition primitives.h:215
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
ArcCos(Stream stream)
Definition primitives.h:227
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
ArcCosh(Stream stream)
Definition primitives.h:241
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
ArcSin(Stream stream)
Definition primitives.h:255
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
ArcSinh(Stream stream)
Definition primitives.h:269
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
ArcTan2(Stream stream)
Definition primitives.h:297
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
ArcTan(Stream stream)
Definition primitives.h:283
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
ArcTanh(Stream stream)
Definition primitives.h:311
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
std::pair< int, int > state() const
Definition primitives.h:336
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
ArgPartition(Stream stream, int kth, int axis)
Definition primitives.h:325
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
 
ReduceType
Definition primitives.h:347
 
@ ArgMin
Definition primitives.h:348
 
@ ArgMax
Definition primitives.h:349
 
ArgReduce(Stream stream, ReduceType reduce_type, int axis)
Definition primitives.h:352
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
std::pair< ReduceType, int > state() const
Definition primitives.h:363
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
ArgSort(Stream stream, int axis)
Definition primitives.h:374
 
int state() const
Definition primitives.h:384
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
auto state() const
Definition primitives.h:427
 
AsStrided(Stream stream, Shape shape, Strides strides, size_t offset)
Definition primitives.h:415
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
AsType(Stream stream, Dtype dtype)
Definition primitives.h:394
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
Dtype state() const
Definition primitives.h:405
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
BitwiseBinary(Stream stream, Op op)
Definition primitives.h:443
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void print(std::ostream &os) override
Print the primitive.
 
Op
Definition primitives.h:441
 
@ RightShift
Definition primitives.h:441
 
@ Or
Definition primitives.h:441
 
@ LeftShift
Definition primitives.h:441
 
@ And
Definition primitives.h:441
 
@ Xor
Definition primitives.h:441
 
auto state() const
Definition primitives.h:454
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
 
auto state() const
Definition primitives.h:478
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
BlockMaskedMM(Stream stream, int block_size)
Definition primitives.h:464
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
BroadcastAxes(Stream stream, std::vector< int > ignore_axes={})
Definition primitives.h:505
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
auto state() const
Definition primitives.h:519
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
static Shape output_shape(const std::vector< array > &inputs, const std::vector< int > &ignore_axes)
 
Broadcast(Stream stream, const Shape &shape)
Definition primitives.h:530
 
static Shape output_shape(const std::vector< array > &inputs)
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< int > state() const
Definition primitives.h:542
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Ceil(Stream stream)
Definition primitives.h:554
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
auto state() const
Definition primitives.h:2315
 
Cholesky(Stream stream, bool upper)
Definition primitives.h:2310
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
 
void print(std::ostream &os) override
Print the primitive.
 
Compiled(Stream stream, std::vector< array > inputs, std::vector< array > outputs, std::vector< array > tape, std::unordered_set< uintptr_t > constant_ids)
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
 
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
 
std::string lib_name() const
Definition primitives.h:595
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
auto state() const
Definition primitives.h:621
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
Concatenate(Stream stream, int axis)
Definition primitives.h:610
 
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
 
Conjugate(Stream stream)
Definition primitives.h:631
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Contiguous(Stream stream, bool allow_col_major)
Definition primitives.h:644
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Convolution(Stream stream, const std::vector< int > &kernel_strides, const std::vector< int > &padding, const std::vector< int > &kernel_dilation, const std::vector< int > &input_dilation, const int groups=1, const bool flip=false)
Definition primitives.h:663
 
auto state() const
Definition primitives.h:690
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Copy(Stream stream)
Definition primitives.h:711
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Cos(Stream stream)
Definition primitives.h:728
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Cosh(Stream stream)
Definition primitives.h:742
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
 
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
 
Depends(Stream stream)
Definition primitives.h:808
 
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
 
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
 
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:854
 
DivMod(Stream stream)
Definition primitives.h:843
 
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
 
Divide(Stream stream)
Definition primitives.h:829
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
 
DynamicSlice(Stream stream, std::vector< int > axes, Shape slice_size)
Definition primitives.h:1958
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
auto state() const
Definition primitives.h:1971
 
auto state() const
Definition primitives.h:1993
 
DynamicSliceUpdate(Stream stream, std::vector< int > axes)
Definition primitives.h:1982
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
 
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
 
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
 
auto state() const
Definition primitives.h:2344
 
Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)
Definition primitives.h:2328
 
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:900
 
Equal(Stream stream, bool equal_nan=false)
Definition primitives.h:889
 
auto state() const
Definition primitives.h:907
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Erf(Stream stream)
Definition primitives.h:917
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
ErfInv(Stream stream)
Definition primitives.h:931
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Exp(Stream stream)
Definition primitives.h:945
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
static Shape output_shape(const array &input, const std::vector< int > &axes)
 
auto state() const
Definition primitives.h:986
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
ExpandDims(Stream stream, std::vector< int > axes)
Definition primitives.h:972
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
 
Expm1(Stream stream)
Definition primitives.h:959
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
FFT(Stream stream, const std::vector< size_t > &axes, bool inverse, bool real)
Definition primitives.h:997
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
auto state() const
Definition primitives.h:1012
 
static Shape output_shape(const array &input, int start_axis, int end_axis)
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Flatten(Stream stream, int start_axis, int end_axis)
Definition primitives.h:1024
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
auto state() const
Definition primitives.h:1037
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Floor(Stream stream)
Definition primitives.h:1049
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Full(Stream stream)
Definition primitives.h:1063
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
 
GatherAxis(Stream stream, int axis)
Definition primitives.h:1100
 
auto state() const
Definition primitives.h:1111
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
std::pair< std::vector< int >, std::vector< int > > state() const
Definition primitives.h:1089
 
Gather(Stream stream, std::vector< int > axes, Shape slice_sizes)
Definition primitives.h:1076
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
GatherMM(Stream stream)
Definition primitives.h:488
 
auto state() const
Definition primitives.h:1575
 
GatherQMM(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1562
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
GreaterEqual(Stream stream)
Definition primitives.h:1135
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Greater(Stream stream)
Definition primitives.h:1121
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Hadamard(Stream stream, float scale)
Definition primitives.h:1149
 
auto state() const
Definition primitives.h:1161
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Imag(Stream stream)
Definition primitives.h:1171
 
void eval_gpu(const std::vector< array > &inputs, array &output) override
 
Inverse(Stream stream, bool tri, bool upper)
Definition primitives.h:2291
 
auto state() const
Definition primitives.h:2299
 
void eval_cpu(const std::vector< array > &inputs, array &output) override
 
LessEqual(Stream stream)
Definition primitives.h:1199
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Less(Stream stream)
Definition primitives.h:1185
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Load(Stream stream, std::shared_ptr< io::Reader > reader, size_t offset, bool swap_endianness=false)
Definition primitives.h:1213
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Log1p(Stream stream)
Definition primitives.h:1281
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
LogAddExp(Stream stream)
Definition primitives.h:1336
 
Base
Definition primitives.h:1244
 
@ ten
Definition primitives.h:1244
 
@ two
Definition primitives.h:1244
 
@ e
Definition primitives.h:1244
 
Log(Stream stream, Base base)
Definition primitives.h:1246
 
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1261
 
Base state() const
Definition primitives.h:1257
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
LogicalAnd(Stream stream)
Definition primitives.h:1308
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
LogicalNot(Stream stream)
Definition primitives.h:1294
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
LogicalOr(Stream stream)
Definition primitives.h:1322
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
 
Matmul(Stream stream)
Definition primitives.h:1350
 
Maximum(Stream stream)
Definition primitives.h:1364
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Minimum(Stream stream)
Definition primitives.h:1378
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Multiply(Stream stream)
Definition primitives.h:1392
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Negative(Stream stream)
Definition primitives.h:1406
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
NotEqual(Stream stream)
Definition primitives.h:1420
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:1450
 
NumberOfElements(Stream stream, std::vector< int > axes, bool inverted, Dtype dtype)
Definition primitives.h:1434
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
std::tuple< std::vector< int >, bool, Dtype > state() const
Definition primitives.h:1453
 
auto state() const
Definition primitives.h:1484
 
Pad(Stream stream, const std::vector< int > &axes, const Shape &low_pad_size, const Shape &high_pad_size)
Definition primitives.h:1467
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Partition(Stream stream, int kth, int axis)
Definition primitives.h:1496
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
auto state() const
Definition primitives.h:1507
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Power(Stream stream)
Definition primitives.h:1518
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:48
 
virtual void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
 
virtual std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
The vector-Jacobian product.
 
virtual ~Primitive()=default
 
Primitive(const Primitive &other)=delete
 
Primitive(Primitive &&other)=delete
 
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
 
Primitive & operator=(Primitive &&other)=delete
 
virtual bool is_equivalent(const Primitive &other) const
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:107
 
Primitive & operator=(const Primitive &other)=delete
 
const Device & device()
The device the primitive will run on.
Definition primitives.h:53
 
virtual std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
The Jacobian-vector product.
 
virtual std::vector< Shape > output_shapes(const std::vector< array > &inputs)
Get the output shapes of the primitive.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes)
The primitive must know how to vectorize itself across the given axes.
 
virtual void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
 
virtual void print(std::ostream &os)=0
Print the primitive.
 
Primitive(Stream stream)
Definition primitives.h:50
 
QRF(Stream stream)
Definition primitives.h:2264
 
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
 
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
QuantizedMatmul(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1532
 
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
auto state() const
Definition primitives.h:1550
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
std::pair< std::vector< int >, int > state() const
Definition primitives.h:1596
 
RandomBits(Stream stream, const Shape &shape, int width)
Definition primitives.h:1587
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Real(Stream stream)
Definition primitives.h:1607
 
Reduce(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1645
 
ReduceType
Definition primitives.h:1643
 
@ Min
Definition primitives.h:1643
 
@ Or
Definition primitives.h:1643
 
@ Max
Definition primitives.h:1643
 
@ And
Definition primitives.h:1643
 
@ Sum
Definition primitives.h:1643
 
@ Prod
Definition primitives.h:1643
 
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1664
 
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
 
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
std::pair< ReduceType, std::vector< int > > state() const
Definition primitives.h:1687
 
Remainder(Stream stream)
Definition primitives.h:875
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
static Shape output_shape(const array &input, Shape shape)
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Reshape(Stream stream, const Shape &shape)
Definition primitives.h:1621
 
std::vector< int > state() const
Definition primitives.h:1631
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
 
Round(Stream stream)
Definition primitives.h:1698
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
 
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
 
SVD(Stream stream)
Definition primitives.h:2277
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
ReduceType
Definition primitives.h:1712
 
@ Prod
Definition primitives.h:1712
 
@ Min
Definition primitives.h:1712
 
@ Max
Definition primitives.h:1712
 
@ Sum
Definition primitives.h:1712
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
auto state() const
Definition primitives.h:1750
 
Scan(Stream stream, ReduceType reduce_type, int axis, bool reverse, bool inclusive)
Definition primitives.h:1714
 
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1732
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
std::pair< ReduceType, int > state() const
Definition primitives.h:1832
 
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1819
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
ScatterAxis(Stream stream, ReduceType reduce_type, int axis)
Definition primitives.h:1810
 
ReduceType
Definition primitives.h:1808
 
@ Sum
Definition primitives.h:1808
 
@ None
Definition primitives.h:1808
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
std::pair< ReduceType, std::vector< int > > state() const
Definition primitives.h:1797
 
ReduceType
Definition primitives.h:1763
 
@ Sum
Definition primitives.h:1763
 
@ Max
Definition primitives.h:1763
 
@ Prod
Definition primitives.h:1763
 
@ None
Definition primitives.h:1763
 
@ Min
Definition primitives.h:1763
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1777
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Scatter(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1765
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Select(Stream stream)
Definition primitives.h:861
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Sigmoid(Stream stream)
Definition primitives.h:1843
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Sign(Stream stream)
Definition primitives.h:1857
 
Sin(Stream stream)
Definition primitives.h:1871
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Sinh(Stream stream)
Definition primitives.h:1885
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
auto state() const
Definition primitives.h:1916
 
Slice(Stream stream, const Shape &start_indices, const Shape &end_indices, const Shape &strides)
Definition primitives.h:1899
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
SliceUpdate(Stream stream, const Shape &start_indices, const Shape &end_indices, const Shape &strides)
Definition primitives.h:1928
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
auto state() const
Definition primitives.h:1946
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Softmax(Stream stream, bool precise)
Definition primitives.h:2003
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
auto state() const
Definition primitives.h:2015
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
auto state() const
Definition primitives.h:2036
 
Sort(Stream stream, int axis)
Definition primitives.h:2025
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
 
std::pair< std::vector< int >, int > state() const
Definition primitives.h:2058
 
Split(Stream stream, const Shape &indices, int axis)
Definition primitives.h:2046
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
 
auto state() const
Definition primitives.h:2095
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Sqrt(Stream stream, bool recip=false)
Definition primitives.h:2085
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:2099
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Square(Stream stream)
Definition primitives.h:2071
 
Squeeze(Stream stream, std::vector< int > axes)
Definition primitives.h:2143
 
auto state() const
Definition primitives.h:2157
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
static Shape output_shape(const array &input, const std::vector< int > &axes)
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
StopGradient(Stream stream)
Definition primitives.h:2113
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Subtract(Stream stream)
Definition primitives.h:2129
 
Tan(Stream stream)
Definition primitives.h:2168
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Tanh(Stream stream)
Definition primitives.h:2182
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Transpose(Stream stream, const std::vector< int > &axes)
Definition primitives.h:2240
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
std::vector< int > state() const
Definition primitives.h:2251
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
 
Definition primitives.h:126
 
UnaryPrimitive & operator=(const UnaryPrimitive &other)=delete
 
UnaryPrimitive(Stream stream)
An abstract base class for a primitive with a single output.
Definition primitives.h:131
 
virtual void eval_gpu(const std::vector< array > &inputs, array &output)=0
 
UnaryPrimitive(UnaryPrimitive &&other)=delete
 
virtual void eval_cpu(const std::vector< array > &inputs, array &output)=0
 
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:141
 
UnaryPrimitive(const UnaryPrimitive &other)=delete
 
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:136
 
UnaryPrimitive & operator=(UnaryPrimitive &&other)=delete
 
virtual ~UnaryPrimitive()=default
 
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
 
Unflatten(Stream stream, int axis, Shape shape)
Definition primitives.h:2196
 
static Shape output_shape(const array &input, int axis, const Shape &shape)
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
auto state() const
Definition primitives.h:2209
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
auto state() const
Definition primitives.h:2230
 
void print(std::ostream &os) override
Print the primitive.
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
View(Stream stream, Dtype dtype)
Definition primitives.h:2221
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
 
array tri(int n, int m, int k, Dtype type, StreamOrDevice s={})
 
array transpose(const array &a, std::vector< int > axes, StreamOrDevice s={})
Permutes the dimensions according to the given axes.
 
array real(const array &a, StreamOrDevice s={})
 
std::vector< ShapeElem > Shape
Definition array.h:21
 
Stream new_stream(Device d)
Make a new stream on the given device.
 
std::vector< int64_t > Strides
Definition array.h:22
 
void eval(std::vector< array > outputs)
 
#define DEFINE_DEFAULT_IS_EQUIVALENT()
Definition primitives.h:34
 
#define DEFINE_PRINT(PRIMITIVE)
Definition primitives.h:29
 
#define DEFINE_INPUT_OUTPUT_SHAPE()
Definition primitives.h:39
 
#define DEFINE_GRADS()
Definition primitives.h:17
 
#define DEFINE_VMAP()
Definition primitives.h:12
 
static constexpr DeviceType gpu
Definition device.h:14
 
static constexpr DeviceType cpu
Definition device.h:13
 
Device device
Definition stream.h:11