mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-13 04:36:46 +08:00
Support destination arg in tree flatten/unflatten (#2450)
This commit is contained in:
parent
db5c7efcf6
commit
728d4db582
@ -51,14 +51,14 @@ the saved state. Here's a simple example:
|
|||||||
optimizer.update(model, grads)
|
optimizer.update(model, grads)
|
||||||
|
|
||||||
# Save the state
|
# Save the state
|
||||||
state = tree_flatten(optimizer.state)
|
state = tree_flatten(optimizer.state, destination={})
|
||||||
mx.save_safetensors("optimizer.safetensors", dict(state))
|
mx.save_safetensors("optimizer.safetensors", state)
|
||||||
|
|
||||||
# Later on, for example when loading from a checkpoint,
|
# Later on, for example when loading from a checkpoint,
|
||||||
# recreate the optimizer and load the state
|
# recreate the optimizer and load the state
|
||||||
optimizer = optim.Adam(learning_rate=1e-2)
|
optimizer = optim.Adam(learning_rate=1e-2)
|
||||||
|
|
||||||
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
state = tree_unflatten(mx.load("optimizer.safetensors"))
|
||||||
optimizer.state = state
|
optimizer.state = state
|
||||||
|
|
||||||
Note, not every optimizer configuation parameter is saved in the state. For
|
Note, not every optimizer configuation parameter is saved in the state. For
|
||||||
|
@ -7,17 +7,17 @@ Exporting Functions
|
|||||||
|
|
||||||
MLX has an API to export and import functions to and from a file. This lets you
|
MLX has an API to export and import functions to and from a file. This lets you
|
||||||
run computations written in one MLX front-end (e.g. Python) in another MLX
|
run computations written in one MLX front-end (e.g. Python) in another MLX
|
||||||
front-end (e.g. C++).
|
front-end (e.g. C++).
|
||||||
|
|
||||||
This guide walks through the basics of the MLX export API with some examples.
|
This guide walks through the basics of the MLX export API with some examples.
|
||||||
To see the full list of functions check-out the :ref:`API documentation
|
To see the full list of functions check-out the :ref:`API documentation
|
||||||
<export>`.
|
<export>`.
|
||||||
|
|
||||||
Basics of Exporting
|
Basics of Exporting
|
||||||
-------------------
|
-------------------
|
||||||
|
|
||||||
Let's start with a simple example:
|
Let's start with a simple example:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
def fun(x, y):
|
def fun(x, y):
|
||||||
@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
|
|||||||
|
|
||||||
x = mx.array(1.0)
|
x = mx.array(1.0)
|
||||||
y = mx.array(1.0)
|
y = mx.array(1.0)
|
||||||
|
|
||||||
# Both arguments to fun are positional
|
# Both arguments to fun are positional
|
||||||
mx.export_function("add.mlxfn", fun, x, y)
|
mx.export_function("add.mlxfn", fun, x, y)
|
||||||
|
|
||||||
@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
|
|||||||
For enclosed arrays inside an exported function, be extra careful to ensure
|
For enclosed arrays inside an exported function, be extra careful to ensure
|
||||||
they are evaluated. The computation graph that gets exported will include
|
they are evaluated. The computation graph that gets exported will include
|
||||||
the computation that produces enclosed inputs.
|
the computation that produces enclosed inputs.
|
||||||
|
|
||||||
If the above example was missing ``mx.eval(model.parameters()``, the
|
If the above example was missing ``mx.eval(model.parameters()``, the
|
||||||
exported function would include the random initialization of the
|
exported function would include the random initialization of the
|
||||||
:obj:`mlx.nn.Module` parameters.
|
:obj:`mlx.nn.Module` parameters.
|
||||||
@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
|
|||||||
# Set the model's parameters to the input parameters
|
# Set the model's parameters to the input parameters
|
||||||
model.update(tree_unflatten(list(params.items())))
|
model.update(tree_unflatten(list(params.items())))
|
||||||
return model(x)
|
return model(x)
|
||||||
|
|
||||||
params = dict(tree_flatten(model.parameters()))
|
params = tree_flatten(model.parameters(), destination={})
|
||||||
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
||||||
|
|
||||||
|
|
||||||
@ -169,8 +169,8 @@ to export a function which can be used for inputs with variable shapes:
|
|||||||
|
|
||||||
# Ok
|
# Ok
|
||||||
out, = imported_abs(mx.array(-1.0))
|
out, = imported_abs(mx.array(-1.0))
|
||||||
|
|
||||||
# Also ok
|
# Also ok
|
||||||
out, = imported_abs(mx.array([-1.0, -2.0]))
|
out, = imported_abs(mx.array([-1.0, -2.0]))
|
||||||
|
|
||||||
With ``shapeless=False`` (which is the default), the second call to
|
With ``shapeless=False`` (which is the default), the second call to
|
||||||
@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
|
|||||||
def fun(x, y=None):
|
def fun(x, y=None):
|
||||||
constant = mx.array(3.0)
|
constant = mx.array(3.0)
|
||||||
if y is not None:
|
if y is not None:
|
||||||
x += y
|
x += y
|
||||||
return x + constant
|
return x + constant
|
||||||
|
|
||||||
with mx.exporter("fun.mlxfn", fun) as exporter:
|
with mx.exporter("fun.mlxfn", fun) as exporter:
|
||||||
@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
|
|||||||
print(out)
|
print(out)
|
||||||
|
|
||||||
In the above example the function constant data, (i.e. ``constant``), is only
|
In the above example the function constant data, (i.e. ``constant``), is only
|
||||||
saved once.
|
saved once.
|
||||||
|
|
||||||
Transformations with Imported Functions
|
Transformations with Imported Functions
|
||||||
---------------------------------------
|
---------------------------------------
|
||||||
@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
|
|||||||
# Prints: array(1, dtype=float32)
|
# Prints: array(1, dtype=float32)
|
||||||
print(dfdx(x))
|
print(dfdx(x))
|
||||||
|
|
||||||
# Compile the imported function
|
# Compile the imported function
|
||||||
mx.compile(imported_fun)
|
mx.compile(imported_fun)
|
||||||
# Prints: array(0, dtype=float32)
|
# Prints: array(0, dtype=float32)
|
||||||
print(compiled_fun(x)[0])
|
print(compiled_fun(x)[0])
|
||||||
@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
|
|||||||
// Prints: array(2, dtype=float32)
|
// Prints: array(2, dtype=float32)
|
||||||
std::cout << outputs[0] << std::endl;
|
std::cout << outputs[0] << std::endl;
|
||||||
|
|
||||||
Imported functions can be transformed in C++ just like in Python. Use
|
Imported functions can be transformed in C++ just like in Python. Use
|
||||||
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
|
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
|
||||||
mx::array>`` for keyword arguments when calling imported functions in C++.
|
mx::array>`` for keyword arguments when calling imported functions in C++.
|
||||||
|
|
||||||
|
@ -178,7 +178,7 @@ class Module(dict):
|
|||||||
|
|
||||||
if strict:
|
if strict:
|
||||||
new_weights = dict(weights)
|
new_weights = dict(weights)
|
||||||
curr_weights = dict(tree_flatten(self.parameters()))
|
curr_weights = tree_flatten(self.parameters(), destination={})
|
||||||
if extras := (new_weights.keys() - curr_weights.keys()):
|
if extras := (new_weights.keys() - curr_weights.keys()):
|
||||||
num_extra = len(extras)
|
num_extra = len(extras)
|
||||||
extras = ",\n".join(sorted(extras))
|
extras = ",\n".join(sorted(extras))
|
||||||
@ -212,7 +212,7 @@ class Module(dict):
|
|||||||
- ``.npz`` will use :func:`mx.savez`
|
- ``.npz`` will use :func:`mx.savez`
|
||||||
- ``.safetensors`` will use :func:`mx.save_safetensors`
|
- ``.safetensors`` will use :func:`mx.save_safetensors`
|
||||||
"""
|
"""
|
||||||
params_dict = dict(tree_flatten(self.parameters()))
|
params_dict = tree_flatten(self.parameters(), destination={})
|
||||||
|
|
||||||
if file.endswith(".npz"):
|
if file.endswith(".npz"):
|
||||||
mx.savez(file, **params_dict)
|
mx.savez(file, **params_dict)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
from typing import Any, Callable, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
def tree_map(
|
def tree_map(
|
||||||
@ -114,8 +114,11 @@ def tree_map_with_path(
|
|||||||
|
|
||||||
|
|
||||||
def tree_flatten(
|
def tree_flatten(
|
||||||
tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None
|
tree: Any,
|
||||||
) -> Any:
|
prefix: str = "",
|
||||||
|
is_leaf: Optional[Callable] = None,
|
||||||
|
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None,
|
||||||
|
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
|
||||||
"""Flattens a Python tree to a list of key, value tuples.
|
"""Flattens a Python tree to a list of key, value tuples.
|
||||||
|
|
||||||
The keys are using the dot notation to define trees of arbitrary depth and
|
The keys are using the dot notation to define trees of arbitrary depth and
|
||||||
@ -128,9 +131,12 @@ def tree_flatten(
|
|||||||
print(tree_flatten([[[0]]]))
|
print(tree_flatten([[[0]]]))
|
||||||
# [("0.0.0", 0)]
|
# [("0.0.0", 0)]
|
||||||
|
|
||||||
print(tree_flatten([[[0]]], ".hello"))
|
print(tree_flatten([[[0]]], prefix=".hello"))
|
||||||
# [("hello.0.0.0", 0)]
|
# [("hello.0.0.0", 0)]
|
||||||
|
|
||||||
|
tree_flatten({"a": {"b": 1}}, destination={})
|
||||||
|
{"a.b": 1}
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
Dictionaries should have keys that are valid Python identifiers.
|
Dictionaries should have keys that are valid Python identifiers.
|
||||||
|
|
||||||
@ -140,26 +146,50 @@ def tree_flatten(
|
|||||||
always discarded.
|
always discarded.
|
||||||
is_leaf (callable): An optional callable that returns True if the
|
is_leaf (callable): An optional callable that returns True if the
|
||||||
passed object is considered a leaf or False otherwise.
|
passed object is considered a leaf or False otherwise.
|
||||||
|
destination (list or dict, optional): A list or dictionary to store the
|
||||||
|
flattened tree. If None an empty list will be used. Default: ``None``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Tuple[str, Any]]: The flat representation of the Python tree.
|
Union[List[Tuple[str, Any]], Dict[str, Any]]: The flat representation of
|
||||||
|
the Python tree.
|
||||||
"""
|
"""
|
||||||
flat_tree = []
|
if destination is None:
|
||||||
|
destination = []
|
||||||
|
|
||||||
if is_leaf is None or not is_leaf(tree):
|
# Create the function to update the destination. We are taking advantage of
|
||||||
if isinstance(tree, (list, tuple)):
|
# the fact that list.extend and dict.update have the same API to simplify
|
||||||
for i, t in enumerate(tree):
|
# the code a bit.
|
||||||
flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf))
|
if isinstance(destination, list):
|
||||||
return flat_tree
|
_add_to_destination = destination.extend
|
||||||
if isinstance(tree, dict):
|
elif isinstance(destination, dict):
|
||||||
for k, t in tree.items():
|
_add_to_destination = destination.update
|
||||||
flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf))
|
else:
|
||||||
return flat_tree
|
raise ValueError("Destination should be either a list or a dictionary or None")
|
||||||
|
|
||||||
return [(prefix[1:], tree)]
|
# Leaf identified by is_leaf so add it and return
|
||||||
|
if is_leaf is not None and is_leaf(tree):
|
||||||
|
_add_to_destination([(prefix[1:], tree)])
|
||||||
|
return destination
|
||||||
|
|
||||||
|
# List or tuple so recursively add each subtree
|
||||||
|
if isinstance(tree, (list, tuple)):
|
||||||
|
for i, item in enumerate(tree):
|
||||||
|
tree_flatten(item, f"{prefix}.{i}", is_leaf, destination)
|
||||||
|
return destination
|
||||||
|
|
||||||
|
# Dictionary so recursively add each subtree
|
||||||
|
if isinstance(tree, dict):
|
||||||
|
for key, value in tree.items():
|
||||||
|
tree_flatten(value, f"{prefix}.{key}", is_leaf, destination)
|
||||||
|
return destination
|
||||||
|
|
||||||
|
# Leaf so add it and return
|
||||||
|
_add_to_destination([(prefix[1:], tree)])
|
||||||
|
|
||||||
|
return destination
|
||||||
|
|
||||||
|
|
||||||
def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
|
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
|
||||||
"""Recreate a Python tree from its flat representation.
|
"""Recreate a Python tree from its flat representation.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -170,31 +200,34 @@ def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
|
|||||||
print(d)
|
print(d)
|
||||||
# {"hello": {"world": 42}}
|
# {"hello": {"world": 42}}
|
||||||
|
|
||||||
|
d = tree_unflatten({"hello.world": 42})
|
||||||
|
print(d)
|
||||||
|
# {"hello": {"world": 42}}
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tree (list[tuple[str, Any]]): The flat representation of a Python tree.
|
tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree.
|
||||||
For instance as returned by :meth:`tree_flatten`.
|
For instance as returned by :meth:`tree_flatten`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A Python tree.
|
A Python tree.
|
||||||
"""
|
"""
|
||||||
if len(tree) == 1 and tree[0][0] == "":
|
items = tree.items() if isinstance(tree, dict) else tree
|
||||||
return tree[0][1]
|
|
||||||
|
|
||||||
try:
|
# Special case when we have just one element in the tree ie not a tree
|
||||||
int(tree[0][0].split(".", maxsplit=1)[0])
|
if len(items) == 1:
|
||||||
is_list = True
|
key, value = next(iter(items))
|
||||||
except ValueError:
|
if key == "":
|
||||||
is_list = False
|
return value
|
||||||
|
|
||||||
# collect children
|
# collect children
|
||||||
children = defaultdict(list)
|
children = defaultdict(list)
|
||||||
for key, value in tree:
|
for key, value in items:
|
||||||
current_idx, *next_idx = key.split(".", maxsplit=1)
|
current_idx, *next_idx = key.split(".", maxsplit=1)
|
||||||
next_idx = "" if not next_idx else next_idx[0]
|
next_idx = "" if not next_idx else next_idx[0]
|
||||||
children[current_idx].append((next_idx, value))
|
children[current_idx].append((next_idx, value))
|
||||||
|
|
||||||
# recursively map them to the original container
|
# Assume they are a list and fail to dict if the keys are not all integers
|
||||||
if is_list:
|
try:
|
||||||
keys = sorted((int(idx), idx) for idx in children.keys())
|
keys = sorted((int(idx), idx) for idx in children.keys())
|
||||||
l = []
|
l = []
|
||||||
for i, k in keys:
|
for i, k in keys:
|
||||||
@ -202,7 +235,7 @@ def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
|
|||||||
l.extend([{} for _ in range(i - len(l))])
|
l.extend([{} for _ in range(i - len(l))])
|
||||||
l.append(tree_unflatten(children[k]))
|
l.append(tree_unflatten(children[k]))
|
||||||
return l
|
return l
|
||||||
else:
|
except ValueError:
|
||||||
return {k: tree_unflatten(v) for k, v in children.items()}
|
return {k: tree_unflatten(v) for k, v in children.items()}
|
||||||
|
|
||||||
|
|
||||||
|
@ -80,7 +80,7 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
|
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
|
||||||
|
|
||||||
model = DictModule()
|
model = DictModule()
|
||||||
params = dict(tree_flatten(model.parameters()))
|
params = tree_flatten(model.parameters(), destination={})
|
||||||
self.assertEqual(len(params), 2)
|
self.assertEqual(len(params), 2)
|
||||||
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
|
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
|
||||||
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))
|
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))
|
||||||
|
Loading…
Reference in New Issue
Block a user