mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-06 08:24:39 +08:00
chore: change function with a destination dictonary object
This commit is contained in:
@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
|
|||||||
optimizer.update(model, grads)
|
optimizer.update(model, grads)
|
||||||
|
|
||||||
# Save the state
|
# Save the state
|
||||||
state = tree_flatten(optimizer.state)
|
state = tree_flatten(optimizer.state, destination={})
|
||||||
mx.save_safetensors("optimizer.safetensors", dict(state))
|
mx.save_safetensors("optimizer.safetensors", state)
|
||||||
|
|
||||||
# Later on, for example when loading from a checkpoint,
|
# Later on, for example when loading from a checkpoint,
|
||||||
# recreate the optimizer and load the state
|
# recreate the optimizer and load the state
|
||||||
optimizer = optim.Adam(learning_rate=1e-2)
|
optimizer = optim.Adam(learning_rate=1e-2)
|
||||||
|
|
||||||
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
state = tree_unflatten(mx.load("optimizer.safetensors"))
|
||||||
optimizer.state = state
|
optimizer.state = state
|
||||||
|
|
||||||
Note, not every optimizer configuation parameter is saved in the state. For
|
Note, not every optimizer configuation parameter is saved in the state. For
|
||||||
|
@@ -7,17 +7,17 @@ Exporting Functions
|
|||||||
|
|
||||||
MLX has an API to export and import functions to and from a file. This lets you
|
MLX has an API to export and import functions to and from a file. This lets you
|
||||||
run computations written in one MLX front-end (e.g. Python) in another MLX
|
run computations written in one MLX front-end (e.g. Python) in another MLX
|
||||||
front-end (e.g. C++).
|
front-end (e.g. C++).
|
||||||
|
|
||||||
This guide walks through the basics of the MLX export API with some examples.
|
This guide walks through the basics of the MLX export API with some examples.
|
||||||
To see the full list of functions check-out the :ref:`API documentation
|
To see the full list of functions check-out the :ref:`API documentation
|
||||||
<export>`.
|
<export>`.
|
||||||
|
|
||||||
Basics of Exporting
|
Basics of Exporting
|
||||||
-------------------
|
-------------------
|
||||||
|
|
||||||
Let's start with a simple example:
|
Let's start with a simple example:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
def fun(x, y):
|
def fun(x, y):
|
||||||
@@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
|
|||||||
|
|
||||||
x = mx.array(1.0)
|
x = mx.array(1.0)
|
||||||
y = mx.array(1.0)
|
y = mx.array(1.0)
|
||||||
|
|
||||||
# Both arguments to fun are positional
|
# Both arguments to fun are positional
|
||||||
mx.export_function("add.mlxfn", fun, x, y)
|
mx.export_function("add.mlxfn", fun, x, y)
|
||||||
|
|
||||||
@@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
|
|||||||
For enclosed arrays inside an exported function, be extra careful to ensure
|
For enclosed arrays inside an exported function, be extra careful to ensure
|
||||||
they are evaluated. The computation graph that gets exported will include
|
they are evaluated. The computation graph that gets exported will include
|
||||||
the computation that produces enclosed inputs.
|
the computation that produces enclosed inputs.
|
||||||
|
|
||||||
If the above example was missing ``mx.eval(model.parameters()``, the
|
If the above example was missing ``mx.eval(model.parameters()``, the
|
||||||
exported function would include the random initialization of the
|
exported function would include the random initialization of the
|
||||||
:obj:`mlx.nn.Module` parameters.
|
:obj:`mlx.nn.Module` parameters.
|
||||||
@@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
|
|||||||
# Set the model's parameters to the input parameters
|
# Set the model's parameters to the input parameters
|
||||||
model.update(tree_unflatten(list(params.items())))
|
model.update(tree_unflatten(list(params.items())))
|
||||||
return model(x)
|
return model(x)
|
||||||
|
|
||||||
params = dict(tree_flatten(model.parameters()))
|
params = tree_flatten(model.parameters(), destination={})
|
||||||
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
||||||
|
|
||||||
|
|
||||||
@@ -169,8 +169,8 @@ to export a function which can be used for inputs with variable shapes:
|
|||||||
|
|
||||||
# Ok
|
# Ok
|
||||||
out, = imported_abs(mx.array(-1.0))
|
out, = imported_abs(mx.array(-1.0))
|
||||||
|
|
||||||
# Also ok
|
# Also ok
|
||||||
out, = imported_abs(mx.array([-1.0, -2.0]))
|
out, = imported_abs(mx.array([-1.0, -2.0]))
|
||||||
|
|
||||||
With ``shapeless=False`` (which is the default), the second call to
|
With ``shapeless=False`` (which is the default), the second call to
|
||||||
@@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
|
|||||||
def fun(x, y=None):
|
def fun(x, y=None):
|
||||||
constant = mx.array(3.0)
|
constant = mx.array(3.0)
|
||||||
if y is not None:
|
if y is not None:
|
||||||
x += y
|
x += y
|
||||||
return x + constant
|
return x + constant
|
||||||
|
|
||||||
with mx.exporter("fun.mlxfn", fun) as exporter:
|
with mx.exporter("fun.mlxfn", fun) as exporter:
|
||||||
@@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
|
|||||||
print(out)
|
print(out)
|
||||||
|
|
||||||
In the above example the function constant data, (i.e. ``constant``), is only
|
In the above example the function constant data, (i.e. ``constant``), is only
|
||||||
saved once.
|
saved once.
|
||||||
|
|
||||||
Transformations with Imported Functions
|
Transformations with Imported Functions
|
||||||
---------------------------------------
|
---------------------------------------
|
||||||
@@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
|
|||||||
# Prints: array(1, dtype=float32)
|
# Prints: array(1, dtype=float32)
|
||||||
print(dfdx(x))
|
print(dfdx(x))
|
||||||
|
|
||||||
# Compile the imported function
|
# Compile the imported function
|
||||||
mx.compile(imported_fun)
|
mx.compile(imported_fun)
|
||||||
# Prints: array(0, dtype=float32)
|
# Prints: array(0, dtype=float32)
|
||||||
print(compiled_fun(x)[0])
|
print(compiled_fun(x)[0])
|
||||||
@@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
|
|||||||
// Prints: array(2, dtype=float32)
|
// Prints: array(2, dtype=float32)
|
||||||
std::cout << outputs[0] << std::endl;
|
std::cout << outputs[0] << std::endl;
|
||||||
|
|
||||||
Imported functions can be transformed in C++ just like in Python. Use
|
Imported functions can be transformed in C++ just like in Python. Use
|
||||||
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
|
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
|
||||||
mx::array>`` for keyword arguments when calling imported functions in C++.
|
mx::array>`` for keyword arguments when calling imported functions in C++.
|
||||||
|
|
||||||
|
@@ -178,7 +178,7 @@ class Module(dict):
|
|||||||
|
|
||||||
if strict:
|
if strict:
|
||||||
new_weights = dict(weights)
|
new_weights = dict(weights)
|
||||||
curr_weights = dict(tree_flatten(self.parameters()))
|
curr_weights = tree_flatten(self.parameters(), destination={})
|
||||||
if extras := (new_weights.keys() - curr_weights.keys()):
|
if extras := (new_weights.keys() - curr_weights.keys()):
|
||||||
num_extra = len(extras)
|
num_extra = len(extras)
|
||||||
extras = ",\n".join(sorted(extras))
|
extras = ",\n".join(sorted(extras))
|
||||||
@@ -212,7 +212,7 @@ class Module(dict):
|
|||||||
- ``.npz`` will use :func:`mx.savez`
|
- ``.npz`` will use :func:`mx.savez`
|
||||||
- ``.safetensors`` will use :func:`mx.save_safetensors`
|
- ``.safetensors`` will use :func:`mx.save_safetensors`
|
||||||
"""
|
"""
|
||||||
params_dict = dict(tree_flatten(self.parameters()))
|
params_dict = tree_flatten(self.parameters(), destination={})
|
||||||
|
|
||||||
if file.endswith(".npz"):
|
if file.endswith(".npz"):
|
||||||
mx.savez(file, **params_dict)
|
mx.savez(file, **params_dict)
|
||||||
|
@@ -30,15 +30,16 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(len(flat_children), 3)
|
self.assertEqual(len(flat_children), 3)
|
||||||
|
|
||||||
leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module)
|
leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module)
|
||||||
self.assertEqual(len(leaves), 4)
|
if isinstance(leaves, list):
|
||||||
self.assertEqual(leaves[0][0], "layers.0.layers.0")
|
self.assertEqual(len(leaves), 4)
|
||||||
self.assertEqual(leaves[1][0], "layers.1.layers.0")
|
self.assertEqual(leaves[0][0], "layers.0.layers.0")
|
||||||
self.assertEqual(leaves[2][0], "layers.1.layers.1")
|
self.assertEqual(leaves[1][0], "layers.1.layers.0")
|
||||||
self.assertEqual(leaves[3][0], "layers.2")
|
self.assertEqual(leaves[2][0], "layers.1.layers.1")
|
||||||
self.assertTrue(leaves[0][1] is m.layers[0].layers[0])
|
self.assertEqual(leaves[3][0], "layers.2")
|
||||||
self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
|
self.assertTrue(leaves[0][1] is m.layers[0].layers[0])
|
||||||
self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
|
self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
|
||||||
self.assertTrue(leaves[3][1] is m.layers[2])
|
self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
|
||||||
|
self.assertTrue(leaves[3][1] is m.layers[2])
|
||||||
|
|
||||||
m.eval()
|
m.eval()
|
||||||
|
|
||||||
@@ -80,7 +81,7 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
|
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
|
||||||
|
|
||||||
model = DictModule()
|
model = DictModule()
|
||||||
params = dict(tree_flatten(model.parameters()))
|
params = tree_flatten(model.parameters(), destination={})
|
||||||
self.assertEqual(len(params), 2)
|
self.assertEqual(len(params), 2)
|
||||||
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
|
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
|
||||||
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))
|
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))
|
||||||
|
Reference in New Issue
Block a user