5#include <unordered_set> 
   12#define DEFINE_VMAP()                                                 \ 
   13  virtual std::pair<std::vector<array>, std::vector<int>> vmap(       \ 
   14      const std::vector<array>& inputs, const std::vector<int>& axes) \ 
 
   17#define DEFINE_GRADS()                           \ 
   18  std::vector<array> jvp(                        \ 
   19      const std::vector<array>& primals,         \ 
   20      const std::vector<array>& tangents,        \ 
   21      const std::vector<int>& argnums) override; \ 
   23  std::vector<array> vjp(                        \ 
   24      const std::vector<array>& primals,         \ 
   25      const std::vector<array>& cotangents,      \ 
   26      const std::vector<int>& argnums,           \ 
   27      const std::vector<array>& outputs) override; 
 
   29#define DEFINE_PRINT(PRIMITIVE)           \ 
   30  void print(std::ostream& os) override { \ 
 
   34#define DEFINE_DEFAULT_IS_EQUIVALENT()                        \ 
   35  bool is_equivalent(const Primitive& other) const override { \ 
 
   39#define DEFINE_INPUT_OUTPUT_SHAPE()                \ 
   40  std::vector<std::vector<int>> output_shapes(     \ 
   41      const std::vector<array>& inputs) override { \ 
   42    return {inputs[0].shape()};                    \ 
 
   70      const std::vector<array>& inputs,
 
   71      std::vector<array>& outputs) = 0;
 
   73      const std::vector<array>& inputs,
 
   74      std::vector<array>& outputs) = 0;
 
   79  virtual std::vector<array> 
jvp(
 
   80      const std::vector<array>& primals,
 
   81      const std::vector<array>& tangents,
 
   82      const std::vector<int>& argnums);
 
   87  virtual std::vector<array> 
vjp(
 
   88      const std::vector<array>& primals,
 
   89      const std::vector<array>& cotangents,
 
   90      const std::vector<int>& argnums,
 
   91      const std::vector<array>& outputs);
 
   99  virtual std::pair<std::vector<array>, std::vector<int>> 
vmap(
 
  100      const std::vector<array>& inputs,
 
  101      const std::vector<int>& axes);
 
  104  virtual void print(std::ostream& os) = 0;
 
  114      const std::vector<array>& inputs);
 
 
  138      const std::vector<array>& inputs,
 
  139      std::vector<array>& outputs)
 override {
 
 
  143      const std::vector<array>& inputs,
 
  144      std::vector<array>& outputs)
 override {
 
 
 
  169  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  186  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  198      const std::vector<array>& primals,
 
  199      const std::vector<array>& cotangents,
 
  200      const std::vector<int>& argnums,
 
  201      const std::vector<array>& outputs) 
override;
 
 
  229  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  246  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  263  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  280  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  297  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  314  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  331  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  348  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  368  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  388      const std::vector<
array>& inputs) override;
 
  394  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  413  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  433  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  440      std::vector<int> shape,
 
  441      std::vector<size_t> strides,
 
  444        shape_(
std::move(shape)),
 
  445        strides_(
std::move(strides)),
 
 
  456  std::vector<
int> shape_;
 
  457  std::vector<
size_t> strides_;
 
  460  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  476  void print(std::ostream& os) override;
 
 
  492      const std::vector<array>& primals,
 
  493      const std::vector<array>& cotangents,
 
  494      const std::vector<int>& argnums,
 
  495      const std::vector<array>& outputs) 
override;
 
  503  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  514      const std::vector<array>& primals,
 
  515      const std::vector<array>& cotangents,
 
  516      const std::vector<int>& argnums,
 
  517      const std::vector<array>& outputs) 
override;
 
  523  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  540  std::vector<
int> shape_;
 
  542  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  559  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  575      std::vector<array> inputs,
 
  576      std::vector<array> outputs,
 
  577      std::vector<array> tape,
 
  578      std::unordered_set<uintptr_t> constant_ids);
 
  580  void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  582  void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  588      const std::vector<
array>& inputs) override;
 
  589  void print(std::ostream& os) override;
 
  592  std::
string lib_name()
 const {
 
 
  597  const std::vector<array> inputs_;
 
  598  const std::vector<array> outputs_;
 
  599  const std::vector<array> tape_;
 
  600  const std::unordered_set<uintptr_t> constant_ids_;
 
  602  std::string kernel_lib_;
 
 
  616  bool is_equivalent(const 
Primitive& other) const override;
 
  621  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  637  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  644      const std::vector<int>& kernel_strides,
 
  645      const std::vector<int>& padding,
 
  646      const std::vector<int>& kernel_dilation,
 
  647      const std::vector<int>& input_dilation,
 
  648      const int groups = 1,
 
  649      const bool flip = 
false)
 
  652        kernel_strides_(kernel_strides),
 
  653        kernel_dilation_(kernel_dilation),
 
  654        input_dilation_(input_dilation),
 
 
  662      const std::vector<array>& primals,
 
  663      const std::vector<array>& cotangents,
 
  664      const std::vector<int>& argnums,
 
  665      const std::vector<array>& outputs) 
override;
 
  668  bool is_equivalent(const 
Primitive& other) const override;
 
  671  std::vector<
int> padding_;
 
  672  std::vector<
int> kernel_strides_;
 
  673  std::vector<
int> kernel_dilation_;
 
  674  std::vector<
int> input_dilation_;
 
  678  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  695  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  712  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  729  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  737      std::function<std::vector<array>(
 
  738          const std::vector<array>&,
 
  739          const std::vector<array>&,
 
  740          const std::vector<array>&)> vjp,
 
  741      std::function<std::vector<array>(
 
  742          const std::vector<array>&,
 
  743          const std::vector<array>&,
 
  744          const std::vector<int>&)> jvp,
 
  745      std::function<std::pair<std::vector<array>, std::vector<int>>(
 
  746          const std::vector<array>&,
 
  747          const std::vector<int>&)> vmap)
 
  749        num_outputs_(num_outputs),
 
 
  754  void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  756  void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  764  void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
 
  768  std::function<std::vector<array>(
 
  769      const std::vector<array>&,
 
  770      const std::vector<array>&,
 
  771      const std::vector<array>&)>
 
  773  std::function<std::vector<array>(
 
  774      const std::vector<array>&,
 
  775      const std::vector<array>&,
 
  776      const std::vector<int>&)>
 
  778  std::function<std::pair<std::vector<array>, std::vector<int>>(
 
  779      const std::vector<array>&,
 
  780      const std::vector<int>&)>
 
 
  788  void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  790  void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  794      const std::vector<array>& primals,
 
  795      const std::vector<array>& cotan,
 
  796      const std::vector<int>& argnums,
 
  797      const std::vector<array>& outputs) 
override;
 
  802  void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
 
 
  819  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  826  void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  828  void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  835  std::vector<std::vector<
int>> output_shapes(
 
  836      const std::vector<
array>& inputs)
 override {
 
  837    return std::vector{inputs[0].shape(), inputs[0].shape()};
 
 
  841  void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
 
 
  858  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  875  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  891  void print(std::ostream& os)
 override {
 
 
  900  void eval(
const std::vector<array>& inputs, 
array& out);
 
 
  918  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  935  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  952  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  968  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
  975      const std::vector<size_t>& axes,
 
  978      : 
UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {}
 
 
  987  bool is_equivalent(const 
Primitive& other) const override;
 
  990  std::vector<
size_t> axes_;
 
  994  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1011  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1027  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1034      const std::vector<int>& axes,
 
 1035      const std::vector<int>& slice_sizes)
 
 1036      : 
UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes) {}
 
 
 1047  void eval(const std::vector<
array>& inputs, 
array& out);
 
 1048  std::vector<
int> axes_;
 
 1049  std::vector<
int> slice_sizes_;
 
 
 1066  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1083  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1104  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1121  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1138  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1145      std::shared_ptr<io::Reader> reader,
 
 1147      bool swap_endianness = 
false)
 
 1151        swap_endianness_(swap_endianness) {}
 
 
 1159  void eval(
const std::vector<array>& inputs, 
array& out);
 
 1160  std::shared_ptr<io::Reader> reader_;
 
 1162  bool swap_endianness_;
 
 
 1180  void print(std::ostream& os)
 override {
 
 
 1196  void eval(
const std::vector<array>& inputs, 
array& out);
 
 
 1212  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1229  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1246  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1263  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1280  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1291      const std::vector<array>& primals,
 
 1292      const std::vector<array>& cotangents,
 
 1293      const std::vector<int>& argnums,
 
 1294      const std::vector<array>& outputs) 
override;
 
 
 1315  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1332  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1349  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1366  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1383  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1390      std::vector<int> axes,
 
 1394        axes_(
std::move(axes)),
 
 1395        inverted_(inverted),
 
 
 1404  std::vector<std::vector<
int>> output_shapes(
 
 1405      const std::vector<
array>& inputs)
 override {
 
 
 1410  std::vector<int> axes_;
 
 1414  void eval(
const std::vector<array>& inputs, 
array& out);
 
 
 1421      const std::vector<int>& axes,
 
 1422      const std::vector<int>& low_pad_size,
 
 1423      const std::vector<int>& high_pad_size)
 
 1426        low_pad_size_(low_pad_size),
 
 1427        high_pad_size_(high_pad_size) {}
 
 
 1438  std::vector<
int> axes_;
 
 1439  std::vector<
int> low_pad_size_;
 
 1440  std::vector<
int> high_pad_size_;
 
 1442  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1463  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1480  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1491        group_size_(group_size),
 
 
 1508  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1515        group_size_(group_size),
 
 
 1532  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1548  std::vector<
int> shape_;
 
 1551  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1568  std::vector<
int> shape_;
 
 1570  void eval(const std::vector<
array>& inputs, 
array& out);
 
 1572  std::pair<
bool, std::vector<
size_t>> prepare_reshape(
 
 1575  void shared_buffer_reshape(
 
 1577      const std::vector<
size_t>& out_strides,
 
 
 1588      const std::vector<int>& axes)
 
 1589      : 
UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
 
 
 1597      const std::vector<
array>& primals,
 
 1598      const std::vector<
array>& cotangents,
 
 1599      const std::vector<
int>& argnums,
 
 1600      const std::vector<
array>& outputs) override;
 
 1602  std::vector<std::vector<
int>> output_shapes(
 
 1603      const std::vector<
array>& inputs) override;
 
 1605  void print(std::ostream& os)
 override {
 
 1606    switch (reduce_type_) {
 
 
 1631  std::vector<int> axes_;
 
 1633  void eval(
const std::vector<array>& inputs, 
array& out);
 
 
 1650  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1664        reduce_type_(reduce_type),
 
 1667        inclusive_(inclusive) {}
 
 
 1675  void print(std::ostream& os)
 override {
 
 1677    switch (reduce_type_) {
 
 
 1700  void eval(
const std::vector<array>& inputs, 
array& out);
 
 
 1710      const std::vector<int>& axes)
 
 1711      : 
UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
 
 
 1719    switch (reduce_type_) {
 
 
 1739  void eval(
const std::vector<array>& inputs, 
array& out);
 
 1741  std::vector<int> axes_;
 
 
 1758  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1775  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1792  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1809  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1816      const std::vector<int>& start_indices,
 
 1817      const std::vector<int>& end_indices,
 
 1818      const std::vector<int>& strides)
 
 1820        start_indices_(start_indices),
 
 1821        end_indices_(end_indices),
 
 1822        strides_(strides) {}
 
 
 1833  std::vector<
int> start_indices_;
 
 1834  std::vector<
int> end_indices_;
 
 1835  std::vector<
int> strides_;
 
 1837  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1844      const std::vector<int>& start_indices,
 
 1845      const std::vector<int>& end_indices,
 
 1846      const std::vector<int>& strides)
 
 1848        start_indices_(start_indices),
 
 1849        end_indices_(end_indices),
 
 1850        strides_(strides) {}
 
 
 1861  std::vector<
int> start_indices_;
 
 1862  std::vector<
int> end_indices_;
 
 1863  std::vector<
int> strides_;
 
 1865  void eval(const std::vector<
array>& inputs, 
array& out);
 
 1867  std::tuple<int64_t, std::vector<int64_t>> prepare_slice(const 
array& in);
 
 
 1886  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1907  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1912  explicit Split(
Stream stream, 
const std::vector<int>& indices, 
int axis)
 
 1913      : 
Primitive(stream), indices_(indices), axis_(axis) {}
 
 
 1915  void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
 1917  void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
 1926  void eval(const std::vector<
array>& inputs, std::vector<
array>& outputs);
 
 1928  std::vector<
int> indices_;
 
 
 1946  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 1962  void print(std::ostream& os)
 override {
 
 
 1971  void eval(
const std::vector<array>& inputs, 
array& out);
 
 
 1988  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 2005  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 2022  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 2039  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 2054  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 2066  void print(std::ostream& os) override;
 
 
 2087  std::vector<
int> axes_;
 
 2089  void eval(const std::vector<
array>& inputs, 
array& out);
 
 
 2097  void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
 2099  void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
 2105  void eval(
const std::vector<array>& inputs, std::vector<array>& outputs);
 
 
 2113  void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
 2115  void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
 2122  void eval(const std::vector<
array>& inputs, std::vector<
array>& outputs);
 
 
 2137  void eval(const std::vector<
array>& inputs, 
array& output);
 
 
 2152  void eval(const std::vector<
array>& inputs, 
array& output);
 
 
Definition primitives.h:155
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Abs(Stream stream)
Definition primitives.h:157
 
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:164
 
std::vector< std::vector< int > > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:166
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:165
 
Definition primitives.h:172
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Add(Stream stream)
Definition primitives.h:174
 
Definition primitives.h:189
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
AddMM(Stream stream, float alpha, float beta)
Definition primitives.h:191
 
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
 
Definition primitives.h:213
 
Arange(Stream stream, double start, double stop, double step)
Definition primitives.h:215
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:232
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
ArcCos(Stream stream)
Definition primitives.h:234
 
Definition primitives.h:249
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
ArcCosh(Stream stream)
Definition primitives.h:251
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:266
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
ArcSin(Stream stream)
Definition primitives.h:268
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:283
 
ArcSinh(Stream stream)
Definition primitives.h:285
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:317
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
ArcTan2(Stream stream)
Definition primitives.h:319
 
Definition primitives.h:300
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
ArcTan(Stream stream)
Definition primitives.h:302
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:334
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
ArcTanh(Stream stream)
Definition primitives.h:336
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:351
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
ArgPartition(Stream stream, int kth, int axis)
Definition primitives.h:353
 
Definition primitives.h:371
 
ReduceType
Definition primitives.h:373
 
@ ArgMin
Definition primitives.h:374
 
@ ArgMax
Definition primitives.h:375
 
ArgReduce(Stream stream, ReduceType reduce_type, int axis)
Definition primitives.h:378
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:397
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
ArgSort(Stream stream, int axis)
Definition primitives.h:399
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:436
 
AsStrided(Stream stream, std::vector< int > shape, std::vector< size_t > strides, size_t offset)
Definition primitives.h:438
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:416
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
AsType(Stream stream, Dtype dtype)
Definition primitives.h:418
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:463
 
BitwiseBinary(Stream stream, Op op)
Definition primitives.h:467
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Op
Definition primitives.h:465
 
@ And
Definition primitives.h:465
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:483
 
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
BlockMaskedMM(Stream stream, int block_size)
Definition primitives.h:485
 
Definition primitives.h:526
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Broadcast(Stream stream, const std::vector< int > &shape)
Definition primitives.h:528
 
Definition primitives.h:545
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Ceil(Stream stream)
Definition primitives.h:547
 
Definition primitives.h:2140
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Cholesky(Stream stream, bool upper)
Definition primitives.h:2142
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:562
 
Compiled(Stream stream, std::vector< array > inputs, std::vector< array > outputs, std::vector< array > tape, std::unordered_set< uintptr_t > constant_ids)
 
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
 
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
 
Definition primitives.h:605
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Concatenate(Stream stream, int axis)
Definition primitives.h:607
 
Definition primitives.h:624
 
Conjugate(Stream stream)
Definition primitives.h:626
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:640
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Convolution(Stream stream, const std::vector< int > &kernel_strides, const std::vector< int > &padding, const std::vector< int > &kernel_dilation, const std::vector< int > &input_dilation, const int groups=1, const bool flip=false)
Definition primitives.h:642
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
 
Definition primitives.h:681
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Copy(Stream stream)
Definition primitives.h:683
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:698
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Cos(Stream stream)
Definition primitives.h:700
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:715
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Cosh(Stream stream)
Definition primitives.h:717
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:784
 
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
 
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
 
Depends(Stream stream)
Definition primitives.h:786
 
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
 
Definition primitives.h:822
 
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
 
DivMod(Stream stream)
Definition primitives.h:824
 
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
 
Definition primitives.h:805
 
Divide(Stream stream)
Definition primitives.h:807
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:878
 
Equal(Stream stream, bool equal_nan=false)
Definition primitives.h:880
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:904
 
Erf(Stream stream)
Definition primitives.h:906
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:921
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
ErfInv(Stream stream)
Definition primitives.h:923
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:938
 
Exp(Stream stream)
Definition primitives.h:940
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:955
 
Expm1(Stream stream)
Definition primitives.h:957
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:971
 
FFT(Stream stream, const std::vector< size_t > &axes, bool inverse, bool real)
Definition primitives.h:973
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:997
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Floor(Stream stream)
Definition primitives.h:999
 
Definition primitives.h:1014
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Full(Stream stream)
Definition primitives.h:1016
 
Definition primitives.h:1030
 
Gather(Stream stream, const std::vector< int > &axes, const std::vector< int > &slice_sizes)
Definition primitives.h:1032
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:506
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
GatherMM(Stream stream)
Definition primitives.h:508
 
Definition primitives.h:1511
 
GatherQMM(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1513
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1069
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
GreaterEqual(Stream stream)
Definition primitives.h:1071
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1052
 
Greater(Stream stream)
Definition primitives.h:1054
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1086
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Hadamard(Stream stream, float scale)
Definition primitives.h:1088
 
Definition primitives.h:2126
 
void eval_gpu(const std::vector< array > &inputs, array &output) override
 
Inverse(Stream stream)
Definition primitives.h:2128
 
void eval_cpu(const std::vector< array > &inputs, array &output) override
 
Definition primitives.h:1124
 
LessEqual(Stream stream)
Definition primitives.h:1126
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1107
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Less(Stream stream)
Definition primitives.h:1109
 
Definition primitives.h:1141
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Load(Stream stream, std::shared_ptr< io::Reader > reader, size_t offset, bool swap_endianness=false)
Definition primitives.h:1143
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1199
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Log1p(Stream stream)
Definition primitives.h:1201
 
Definition primitives.h:1266
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
LogAddExp(Stream stream)
Definition primitives.h:1268
 
Definition primitives.h:1165
 
Base
Definition primitives.h:1167
 
Log(Stream stream, Base base)
Definition primitives.h:1169
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1232
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
LogicalAnd(Stream stream)
Definition primitives.h:1234
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1215
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
LogicalNot(Stream stream)
Definition primitives.h:1217
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1249
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
LogicalOr(Stream stream)
Definition primitives.h:1251
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1283
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Matmul(Stream stream)
Definition primitives.h:1285
 
Definition primitives.h:1301
 
Maximum(Stream stream)
Definition primitives.h:1303
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1318
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Minimum(Stream stream)
Definition primitives.h:1320
 
Definition primitives.h:1335
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Multiply(Stream stream)
Definition primitives.h:1337
 
Definition primitives.h:1352
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Negative(Stream stream)
Definition primitives.h:1354
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1369
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
NotEqual(Stream stream)
Definition primitives.h:1371
 
Definition primitives.h:1386
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
NumberOfElements(Stream stream, std::vector< int > axes, bool inverted, Dtype dtype)
Definition primitives.h:1388
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1417
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Pad(Stream stream, const std::vector< int > &axes, const std::vector< int > &low_pad_size, const std::vector< int > &high_pad_size)
Definition primitives.h:1419
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1445
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Partition(Stream stream, int kth, int axis)
Definition primitives.h:1447
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1466
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Power(Stream stream)
Definition primitives.h:1468
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:48
 
virtual void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
 
virtual std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
The vector-Jacobian product.
 
virtual ~Primitive()=default
 
Primitive(const Primitive &other)=delete
 
Primitive(Primitive &&other)=delete
 
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
 
Primitive & operator=(Primitive &&other)=delete
 
virtual bool is_equivalent(const Primitive &other) const
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:107
 
Primitive & operator=(const Primitive &other)=delete
 
virtual std::vector< std::vector< int > > output_shapes(const std::vector< array > &inputs)
Get the output shapes of the primitive.
 
const Device & device()
The device the primitive will run on.
Definition primitives.h:53
 
virtual std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
The Jacobian-vector product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes)
The primitive must know how to vectorize itself across the given axes.
 
virtual void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
 
virtual void print(std::ostream &os)=0
Print the primitive.
 
Primitive(Stream stream)
Definition primitives.h:50
 
Definition primitives.h:2093
 
QRF(Stream stream)
Definition primitives.h:2095
 
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
 
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
 
Definition primitives.h:1483
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
QuantizedMatmul(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1485
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1535
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
RandomBits(Stream stream, const std::vector< int > &shape, int width)
Definition primitives.h:1537
 
Definition primitives.h:1581
 
Reduce(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1585
 
ReduceType
Definition primitives.h:1583
 
@ And
Definition primitives.h:1583
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:861
 
Remainder(Stream stream)
Definition primitives.h:863
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1554
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Reshape(Stream stream, const std::vector< int > &shape)
Definition primitives.h:1556
 
Definition primitives.h:1636
 
Round(Stream stream)
Definition primitives.h:1638
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:2109
 
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
 
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
 
SVD(Stream stream)
Definition primitives.h:2111
 
Definition primitives.h:1653
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
ReduceType
Definition primitives.h:1655
 
@ Max
Definition primitives.h:1655
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
Scan(Stream stream, ReduceType reduce_type, int axis, bool reverse, bool inclusive)
Definition primitives.h:1657
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1703
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
ReduceType
Definition primitives.h:1705
 
@ Max
Definition primitives.h:1705
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1717
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Scatter(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1707
 
Definition primitives.h:844
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Select(Stream stream)
Definition primitives.h:846
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1744
 
Sigmoid(Stream stream)
Definition primitives.h:1746
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1761
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Sign(Stream stream)
Definition primitives.h:1763
 
Definition primitives.h:1778
 
Sin(Stream stream)
Definition primitives.h:1780
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1795
 
Sinh(Stream stream)
Definition primitives.h:1797
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1812
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Slice(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
Definition primitives.h:1814
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1840
 
SliceUpdate(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
Definition primitives.h:1842
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1870
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Softmax(Stream stream, bool precise)
Definition primitives.h:1872
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1890
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Sort(Stream stream, int axis)
Definition primitives.h:1892
 
Definition primitives.h:1910
 
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
 
Split(Stream stream, const std::vector< int > &indices, int axis)
Definition primitives.h:1912
 
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
 
Definition primitives.h:1949
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Sqrt(Stream stream, bool recip=false)
Definition primitives.h:1951
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:1932
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Square(Stream stream)
Definition primitives.h:1934
 
Definition primitives.h:1975
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
StopGradient(Stream stream)
Definition primitives.h:1977
 
Definition primitives.h:1991
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Subtract(Stream stream)
Definition primitives.h:1993
 
Definition primitives.h:2008
 
Tan(Stream stream)
Definition primitives.h:2010
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:2025
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Tanh(Stream stream)
Definition primitives.h:2027
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:2073
 
Transpose(Stream stream, const std::vector< int > &axes)
Definition primitives.h:2075
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Definition primitives.h:127
 
UnaryPrimitive & operator=(const UnaryPrimitive &other)=delete
 
UnaryPrimitive(Stream stream)
An abstract base class for a primitive with a single output.
Definition primitives.h:132
 
virtual void eval_gpu(const std::vector< array > &inputs, array &output)=0
 
UnaryPrimitive(UnaryPrimitive &&other)=delete
 
virtual void eval_cpu(const std::vector< array > &inputs, array &output)=0
 
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:142
 
UnaryPrimitive(const UnaryPrimitive &other)=delete
 
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:137
 
UnaryPrimitive & operator=(UnaryPrimitive &&other)=delete
 
virtual ~UnaryPrimitive()=default
 
Definition primitives.h:2057
 
void eval_cpu(const std::vector< array > &inputs, array &out) override
 
View(Stream stream, Dtype dtype)
Definition primitives.h:2059
 
void eval_gpu(const std::vector< array > &inputs, array &out) override
 
Op op
Definition binary.h:141
 
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
 
array transpose(const array &a, std::vector< int > axes, StreamOrDevice s={})
Permutes the dimensions according to the given axes.
 
std::pair< std::vector< array >, std::vector< array > > jvp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &tangents)
Computes the output and Jacobian-vector product (JVP) of a function.
 
std::pair< std::vector< array >, std::vector< array > > vjp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &cotangents)
Computes the output and vector-Jacobian product (VJP) of a function.
 
void eval(std::vector< array > outputs)
 
std::function< array(const array &) vmap)(const std::function< array(const array &)> &fun, int in_axis=0, int out_axis=0)
Automatically vectorize a unary function over the requested axes.
 
#define DEFINE_DEFAULT_IS_EQUIVALENT()
Definition primitives.h:34
 
#define DEFINE_PRINT(PRIMITIVE)
Definition primitives.h:29
 
#define DEFINE_INPUT_OUTPUT_SHAPE()
Definition primitives.h:39
 
#define DEFINE_GRADS()
Definition primitives.h:17
 
#define DEFINE_VMAP()
Definition primitives.h:12
 
Definition binary_ops.h:270
 
Definition binary_ops.h:277
 
Device device
Definition stream.h:11