chore: change function with a destination dictonary object

This commit is contained in:
Luca Vivona
2025-07-31 22:19:55 -04:00
parent a16501fe03
commit 5659b12730
4 changed files with 29 additions and 28 deletions

View File

@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
optimizer.update(model, grads) optimizer.update(model, grads)
# Save the state # Save the state
state = tree_flatten(optimizer.state) state = tree_flatten(optimizer.state, destination={})
mx.save_safetensors("optimizer.safetensors", dict(state)) mx.save_safetensors("optimizer.safetensors", state)
# Later on, for example when loading from a checkpoint, # Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state # recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2) optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(list(mx.load("optimizer.safetensors").items())) state = tree_unflatten(mx.load("optimizer.safetensors"))
optimizer.state = state optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For Note, not every optimizer configuation parameter is saved in the state. For

View File

@@ -7,17 +7,17 @@ Exporting Functions
MLX has an API to export and import functions to and from a file. This lets you MLX has an API to export and import functions to and from a file. This lets you
run computations written in one MLX front-end (e.g. Python) in another MLX run computations written in one MLX front-end (e.g. Python) in another MLX
front-end (e.g. C++). front-end (e.g. C++).
This guide walks through the basics of the MLX export API with some examples. This guide walks through the basics of the MLX export API with some examples.
To see the full list of functions check-out the :ref:`API documentation To see the full list of functions check-out the :ref:`API documentation
<export>`. <export>`.
Basics of Exporting Basics of Exporting
------------------- -------------------
Let's start with a simple example: Let's start with a simple example:
.. code-block:: python .. code-block:: python
def fun(x, y): def fun(x, y):
@@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
x = mx.array(1.0) x = mx.array(1.0)
y = mx.array(1.0) y = mx.array(1.0)
# Both arguments to fun are positional # Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y) mx.export_function("add.mlxfn", fun, x, y)
@@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
For enclosed arrays inside an exported function, be extra careful to ensure For enclosed arrays inside an exported function, be extra careful to ensure
they are evaluated. The computation graph that gets exported will include they are evaluated. The computation graph that gets exported will include
the computation that produces enclosed inputs. the computation that produces enclosed inputs.
If the above example was missing ``mx.eval(model.parameters()``, the If the above example was missing ``mx.eval(model.parameters()``, the
exported function would include the random initialization of the exported function would include the random initialization of the
:obj:`mlx.nn.Module` parameters. :obj:`mlx.nn.Module` parameters.
@@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
# Set the model's parameters to the input parameters # Set the model's parameters to the input parameters
model.update(tree_unflatten(list(params.items()))) model.update(tree_unflatten(list(params.items())))
return model(x) return model(x)
params = dict(tree_flatten(model.parameters())) params = tree_flatten(model.parameters(), destination={})
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params) mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
@@ -169,8 +169,8 @@ to export a function which can be used for inputs with variable shapes:
# Ok # Ok
out, = imported_abs(mx.array(-1.0)) out, = imported_abs(mx.array(-1.0))
# Also ok # Also ok
out, = imported_abs(mx.array([-1.0, -2.0])) out, = imported_abs(mx.array([-1.0, -2.0]))
With ``shapeless=False`` (which is the default), the second call to With ``shapeless=False`` (which is the default), the second call to
@@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
def fun(x, y=None): def fun(x, y=None):
constant = mx.array(3.0) constant = mx.array(3.0)
if y is not None: if y is not None:
x += y x += y
return x + constant return x + constant
with mx.exporter("fun.mlxfn", fun) as exporter: with mx.exporter("fun.mlxfn", fun) as exporter:
@@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
print(out) print(out)
In the above example the function constant data, (i.e. ``constant``), is only In the above example the function constant data, (i.e. ``constant``), is only
saved once. saved once.
Transformations with Imported Functions Transformations with Imported Functions
--------------------------------------- ---------------------------------------
@@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
# Prints: array(1, dtype=float32) # Prints: array(1, dtype=float32)
print(dfdx(x)) print(dfdx(x))
# Compile the imported function # Compile the imported function
mx.compile(imported_fun) mx.compile(imported_fun)
# Prints: array(0, dtype=float32) # Prints: array(0, dtype=float32)
print(compiled_fun(x)[0]) print(compiled_fun(x)[0])
@@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
// Prints: array(2, dtype=float32) // Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl; std::cout << outputs[0] << std::endl;
Imported functions can be transformed in C++ just like in Python. Use Imported functions can be transformed in C++ just like in Python. Use
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string, ``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
mx::array>`` for keyword arguments when calling imported functions in C++. mx::array>`` for keyword arguments when calling imported functions in C++.

View File

@@ -178,7 +178,7 @@ class Module(dict):
if strict: if strict:
new_weights = dict(weights) new_weights = dict(weights)
curr_weights = dict(tree_flatten(self.parameters())) curr_weights = tree_flatten(self.parameters(), destination={})
if extras := (new_weights.keys() - curr_weights.keys()): if extras := (new_weights.keys() - curr_weights.keys()):
num_extra = len(extras) num_extra = len(extras)
extras = ",\n".join(sorted(extras)) extras = ",\n".join(sorted(extras))
@@ -212,7 +212,7 @@ class Module(dict):
- ``.npz`` will use :func:`mx.savez` - ``.npz`` will use :func:`mx.savez`
- ``.safetensors`` will use :func:`mx.save_safetensors` - ``.safetensors`` will use :func:`mx.save_safetensors`
""" """
params_dict = dict(tree_flatten(self.parameters())) params_dict = tree_flatten(self.parameters(), destination={})
if file.endswith(".npz"): if file.endswith(".npz"):
mx.savez(file, **params_dict) mx.savez(file, **params_dict)

View File

@@ -30,15 +30,16 @@ class TestBase(mlx_tests.MLXTestCase):
self.assertEqual(len(flat_children), 3) self.assertEqual(len(flat_children), 3)
leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module) leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module)
self.assertEqual(len(leaves), 4) if isinstance(leaves, list):
self.assertEqual(leaves[0][0], "layers.0.layers.0") self.assertEqual(len(leaves), 4)
self.assertEqual(leaves[1][0], "layers.1.layers.0") self.assertEqual(leaves[0][0], "layers.0.layers.0")
self.assertEqual(leaves[2][0], "layers.1.layers.1") self.assertEqual(leaves[1][0], "layers.1.layers.0")
self.assertEqual(leaves[3][0], "layers.2") self.assertEqual(leaves[2][0], "layers.1.layers.1")
self.assertTrue(leaves[0][1] is m.layers[0].layers[0]) self.assertEqual(leaves[3][0], "layers.2")
self.assertTrue(leaves[1][1] is m.layers[1].layers[0]) self.assertTrue(leaves[0][1] is m.layers[0].layers[0])
self.assertTrue(leaves[2][1] is m.layers[1].layers[1]) self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
self.assertTrue(leaves[3][1] is m.layers[2]) self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
self.assertTrue(leaves[3][1] is m.layers[2])
m.eval() m.eval()
@@ -80,7 +81,7 @@ class TestBase(mlx_tests.MLXTestCase):
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))} self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
model = DictModule() model = DictModule()
params = dict(tree_flatten(model.parameters())) params = tree_flatten(model.parameters(), destination={})
self.assertEqual(len(params), 2) self.assertEqual(len(params), 2)
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2)))) self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2)))) self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))