7std::pair<std::vector<array>, std::vector<array>>
vmap_trace(
8 const std::function<std::vector<array>(
const std::vector<array>&)>& fun,
9 const std::vector<array>& inputs,
10 const std::vector<int>& in_axes);
13 const std::vector<array>& inputs,
14 const std::vector<array>& s_inputs,
15 const std::vector<array>& s_outputs,
16 const std::vector<int>& in_axes,
17 const std::vector<int>& out_axes);
21std::function<std::vector<array>(
const std::vector<array>&)>
compile(
22 const std::function<std::vector<array>(
const std::vector<array>&)>& fun,
23 std::uintptr_t fun_id,
24 bool shapeless =
false,
25 std::vector<uint64_t> constants = {});
46 return tracing_counter > 0;
50 static int tracing_counter;
62 return tracing_counter > 0;
66 static int tracing_counter;
std::vector< array > vmap_replace(const std::vector< array > &inputs, const std::vector< array > &s_inputs, const std::vector< array > &s_outputs, const std::vector< int > &in_axes, const std::vector< int > &out_axes)
void compile_clear_cache()
std::function< std::vector< array >(const std::vector< array > &) compile)(const std::function< std::vector< array >(const std::vector< array > &)> &fun, std::uintptr_t fun_id, bool shapeless=false, std::vector< uint64_t > constants={})
std::pair< std::vector< array >, std::vector< array > > vmap_trace(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &inputs, const std::vector< int > &in_axes)
void compile_erase(std::uintptr_t fun_id)
Definition transforms_impl.h:37
InTracing()
Definition transforms_impl.h:38
~InTracing()
Definition transforms_impl.h:41
static bool in_tracing()
Definition transforms_impl.h:45
Definition transforms_impl.h:53
static bool retain_graph()
Definition transforms_impl.h:61
~RetainGraph()
Definition transforms_impl.h:57
RetainGraph()
Definition transforms_impl.h:54