Export / import functions to / from a file (#1642)

* export and import functions

* refactor + works for few primitives

* nit

* allow primitives with state

* nit

* nit

* simplify serialize / deserialize

* fix for constants

* python bindings

* maybe fix serialize failure case

* add example

* more primitives, training kind of works

* same result for python and c++

* some fixes

* fix export

* template it up

* some simplificatoin

* rebase

* allow kwargs and multiple functions

* exporter

* more primitives for exporting

* deal with endianness

* handle invalid stream

* add docstring
This commit is contained in:
Awni Hannun
2024-12-24 11:19:13 -08:00
committed by GitHub
parent 935c8c4bb1
commit 4ba0c24a8f
35 changed files with 2239 additions and 90 deletions

View File

@@ -156,9 +156,6 @@ CompileMode& compile_mode() {
return compile_mode_;
}
using ParentsMap =
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
// Helper like below but only merges the two provided arrays. If the src has
// siblings then these won't be merged to the dst.
void merge_one(array& dst, array& src, ParentsMap& parents_map) {
@@ -732,10 +729,15 @@ std::vector<array> compile_replace(
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
}
auto is_load = [](const Primitive& p) { return typeid(p) == typeid(Load); };
for (auto& a : tape) {
// Arrays in the tape without primitives are constants
// and can be used directly
if (!a.has_primitive()) {
// Arrays in the tape without primitives are either:
// - inputs, which are already in the map
// - constants, which can be used directly
// - a load primitive which has no inputs and will become a constant
// after the first eval
if (!a.has_primitive() || is_load(a.primitive())) {
trace_to_real.insert({a.id(), a});
} else {
// Find real inputs