mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-14 17:12:49 +08:00
Export / import functions to / from a file (#1642)
* export and import functions * refactor + works for few primitives * nit * allow primitives with state * nit * nit * simplify serialize / deserialize * fix for constants * python bindings * maybe fix serialize failure case * add example * more primitives, training kind of works * same result for python and c++ * some fixes * fix export * template it up * some simplificatoin * rebase * allow kwargs and multiple functions * exporter * more primitives for exporting * deal with endianness * handle invalid stream * add docstring
This commit is contained in:
@@ -8,14 +8,12 @@
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/compile.h"
|
||||
#include "mlx/compile_impl.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
#include "mlx/utils.h"
|
||||
@@ -945,34 +943,34 @@ void init_transforms(nb::module_& m) {
|
||||
Note, all custom transformations are optional. Undefined transformations
|
||||
fall back to the default behaviour.
|
||||
|
||||
Example usage:
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.core as mx
|
||||
|
||||
@mx.custom_function
|
||||
def f(x, y):
|
||||
return mx.sin(x) * y
|
||||
@mx.custom_function
|
||||
def f(x, y):
|
||||
return mx.sin(x) * y
|
||||
|
||||
@f.vjp
|
||||
def f_vjp(primals, cotangent, output):
|
||||
@f.vjp
|
||||
def f_vjp(primals, cotangent, output):
|
||||
x, y = primals
|
||||
return cotan * mx.cos(x) * y, cotan * mx.sin(x)
|
||||
|
||||
@f.jvp
|
||||
def f_jvp(primals, tangents):
|
||||
x, y = primals
|
||||
return cotan * mx.cos(x) * y, cotan * mx.sin(x)
|
||||
dx, dy = tangents
|
||||
return dx * mx.cos(x) * y + dy * mx.sin(x)
|
||||
|
||||
@f.jvp
|
||||
def f_jvp(primals, tangents):
|
||||
x, y = primals
|
||||
dx, dy = tangents
|
||||
return dx * mx.cos(x) * y + dy * mx.sin(x)
|
||||
|
||||
@f.vmap
|
||||
def f_vmap(inputs, axes):
|
||||
x, y = inputs
|
||||
ax, ay = axes
|
||||
if ay != ax and ax is not None:
|
||||
y = y.swapaxes(ay, ax)
|
||||
return mx.sin(x) * y, (ax or ay)
|
||||
@f.vmap
|
||||
def f_vmap(inputs, axes):
|
||||
x, y = inputs
|
||||
ax, ay = axes
|
||||
if ay != ax and ax is not None:
|
||||
y = y.swapaxes(ay, ax)
|
||||
return mx.sin(x) * y, (ax or ay)
|
||||
)pbdoc")
|
||||
.def(
|
||||
nb::init<nb::callable>(),
|
||||
@@ -1313,25 +1311,6 @@ void init_transforms(nb::module_& m) {
|
||||
Returns:
|
||||
Callable: The vectorized function.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"export_to_dot",
|
||||
[](nb::object file, const nb::args& args) {
|
||||
std::vector<mx::array> arrays = tree_flatten(args);
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
std::ofstream out(nb::cast<std::string>(file));
|
||||
export_to_dot(out, arrays);
|
||||
} else if (nb::hasattr(file, "write")) {
|
||||
std::ostringstream out;
|
||||
export_to_dot(out, arrays);
|
||||
auto write = file.attr("write");
|
||||
write(out.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"export_to_dot accepts file-like objects or strings to be used as filenames");
|
||||
}
|
||||
},
|
||||
"file"_a,
|
||||
"args"_a);
|
||||
m.def(
|
||||
"compile",
|
||||
[](const nb::callable& fun,
|
||||
|
||||
Reference in New Issue
Block a user