7std::pair<std::vector<array>, std::vector<array>>
vmap_trace(
8 const std::function<std::vector<array>(
const std::vector<array>&)>& fun,
9 const std::vector<array>& inputs,
10 const std::vector<int>& in_axes);
13 const std::vector<array>& inputs,
14 const std::vector<array>& s_inputs,
15 const std::vector<array>& s_outputs,
16 const std::vector<int>& in_axes,
17 const std::vector<int>& out_axes);
31 return tracing_counter > 0;
35 static int tracing_counter;
47 return tracing_counter > 0;
51 static int tracing_counter;
std::vector< array > vmap_replace(const std::vector< array > &inputs, const std::vector< array > &s_inputs, const std::vector< array > &s_outputs, const std::vector< int > &in_axes, const std::vector< int > &out_axes)
std::pair< std::vector< array >, std::vector< array > > vmap_trace(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &inputs, const std::vector< int > &in_axes)
Definition transforms_impl.h:22
InTracing()
Definition transforms_impl.h:23
~InTracing()
Definition transforms_impl.h:26
static bool in_tracing()
Definition transforms_impl.h:30
Definition transforms_impl.h:38
static bool retain_graph()
Definition transforms_impl.h:46
~RetainGraph()
Definition transforms_impl.h:42
RetainGraph()
Definition transforms_impl.h:39