5#include <unordered_map>
13std::function<std::vector<array>(
const std::vector<array>&)>
compile(
14 std::function<std::vector<array>(
const std::vector<array>&)> fun,
15 std::uintptr_t fun_id,
16 bool shapeless =
false,
17 std::vector<uint64_t> constants = {});
29 const std::function<std::vector<array>(
const std::vector<array>&)>& fun,
30 const std::vector<array>& inputs,
34 std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
38 const std::vector<array>& inputs,
39 const std::vector<array>& outputs,
40 const std::vector<array>& original_inputs);
44 std::vector<array>& tape,
46 std::vector<array>& outputs,
50 const std::vector<array>& tape,
51 const std::vector<array>& trace_inputs,
52 const std::vector<array>& trace_outputs,
53 const std::vector<array>& inputs,
Definition binary_ops.h:7
void compile_validate_shapeless(const std::vector< array > &tape)
void compile_simplify(std::vector< array > &tape, ParentsMap &parents_map, std::vector< array > &outputs, int passes)
void compile_clear_cache()
std::pair< std::vector< array >, ParentsMap > compile_dfs(const std::vector< array > &inputs, const std::vector< array > &outputs, const std::vector< array > &original_inputs)
std::vector< array > compile_replace(const std::vector< array > &tape, const std::vector< array > &trace_inputs, const std::vector< array > &trace_outputs, const std::vector< array > &inputs, bool shapeless)
void compile_erase(std::uintptr_t fun_id)
std::unordered_map< std::uintptr_t, std::vector< std::pair< array, int > > > ParentsMap
Definition compile_impl.h:33
std::pair< std::vector< array >, std::vector< array > > compile_trace(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &inputs, bool shapeless)
bool compile_available_for_device(const Device &device)
std::function< std::vector< array >(const std::vector< array > &)> compile(std::function< std::vector< array >(const std::vector< array > &)> fun, std::uintptr_t fun_id, bool shapeless=false, std::vector< uint64_t > constants={})