mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-02 05:14:40 +08:00
chore: change function with a destination dictonary object
This commit is contained in:
@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
|
||||
optimizer.update(model, grads)
|
||||
|
||||
# Save the state
|
||||
state = tree_flatten(optimizer.state)
|
||||
mx.save_safetensors("optimizer.safetensors", dict(state))
|
||||
state = tree_flatten(optimizer.state, destination={})
|
||||
mx.save_safetensors("optimizer.safetensors", state)
|
||||
|
||||
# Later on, for example when loading from a checkpoint,
|
||||
# recreate the optimizer and load the state
|
||||
optimizer = optim.Adam(learning_rate=1e-2)
|
||||
|
||||
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
||||
state = tree_unflatten(mx.load("optimizer.safetensors"))
|
||||
optimizer.state = state
|
||||
|
||||
Note, not every optimizer configuation parameter is saved in the state. For
|
||||
|
@@ -7,17 +7,17 @@ Exporting Functions
|
||||
|
||||
MLX has an API to export and import functions to and from a file. This lets you
|
||||
run computations written in one MLX front-end (e.g. Python) in another MLX
|
||||
front-end (e.g. C++).
|
||||
front-end (e.g. C++).
|
||||
|
||||
This guide walks through the basics of the MLX export API with some examples.
|
||||
To see the full list of functions check-out the :ref:`API documentation
|
||||
<export>`.
|
||||
|
||||
Basics of Exporting
|
||||
Basics of Exporting
|
||||
-------------------
|
||||
|
||||
Let's start with a simple example:
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
@@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
|
||||
|
||||
# Both arguments to fun are positional
|
||||
mx.export_function("add.mlxfn", fun, x, y)
|
||||
|
||||
@@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
|
||||
For enclosed arrays inside an exported function, be extra careful to ensure
|
||||
they are evaluated. The computation graph that gets exported will include
|
||||
the computation that produces enclosed inputs.
|
||||
|
||||
|
||||
If the above example was missing ``mx.eval(model.parameters()``, the
|
||||
exported function would include the random initialization of the
|
||||
:obj:`mlx.nn.Module` parameters.
|
||||
@@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
|
||||
# Set the model's parameters to the input parameters
|
||||
model.update(tree_unflatten(list(params.items())))
|
||||
return model(x)
|
||||
|
||||
params = dict(tree_flatten(model.parameters()))
|
||||
|
||||
params = tree_flatten(model.parameters(), destination={})
|
||||
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
||||
|
||||
|
||||
@@ -169,8 +169,8 @@ to export a function which can be used for inputs with variable shapes:
|
||||
|
||||
# Ok
|
||||
out, = imported_abs(mx.array(-1.0))
|
||||
|
||||
# Also ok
|
||||
|
||||
# Also ok
|
||||
out, = imported_abs(mx.array([-1.0, -2.0]))
|
||||
|
||||
With ``shapeless=False`` (which is the default), the second call to
|
||||
@@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
|
||||
def fun(x, y=None):
|
||||
constant = mx.array(3.0)
|
||||
if y is not None:
|
||||
x += y
|
||||
x += y
|
||||
return x + constant
|
||||
|
||||
with mx.exporter("fun.mlxfn", fun) as exporter:
|
||||
@@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
|
||||
print(out)
|
||||
|
||||
In the above example the function constant data, (i.e. ``constant``), is only
|
||||
saved once.
|
||||
saved once.
|
||||
|
||||
Transformations with Imported Functions
|
||||
---------------------------------------
|
||||
@@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
|
||||
# Prints: array(1, dtype=float32)
|
||||
print(dfdx(x))
|
||||
|
||||
# Compile the imported function
|
||||
# Compile the imported function
|
||||
mx.compile(imported_fun)
|
||||
# Prints: array(0, dtype=float32)
|
||||
print(compiled_fun(x)[0])
|
||||
@@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
|
||||
// Prints: array(2, dtype=float32)
|
||||
std::cout << outputs[0] << std::endl;
|
||||
|
||||
Imported functions can be transformed in C++ just like in Python. Use
|
||||
Imported functions can be transformed in C++ just like in Python. Use
|
||||
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
|
||||
mx::array>`` for keyword arguments when calling imported functions in C++.
|
||||
|
||||
|
@@ -178,7 +178,7 @@ class Module(dict):
|
||||
|
||||
if strict:
|
||||
new_weights = dict(weights)
|
||||
curr_weights = dict(tree_flatten(self.parameters()))
|
||||
curr_weights = tree_flatten(self.parameters(), destination={})
|
||||
if extras := (new_weights.keys() - curr_weights.keys()):
|
||||
num_extra = len(extras)
|
||||
extras = ",\n".join(sorted(extras))
|
||||
@@ -212,7 +212,7 @@ class Module(dict):
|
||||
- ``.npz`` will use :func:`mx.savez`
|
||||
- ``.safetensors`` will use :func:`mx.save_safetensors`
|
||||
"""
|
||||
params_dict = dict(tree_flatten(self.parameters()))
|
||||
params_dict = tree_flatten(self.parameters(), destination={})
|
||||
|
||||
if file.endswith(".npz"):
|
||||
mx.savez(file, **params_dict)
|
||||
|
@@ -30,15 +30,16 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(len(flat_children), 3)
|
||||
|
||||
leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module)
|
||||
self.assertEqual(len(leaves), 4)
|
||||
self.assertEqual(leaves[0][0], "layers.0.layers.0")
|
||||
self.assertEqual(leaves[1][0], "layers.1.layers.0")
|
||||
self.assertEqual(leaves[2][0], "layers.1.layers.1")
|
||||
self.assertEqual(leaves[3][0], "layers.2")
|
||||
self.assertTrue(leaves[0][1] is m.layers[0].layers[0])
|
||||
self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
|
||||
self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
|
||||
self.assertTrue(leaves[3][1] is m.layers[2])
|
||||
if isinstance(leaves, list):
|
||||
self.assertEqual(len(leaves), 4)
|
||||
self.assertEqual(leaves[0][0], "layers.0.layers.0")
|
||||
self.assertEqual(leaves[1][0], "layers.1.layers.0")
|
||||
self.assertEqual(leaves[2][0], "layers.1.layers.1")
|
||||
self.assertEqual(leaves[3][0], "layers.2")
|
||||
self.assertTrue(leaves[0][1] is m.layers[0].layers[0])
|
||||
self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
|
||||
self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
|
||||
self.assertTrue(leaves[3][1] is m.layers[2])
|
||||
|
||||
m.eval()
|
||||
|
||||
@@ -80,7 +81,7 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
|
||||
|
||||
model = DictModule()
|
||||
params = dict(tree_flatten(model.parameters()))
|
||||
params = tree_flatten(model.parameters(), destination={})
|
||||
self.assertEqual(len(params), 2)
|
||||
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
|
||||
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))
|
||||
|
Reference in New Issue
Block a user