13void eval(std::vector<array> outputs);
 
   16void eval(Arrays&&... outputs) {
 
   17  eval(std::vector<array>{std::forward<Arrays>(outputs)...});
 
 
   27std::pair<std::vector<array>, std::vector<array>> 
vjp(
 
   28    const std::function<std::vector<array>(
const std::vector<array>&)>& fun,
 
   29    const std::vector<array>& primals,
 
   30    const std::vector<array>& cotangents);
 
   35std::pair<array, array> 
vjp(
 
   38    const array& cotangent);
 
   47std::pair<std::vector<array>, std::vector<array>> 
jvp(
 
   48    const std::function<std::vector<array>(
const std::vector<array>&)>& fun,
 
   49    const std::vector<array>& primals,
 
   50    const std::vector<array>& tangents);
 
   55std::pair<array, array> 
jvp(
 
   58    const array& tangent);
 
   64    std::function<std::pair<std::vector<array>, std::vector<array>>(
 
   65        const std::vector<array>&)>;
 
   67    const std::vector<array>&)>;
 
   74    const std::function<std::vector<array>(
const std::vector<array>&)>& fun,
 
   75    const std::vector<int>& argnums);
 
   82    const std::function<std::vector<array>(
const std::vector<array>&)>& fun,
 
 
   92    const std::function<
array(
const array&)>& fun) {
 
   93  return [fun](
auto inputs) { 
return vjp(fun, inputs, 
array(1.0f)); };
 
 
   97    const std::function<
array(
const std::vector<array>&)>& fun,
 
   98    const std::vector<int>& argnums) {
 
   99  return [fun, argnums](
auto inputs) {
 
  101        [fun](
auto inputs) { 
return std::vector<array>{fun(inputs)}; },
 
  104    return std::make_pair(result.first[0], result.second);
 
 
  109    const std::function<
array(
const std::vector<array>&)>& fun,
 
 
  122std::function<std::vector<array>(
const std::vector<array>&)> 
inline grad(
 
  123    const std::function<
array(
const std::vector<array>&)>& fun,
 
  124    const std::vector<int>& argnums) {
 
  126  return [fn](
const std::vector<array>& inputs) { 
return fn(inputs).second; };
 
 
  137std::function<std::vector<array>(
const std::vector<array>&)> 
inline grad(
 
  138    const std::function<
array(
const std::vector<array>&)>& fun,
 
  140  return grad(fun, std::vector<int>{argnum});
 
 
  147    const std::function<
array(
const array&)>& fun) {
 
  149  return [fn](
const array& input) { 
return fn(input).second; };
 
 
  156    const std::function<
array(
const array&)>& fun,
 
  178std::function<std::vector<array>(
const std::vector<array>&)> 
vmap(
 
  179    const std::function<std::vector<array>(
const std::vector<array>&)>& fun,
 
  180    const std::vector<int>& in_axes = {},
 
  181    const std::vector<int>& out_axes = {});
 
  193    std::function<std::vector<array>(
const std::vector<array>&)> fun,
 
  194    std::optional<std::function<std::vector<array>(
 
  195        const std::vector<array>&,
 
  196        const std::vector<array>&,
 
  197        const std::vector<array>&)>> fun_vjp = std::nullopt,
 
  198    std::optional<std::function<std::vector<array>(
 
  199        const std::vector<array>&,
 
  200        const std::vector<array>&,
 
  201        const std::vector<int>&)>> fun_jvp = std::nullopt,
 
  202    std::optional<std::function<std::pair<std::vector<array>, std::vector<int>>(
 
  203        const std::vector<array>&,
 
  204        const std::vector<int>&)>> fun_vmap = std::nullopt);
 
  210std::function<std::vector<array>(
const std::vector<array>&)> 
custom_vjp(
 
  211    std::function<std::vector<array>(
const std::vector<array>&)> fun,
 
  212    std::function<std::vector<array>(
 
  213        const std::vector<array>&,
 
  214        const std::vector<array>&,
 
  215        const std::vector<array>&)> fun_vjp);
 
  221std::function<std::vector<array>(
const std::vector<array>&)> 
checkpoint(
 
  222    std::function<std::vector<array>(
const std::vector<array>&)> fun);
 
void async_eval(std::vector< array > outputs)
 
std::pair< std::vector< array >, std::vector< array > > jvp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &tangents)
Computes the output and Jacobian-vector product (JVP) of a function.
 
std::pair< std::vector< array >, std::vector< array > > vjp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &cotangents)
Computes the output and vector-Jacobian product (VJP) of a function.
 
std::function< std::vector< array >(const std::vector< array > &) checkpoint)(std::function< std::vector< array >(const std::vector< array > &)> fun)
Checkpoint the gradient of a function.
 
std::function< std::pair< array, std::vector< array > >( const std::vector< array > &)> SimpleValueAndGradFn
Definition transforms.h:66
 
std::function< std::pair< array, array >(const array &) value_and_grad)(const std::function< array(const array &)> &fun)
Returns a function which computes the value and gradient of the unary input function.
Definition transforms.h:91
 
std::function< std::vector< array >(const std::vector< array > &) custom_vjp)(std::function< std::vector< array >(const std::vector< array > &)> fun, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> fun_vjp)
Return a function that behaves exactly like fun but if the vjp of the results is computed fun_vjp wil...
 
std::function< std::vector< array >(const std::vector< array > &) custom_function)(std::function< std::vector< array >(const std::vector< array > &)> fun, std::optional< std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> > fun_vjp=std::nullopt, std::optional< std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< int > &)> > fun_jvp=std::nullopt, std::optional< std::function< std::pair< std::vector< array >, std::vector< int > >(const std::vector< array > &, const std::vector< int > &)> > fun_vmap=std::nullopt)
Redefine the transformations of fun according to the provided functions.
 
void eval(std::vector< array > outputs)
 
std::function< array(const array &) vmap)(const std::function< array(const array &)> &fun, int in_axis=0, int out_axis=0)
Automatically vectorize a unary function over the requested axes.
 
std::function< std::vector< array >(const std::vector< array > &) grad)(const std::function< array(const std::vector< array > &)> &fun, const std::vector< int > &argnums)
Returns a function which computes the gradient of the input function with respect to a vector of inpu...
Definition transforms.h:122
 
std::function< std::pair< std::vector< array >, std::vector< array > >( const std::vector< array > &)> ValueAndGradFn
Definition transforms.h:63
 
typename std::enable_if_t< is_arrays_v< T... > > enable_for_arrays_t
Definition array.h:566