mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Export / import functions to / from a file (#1642)
* export and import functions * refactor + works for few primitives * nit * allow primitives with state * nit * nit * simplify serialize / deserialize * fix for constants * python bindings * maybe fix serialize failure case * add example * more primitives, training kind of works * same result for python and c++ * some fixes * fix export * template it up * some simplificatoin * rebase * allow kwargs and multiple functions * exporter * more primitives for exporting * deal with endianness * handle invalid stream * add docstring
This commit is contained in:
@@ -11,6 +11,7 @@ nanobind_add_module(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/convert.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
|
@@ -331,7 +331,7 @@ PyScalarT validate_shape(
|
||||
t = pycomplex;
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "Invalid type " << nb::type_name(l.type()).c_str()
|
||||
msg << "Invalid type " << nb::type_name(l.type()).c_str()
|
||||
<< " received in array initialization.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
277
python/src/export.cpp
Normal file
277
python/src/export.cpp
Normal file
@@ -0,0 +1,277 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/map.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/export.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "python/src/trees.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
std::pair<std::vector<mx::array>, std::map<std::string, mx::array>>
|
||||
validate_and_extract_inputs(
|
||||
const nb::args& args,
|
||||
const nb::kwargs& kwargs,
|
||||
const std::string& prefix) {
|
||||
auto maybe_throw = [&prefix](bool valid) {
|
||||
if (!valid) {
|
||||
throw std::invalid_argument(
|
||||
prefix +
|
||||
" Inputs can either be a variable "
|
||||
"number of positional and keyword arrays or a single tuple "
|
||||
"and/or dictionary of arrays.");
|
||||
}
|
||||
};
|
||||
std::vector<mx::array> args_;
|
||||
std::map<std::string, mx::array> kwargs_;
|
||||
if (args.size() == 0) {
|
||||
// No args so kwargs must be keyword arrays
|
||||
maybe_throw(nb::try_cast(kwargs, kwargs_));
|
||||
} else if (args.size() > 0 && nb::isinstance<mx::array>(args[0])) {
|
||||
// Args are positional arrays and kwargs are keyword arrays
|
||||
maybe_throw(nb::try_cast(args, args_));
|
||||
maybe_throw(nb::try_cast(kwargs, kwargs_));
|
||||
} else if (args.size() == 1) {
|
||||
// - args[0] can be a tuple or list or arrays or a dict
|
||||
// with string keys and array values
|
||||
// - kwargs should be empty
|
||||
maybe_throw(kwargs.size() == 0);
|
||||
if (!nb::try_cast(args[0], args_)) {
|
||||
maybe_throw(nb::try_cast(args[0], kwargs_));
|
||||
}
|
||||
} else if (args.size() == 2) {
|
||||
// - args[0] can be a tuple or list of arrays
|
||||
// - args[1] can be a dict of string keys with array values.
|
||||
// - kwargs should be empty
|
||||
maybe_throw(kwargs.size() == 0);
|
||||
maybe_throw(nb::try_cast(args[0], args_));
|
||||
maybe_throw(nb::try_cast(args[1], kwargs_));
|
||||
} else {
|
||||
maybe_throw(false);
|
||||
}
|
||||
return {args_, kwargs_};
|
||||
}
|
||||
|
||||
auto wrap_export_function(const nb::callable& fun) {
|
||||
return [fun](
|
||||
const std::vector<mx::array>& args_,
|
||||
const std::map<std::string, mx::array>& kwargs_) {
|
||||
auto kwargs = nb::dict();
|
||||
kwargs.update(nb::cast(kwargs_));
|
||||
auto args = nb::tuple(nb::cast(args_));
|
||||
auto outputs = fun(*args, **kwargs);
|
||||
std::vector<mx::array> outputs_;
|
||||
if (nb::isinstance<mx::array>(outputs)) {
|
||||
outputs_.push_back(nb::cast<mx::array>(outputs));
|
||||
} else if (!nb::try_cast(outputs, outputs_)) {
|
||||
throw std::invalid_argument(
|
||||
"[export_function] Outputs can be either a single array "
|
||||
"a tuple or list of arrays.");
|
||||
}
|
||||
return outputs_;
|
||||
};
|
||||
}
|
||||
|
||||
void init_export(nb::module_& m) {
|
||||
m.def(
|
||||
"export_function",
|
||||
[](const std::string& file,
|
||||
const nb::callable& fun,
|
||||
const nb::args& args,
|
||||
bool shapeless,
|
||||
const nb::kwargs& kwargs) {
|
||||
auto [args_, kwargs_] =
|
||||
validate_and_extract_inputs(args, kwargs, "[export_function]");
|
||||
mx::export_function(
|
||||
file, wrap_export_function(fun), args_, kwargs_, shapeless);
|
||||
},
|
||||
"file"_a,
|
||||
"fun"_a,
|
||||
"args"_a,
|
||||
nb::kw_only(),
|
||||
"shapeless"_a = false,
|
||||
"kwargs"_a,
|
||||
R"pbdoc(
|
||||
Export a function to a file.
|
||||
|
||||
Example input arrays must be provided to export a function. The example
|
||||
inputs can be variable ``*args`` and ``**kwargs`` or a tuple of arrays
|
||||
and/or dictionary of string keys with array values.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is part of an experimental API which is likely to
|
||||
change in future versions of MLX. Functions exported with older
|
||||
versions of MLX may not be compatible with future versions.
|
||||
|
||||
Args:
|
||||
file (str): File path to export the function to.
|
||||
fun (Callable): A function which takes as input zero or more
|
||||
:class:`array` and returns one or more :class:`array`.
|
||||
*args (array): Example array inputs to the function.
|
||||
shapeless (bool, optional): Whether or not the function allows
|
||||
inputs with variable shapes. Default: ``False``.
|
||||
**kwargs (array): Additional example keyword array inputs to the
|
||||
function.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return x + y
|
||||
|
||||
x = mx.array(1)
|
||||
y = mx.array([1, 2, 3])
|
||||
mx.export_function("fun.mlxfn", fun, x, y=y)
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"import_function",
|
||||
[](const std::string& file) {
|
||||
return nb::cpp_function(
|
||||
[fn = mx::import_function(file)](
|
||||
const nb::args& args, const nb::kwargs& kwargs) {
|
||||
auto [args_, kwargs_] = validate_and_extract_inputs(
|
||||
args, kwargs, "[import_function::call]");
|
||||
return nb::tuple(nb::cast(fn(args_, kwargs_)));
|
||||
});
|
||||
},
|
||||
"file"_a,
|
||||
nb::sig("def import_function(file: str) -> Callable"),
|
||||
R"pbdoc(
|
||||
Import a function from a file.
|
||||
|
||||
The imported function can be called either with ``*args`` and
|
||||
``**kwargs`` or with a tuple of arrays and/or dictionary of string
|
||||
keys with array values. Imported functions always return a tuple of
|
||||
arrays.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is part of an experimental API which is likely to
|
||||
change in future versions of MLX. Functions exported with older
|
||||
versions of MLX may not be compatible with future versions.
|
||||
|
||||
Args:
|
||||
file (str): The file path to import the function from.
|
||||
|
||||
Returns:
|
||||
Callable: The imported function.
|
||||
|
||||
Example:
|
||||
>>> fn = mx.import_function("function.mlxfn")
|
||||
>>> out = fn(a, b, x=x, y=y)[0]
|
||||
>>>
|
||||
>>> out = fn((a, b), {"x": x, "y": y}[0]
|
||||
)pbdoc");
|
||||
|
||||
nb::class_<mx::FunctionExporter>(
|
||||
m,
|
||||
"FunctionExporter",
|
||||
R"pbdoc(
|
||||
A context managing class for exporting multiple traces of the same
|
||||
function to a file.
|
||||
|
||||
Make an instance of this class by calling fun:`mx.exporter`.
|
||||
)pbdoc")
|
||||
.def("close", &mx::FunctionExporter::close)
|
||||
.def(
|
||||
"__enter__", [](mx::FunctionExporter& exporter) { return &exporter; })
|
||||
.def(
|
||||
"__exit__",
|
||||
[](mx::FunctionExporter& exporter,
|
||||
const std::optional<nb::object>&,
|
||||
const std::optional<nb::object>&,
|
||||
const std::optional<nb::object>&) { exporter.close(); },
|
||||
"exc_type"_a = nb::none(),
|
||||
"exc_value"_a = nb::none(),
|
||||
"traceback"_a = nb::none())
|
||||
.def(
|
||||
"__call__",
|
||||
[](mx::FunctionExporter& exporter,
|
||||
const nb::args& args,
|
||||
const nb::kwargs& kwargs) {
|
||||
auto [args_, kwargs_] =
|
||||
validate_and_extract_inputs(args, kwargs, "[export_function]");
|
||||
exporter(args_, kwargs_);
|
||||
});
|
||||
|
||||
m.def(
|
||||
"exporter",
|
||||
[](const std::string& file, const nb::callable& fun, bool shapeless) {
|
||||
return mx::exporter(file, wrap_export_function(fun), shapeless);
|
||||
},
|
||||
"file"_a,
|
||||
"fun"_a,
|
||||
nb::kw_only(),
|
||||
"shapeless"_a = false,
|
||||
R"pbdoc(
|
||||
Make a callable object to export multiple traces of a function to a file.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is part of an experimental API which is likely to
|
||||
change in future versions of MLX. Functions exported with older
|
||||
versions of MLX may not be compatible with future versions.
|
||||
|
||||
Args:
|
||||
file (str): File path to export the function to.
|
||||
shapeless (bool, optional): Whether or not the function allows
|
||||
inputs with variable shapes. Default: ``False``.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(*args):
|
||||
return sum(args)
|
||||
|
||||
with mx.exporter("fun.mlxfn", fun) as exporter:
|
||||
exporter(mx.array(1))
|
||||
exporter(mx.array(1), mx.array(2))
|
||||
exporter(mx.array(1), mx.array(2), mx.array(3))
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"export_to_dot",
|
||||
[](nb::object file, const nb::args& args) {
|
||||
std::vector<mx::array> arrays = tree_flatten(args);
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
std::ofstream out(nb::cast<std::string>(file));
|
||||
mx::export_to_dot(out, arrays);
|
||||
} else if (nb::hasattr(file, "write")) {
|
||||
std::ostringstream out;
|
||||
mx::export_to_dot(out, arrays);
|
||||
auto write = file.attr("write");
|
||||
write(out.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[export_to_dot] Accepts file-like objects or strings "
|
||||
"to be used as filenames.");
|
||||
}
|
||||
},
|
||||
"file"_a,
|
||||
"args"_a,
|
||||
R"pbdoc(
|
||||
Export a graph to DOT format for visualization.
|
||||
|
||||
A variable number of output arrays can be provided for exporting
|
||||
The graph exported will recursively include all enevaluated inputs of
|
||||
the provided outputs.
|
||||
|
||||
Args:
|
||||
file (str): The file path to export to.
|
||||
*args (array): The output arrays.
|
||||
|
||||
Example:
|
||||
>>> a = mx.array(1) + mx.array(2)
|
||||
>>> mx.export_to_dot("graph.dot", a)
|
||||
)pbdoc");
|
||||
}
|
@@ -19,6 +19,7 @@ void init_linalg(nb::module_&);
|
||||
void init_constants(nb::module_&);
|
||||
void init_fast(nb::module_&);
|
||||
void init_distributed(nb::module_&);
|
||||
void init_export(nb::module_&);
|
||||
|
||||
NB_MODULE(core, m) {
|
||||
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
||||
@@ -39,6 +40,7 @@ NB_MODULE(core, m) {
|
||||
init_constants(m);
|
||||
init_fast(m);
|
||||
init_distributed(m);
|
||||
init_export(m);
|
||||
|
||||
m.attr("__version__") = TOSTRING(_VERSION_);
|
||||
}
|
||||
|
@@ -2898,7 +2898,7 @@ void init_ops(nb::module_& m) {
|
||||
Generate multidimensional coordinate grids from 1-D coordinate arrays
|
||||
|
||||
Args:
|
||||
arrays (array): Input arrays.
|
||||
*arrays (array): Input arrays.
|
||||
sparse (bool, optional): If ``True``, a sparse grid is returned in which each output
|
||||
array has a single non-zero element. If ``False``, a dense grid is returned.
|
||||
Defaults to ``False``.
|
||||
@@ -3840,8 +3840,8 @@ void init_ops(nb::module_& m) {
|
||||
|
||||
Args:
|
||||
file (file, str): Path to file to which the arrays are saved.
|
||||
args (arrays): Arrays to be saved.
|
||||
kwargs (arrays): Arrays to be saved. Each array will be saved
|
||||
*args (arrays): Arrays to be saved.
|
||||
**kwargs (arrays): Arrays to be saved. Each array will be saved
|
||||
with the associated keyword as the output file name.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
|
@@ -8,14 +8,12 @@
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/compile.h"
|
||||
#include "mlx/compile_impl.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
#include "mlx/utils.h"
|
||||
@@ -945,34 +943,34 @@ void init_transforms(nb::module_& m) {
|
||||
Note, all custom transformations are optional. Undefined transformations
|
||||
fall back to the default behaviour.
|
||||
|
||||
Example usage:
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.core as mx
|
||||
|
||||
@mx.custom_function
|
||||
def f(x, y):
|
||||
return mx.sin(x) * y
|
||||
@mx.custom_function
|
||||
def f(x, y):
|
||||
return mx.sin(x) * y
|
||||
|
||||
@f.vjp
|
||||
def f_vjp(primals, cotangent, output):
|
||||
@f.vjp
|
||||
def f_vjp(primals, cotangent, output):
|
||||
x, y = primals
|
||||
return cotan * mx.cos(x) * y, cotan * mx.sin(x)
|
||||
|
||||
@f.jvp
|
||||
def f_jvp(primals, tangents):
|
||||
x, y = primals
|
||||
return cotan * mx.cos(x) * y, cotan * mx.sin(x)
|
||||
dx, dy = tangents
|
||||
return dx * mx.cos(x) * y + dy * mx.sin(x)
|
||||
|
||||
@f.jvp
|
||||
def f_jvp(primals, tangents):
|
||||
x, y = primals
|
||||
dx, dy = tangents
|
||||
return dx * mx.cos(x) * y + dy * mx.sin(x)
|
||||
|
||||
@f.vmap
|
||||
def f_vmap(inputs, axes):
|
||||
x, y = inputs
|
||||
ax, ay = axes
|
||||
if ay != ax and ax is not None:
|
||||
y = y.swapaxes(ay, ax)
|
||||
return mx.sin(x) * y, (ax or ay)
|
||||
@f.vmap
|
||||
def f_vmap(inputs, axes):
|
||||
x, y = inputs
|
||||
ax, ay = axes
|
||||
if ay != ax and ax is not None:
|
||||
y = y.swapaxes(ay, ax)
|
||||
return mx.sin(x) * y, (ax or ay)
|
||||
)pbdoc")
|
||||
.def(
|
||||
nb::init<nb::callable>(),
|
||||
@@ -1313,25 +1311,6 @@ void init_transforms(nb::module_& m) {
|
||||
Returns:
|
||||
Callable: The vectorized function.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"export_to_dot",
|
||||
[](nb::object file, const nb::args& args) {
|
||||
std::vector<mx::array> arrays = tree_flatten(args);
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
std::ofstream out(nb::cast<std::string>(file));
|
||||
export_to_dot(out, arrays);
|
||||
} else if (nb::hasattr(file, "write")) {
|
||||
std::ostringstream out;
|
||||
export_to_dot(out, arrays);
|
||||
auto write = file.attr("write");
|
||||
write(out.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"export_to_dot accepts file-like objects or strings to be used as filenames");
|
||||
}
|
||||
},
|
||||
"file"_a,
|
||||
"args"_a);
|
||||
m.def(
|
||||
"compile",
|
||||
[](const nb::callable& fun,
|
||||
|
244
python/tests/test_export_import.py
Normal file
244
python/tests/test_export_import.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestExportImport(mlx_tests.MLXTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
||||
cls.test_dir = cls.test_dir_fid.name
|
||||
if not os.path.isdir(cls.test_dir):
|
||||
os.mkdir(cls.test_dir)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.test_dir_fid.cleanup()
|
||||
|
||||
def test_basic_export_import(self):
|
||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||
|
||||
# Function with no inputs
|
||||
def fun():
|
||||
return mx.zeros((3, 3))
|
||||
|
||||
mx.export_function(path, fun)
|
||||
imported = mx.import_function(path)
|
||||
|
||||
expected = fun()
|
||||
(out,) = imported()
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
# Simple function with inputs
|
||||
def fun(x):
|
||||
return mx.abs(mx.sin(x))
|
||||
|
||||
inputs = mx.array([1.0, 2.0, 3.0, 4.0, 5.0])
|
||||
|
||||
mx.export_function(path, fun, inputs)
|
||||
imported = mx.import_function(path)
|
||||
|
||||
expected = fun(inputs)
|
||||
(out,) = imported(inputs)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
# Inputs in a list or tuple
|
||||
def fun(x):
|
||||
x = mx.abs(mx.sin(x))
|
||||
return x
|
||||
|
||||
mx.export_function(path, fun, [inputs])
|
||||
imported = mx.import_function(path)
|
||||
|
||||
expected = fun(inputs)
|
||||
(out,) = imported([inputs])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
(out,) = imported(inputs)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
mx.export_function(path, fun, (inputs,))
|
||||
imported = mx.import_function(path)
|
||||
(out,) = imported((inputs,))
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
# Outputs in a list
|
||||
def fun(x):
|
||||
return [mx.abs(mx.sin(x))]
|
||||
|
||||
mx.export_function(path, fun, inputs)
|
||||
imported = mx.import_function(path)
|
||||
(out,) = imported(inputs)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
# Outputs in a tuple
|
||||
def fun(x):
|
||||
return (mx.abs(mx.sin(x)),)
|
||||
|
||||
mx.export_function(path, fun, inputs)
|
||||
imported = mx.import_function(path)
|
||||
(out,) = imported(inputs)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
# Check throws on invalid inputs / outputs
|
||||
def fun(x):
|
||||
return mx.abs(x)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.export_function(path, fun, "hi")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.export_function(path, fun, mx.array(1.0), "hi")
|
||||
|
||||
def fun(x):
|
||||
return mx.abs(x[0][0])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.export_function(path, fun, [[mx.array(1.0)]])
|
||||
|
||||
def fun():
|
||||
return (mx.zeros((3, 3)), 1)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.export_function(path, fun)
|
||||
|
||||
def fun():
|
||||
return (mx.zeros((3, 3)), [mx.zeros((3, 3))])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.export_function(path, fun)
|
||||
|
||||
def fun(x, y):
|
||||
return x + y
|
||||
|
||||
mx.export_function(path, fun, mx.array(1.0), mx.array(1.0))
|
||||
imported = mx.import_function(path)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
imported(mx.array(1.0), 1.0)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
imported(mx.array(1.0), mx.array(1.0), mx.array(1.0))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
imported(mx.array(1.0), [mx.array(1.0)])
|
||||
|
||||
def test_export_random_sample(self):
|
||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||
|
||||
mx.random.seed(5)
|
||||
|
||||
def fun():
|
||||
return mx.random.uniform(shape=(3,))
|
||||
|
||||
mx.export_function(path, fun)
|
||||
imported = mx.import_function(path)
|
||||
|
||||
(out,) = imported()
|
||||
|
||||
mx.random.seed(5)
|
||||
expected = fun()
|
||||
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
def test_export_with_kwargs(self):
|
||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||
|
||||
def fun(x, z=None):
|
||||
out = x
|
||||
if z is not None:
|
||||
out += z
|
||||
return out
|
||||
|
||||
x = mx.array([1, 2, 3])
|
||||
y = mx.array([1, 1, 0])
|
||||
z = mx.array([2, 2, 2])
|
||||
|
||||
mx.export_function(path, fun, (x,), {"z": z})
|
||||
imported_fun = mx.import_function(path)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
imported_fun(x, z)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
imported_fun(x, y=z)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
imported_fun((x,), {"y": z})
|
||||
|
||||
out = imported_fun(x, z=z)[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||
|
||||
out = imported_fun((x,), {"z": z})[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||
|
||||
mx.export_function(path, fun, x, z=z)
|
||||
imported_fun = mx.import_function(path)
|
||||
out = imported_fun(x, z=z)[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||
|
||||
out = imported_fun((x,), {"z": z})[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||
|
||||
# Only specify kwargs
|
||||
mx.export_function(path, fun, x=x, z=z)
|
||||
imported_fun = mx.import_function(path)
|
||||
with self.assertRaises(ValueError):
|
||||
out = imported_fun(x, z=z)[0]
|
||||
|
||||
out = imported_fun(x=x, z=z)[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||
|
||||
out = imported_fun({"x": x, "z": z})[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||
|
||||
def test_export_variable_inputs(self):
|
||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||
|
||||
def fun(x, y, z=None):
|
||||
out = x + y
|
||||
if z is not None:
|
||||
out += z
|
||||
return out
|
||||
|
||||
with mx.exporter(path, fun) as exporter:
|
||||
exporter(mx.array([1, 2, 3]), mx.array([1, 1, 1]))
|
||||
exporter(mx.array([1, 2, 3]), mx.array([1, 1, 1]), z=mx.array([2]))
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
exporter(mx.array([1, 2, 3, 4]), mx.array([1, 1, 1, 1]))
|
||||
|
||||
imported_fun = mx.import_function(path)
|
||||
out = imported_fun(mx.array([1, 2, 3]), mx.array([1, 1, 1]))[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([2, 3, 4])))
|
||||
|
||||
out = imported_fun(mx.array([1, 2, 3]), mx.array([1, 1, 1]), z=mx.array([2]))[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([4, 5, 6])))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
imported_fun(mx.array([1, 2, 3, 4]), mx.array([1, 1, 1, 1]))
|
||||
|
||||
# A function with a large constant
|
||||
constant = mx.zeros((16, 2048))
|
||||
mx.eval(constant)
|
||||
|
||||
def fun(*args):
|
||||
return constant + sum(args)
|
||||
|
||||
with mx.exporter(path, fun) as exporter:
|
||||
for i in range(5):
|
||||
exporter(*[mx.array(1)] * i)
|
||||
|
||||
# Check the exported file size < constant size + small amount
|
||||
constants_size = constant.nbytes + 8192
|
||||
self.assertTrue(os.path.getsize(path) < constants_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@@ -28,15 +28,14 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
def setUpClass(cls):
|
||||
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
||||
cls.test_dir = cls.test_dir_fid.name
|
||||
if not os.path.isdir(cls.test_dir):
|
||||
os.mkdir(cls.test_dir)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.test_dir_fid.cleanup()
|
||||
|
||||
def test_save_and_load(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
for dt in self.dtypes:
|
||||
with self.subTest(dtype=dt):
|
||||
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
||||
@@ -64,9 +63,6 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
||||
|
||||
def test_save_and_load_safetensors(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
test_file = os.path.join(self.test_dir, "test.safetensors")
|
||||
with self.assertRaises(Exception):
|
||||
mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
|
||||
@@ -330,9 +326,6 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.array_equal(save_arrs_npy[k], v))
|
||||
|
||||
def test_non_contiguous(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
a = mx.broadcast_to(mx.array([1, 2]), [4, 2])
|
||||
|
||||
save_file = os.path.join(self.test_dir, "a.npy")
|
||||
|
Reference in New Issue
Block a user