Export / import functions to / from a file (#1642)

* export and import functions

* refactor + works for few primitives

* nit

* allow primitives with state

* nit

* nit

* simplify serialize / deserialize

* fix for constants

* python bindings

* maybe fix serialize failure case

* add example

* more primitives, training kind of works

* same result for python and c++

* some fixes

* fix export

* template it up

* some simplificatoin

* rebase

* allow kwargs and multiple functions

* exporter

* more primitives for exporting

* deal with endianness

* handle invalid stream

* add docstring
This commit is contained in:
Awni Hannun
2024-12-24 11:19:13 -08:00
committed by GitHub
parent 935c8c4bb1
commit 4ba0c24a8f
35 changed files with 2239 additions and 90 deletions

View File

@@ -8,14 +8,12 @@
#include <nanobind/stl/vector.h>
#include <algorithm>
#include <fstream>
#include <numeric>
#include <sstream>
#include "mlx/array.h"
#include "mlx/compile.h"
#include "mlx/compile_impl.h"
#include "mlx/graph_utils.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
#include "mlx/utils.h"
@@ -945,34 +943,34 @@ void init_transforms(nb::module_& m) {
Note, all custom transformations are optional. Undefined transformations
fall back to the default behaviour.
Example usage:
Example:
.. code-block:: python
.. code-block:: python
import mlx.core as mx
import mlx.core as mx
@mx.custom_function
def f(x, y):
return mx.sin(x) * y
@mx.custom_function
def f(x, y):
return mx.sin(x) * y
@f.vjp
def f_vjp(primals, cotangent, output):
@f.vjp
def f_vjp(primals, cotangent, output):
x, y = primals
return cotan * mx.cos(x) * y, cotan * mx.sin(x)
@f.jvp
def f_jvp(primals, tangents):
x, y = primals
return cotan * mx.cos(x) * y, cotan * mx.sin(x)
dx, dy = tangents
return dx * mx.cos(x) * y + dy * mx.sin(x)
@f.jvp
def f_jvp(primals, tangents):
x, y = primals
dx, dy = tangents
return dx * mx.cos(x) * y + dy * mx.sin(x)
@f.vmap
def f_vmap(inputs, axes):
x, y = inputs
ax, ay = axes
if ay != ax and ax is not None:
y = y.swapaxes(ay, ax)
return mx.sin(x) * y, (ax or ay)
@f.vmap
def f_vmap(inputs, axes):
x, y = inputs
ax, ay = axes
if ay != ax and ax is not None:
y = y.swapaxes(ay, ax)
return mx.sin(x) * y, (ax or ay)
)pbdoc")
.def(
nb::init<nb::callable>(),
@@ -1313,25 +1311,6 @@ void init_transforms(nb::module_& m) {
Returns:
Callable: The vectorized function.
)pbdoc");
m.def(
"export_to_dot",
[](nb::object file, const nb::args& args) {
std::vector<mx::array> arrays = tree_flatten(args);
if (nb::isinstance<nb::str>(file)) {
std::ofstream out(nb::cast<std::string>(file));
export_to_dot(out, arrays);
} else if (nb::hasattr(file, "write")) {
std::ostringstream out;
export_to_dot(out, arrays);
auto write = file.attr("write");
write(out.str());
} else {
throw std::invalid_argument(
"export_to_dot accepts file-like objects or strings to be used as filenames");
}
},
"file"_a,
"args"_a);
m.def(
"compile",
[](const nb::callable& fun,