Support destination arg in tree flatten/unflatten (#2450)

This commit is contained in:
Luca Vivona 2025-08-06 18:34:59 -04:00 committed by GitHub
parent db5c7efcf6
commit 728d4db582
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 81 additions and 48 deletions

View File

@ -51,14 +51,14 @@ the saved state. Here's a simple example:
optimizer.update(model, grads) optimizer.update(model, grads)
# Save the state # Save the state
state = tree_flatten(optimizer.state) state = tree_flatten(optimizer.state, destination={})
mx.save_safetensors("optimizer.safetensors", dict(state)) mx.save_safetensors("optimizer.safetensors", state)
# Later on, for example when loading from a checkpoint, # Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state # recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2) optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(list(mx.load("optimizer.safetensors").items())) state = tree_unflatten(mx.load("optimizer.safetensors"))
optimizer.state = state optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For Note, not every optimizer configuation parameter is saved in the state. For

View File

@ -151,7 +151,7 @@ parameters, pass them as inputs to the ``call`` wrapper:
model.update(tree_unflatten(list(params.items()))) model.update(tree_unflatten(list(params.items())))
return model(x) return model(x)
params = dict(tree_flatten(model.parameters())) params = tree_flatten(model.parameters(), destination={})
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params) mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)

View File

@ -178,7 +178,7 @@ class Module(dict):
if strict: if strict:
new_weights = dict(weights) new_weights = dict(weights)
curr_weights = dict(tree_flatten(self.parameters())) curr_weights = tree_flatten(self.parameters(), destination={})
if extras := (new_weights.keys() - curr_weights.keys()): if extras := (new_weights.keys() - curr_weights.keys()):
num_extra = len(extras) num_extra = len(extras)
extras = ",\n".join(sorted(extras)) extras = ",\n".join(sorted(extras))
@ -212,7 +212,7 @@ class Module(dict):
- ``.npz`` will use :func:`mx.savez` - ``.npz`` will use :func:`mx.savez`
- ``.safetensors`` will use :func:`mx.save_safetensors` - ``.safetensors`` will use :func:`mx.save_safetensors`
""" """
params_dict = dict(tree_flatten(self.parameters())) params_dict = tree_flatten(self.parameters(), destination={})
if file.endswith(".npz"): if file.endswith(".npz"):
mx.savez(file, **params_dict) mx.savez(file, **params_dict)

View File

@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
from collections import defaultdict from collections import defaultdict
from itertools import zip_longest from itertools import zip_longest
from typing import Any, Callable, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple, Union
def tree_map( def tree_map(
@ -114,8 +114,11 @@ def tree_map_with_path(
def tree_flatten( def tree_flatten(
tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None tree: Any,
) -> Any: prefix: str = "",
is_leaf: Optional[Callable] = None,
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None,
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
"""Flattens a Python tree to a list of key, value tuples. """Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and The keys are using the dot notation to define trees of arbitrary depth and
@ -128,9 +131,12 @@ def tree_flatten(
print(tree_flatten([[[0]]])) print(tree_flatten([[[0]]]))
# [("0.0.0", 0)] # [("0.0.0", 0)]
print(tree_flatten([[[0]]], ".hello")) print(tree_flatten([[[0]]], prefix=".hello"))
# [("hello.0.0.0", 0)] # [("hello.0.0.0", 0)]
tree_flatten({"a": {"b": 1}}, destination={})
{"a.b": 1}
.. note:: .. note::
Dictionaries should have keys that are valid Python identifiers. Dictionaries should have keys that are valid Python identifiers.
@ -140,26 +146,50 @@ def tree_flatten(
always discarded. always discarded.
is_leaf (callable): An optional callable that returns True if the is_leaf (callable): An optional callable that returns True if the
passed object is considered a leaf or False otherwise. passed object is considered a leaf or False otherwise.
destination (list or dict, optional): A list or dictionary to store the
flattened tree. If None an empty list will be used. Default: ``None``.
Returns: Returns:
List[Tuple[str, Any]]: The flat representation of the Python tree. Union[List[Tuple[str, Any]], Dict[str, Any]]: The flat representation of
the Python tree.
""" """
flat_tree = [] if destination is None:
destination = []
if is_leaf is None or not is_leaf(tree): # Create the function to update the destination. We are taking advantage of
if isinstance(tree, (list, tuple)): # the fact that list.extend and dict.update have the same API to simplify
for i, t in enumerate(tree): # the code a bit.
flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf)) if isinstance(destination, list):
return flat_tree _add_to_destination = destination.extend
if isinstance(tree, dict): elif isinstance(destination, dict):
for k, t in tree.items(): _add_to_destination = destination.update
flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf)) else:
return flat_tree raise ValueError("Destination should be either a list or a dictionary or None")
return [(prefix[1:], tree)] # Leaf identified by is_leaf so add it and return
if is_leaf is not None and is_leaf(tree):
_add_to_destination([(prefix[1:], tree)])
return destination
# List or tuple so recursively add each subtree
if isinstance(tree, (list, tuple)):
for i, item in enumerate(tree):
tree_flatten(item, f"{prefix}.{i}", is_leaf, destination)
return destination
# Dictionary so recursively add each subtree
if isinstance(tree, dict):
for key, value in tree.items():
tree_flatten(value, f"{prefix}.{key}", is_leaf, destination)
return destination
# Leaf so add it and return
_add_to_destination([(prefix[1:], tree)])
return destination
def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any: def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
"""Recreate a Python tree from its flat representation. """Recreate a Python tree from its flat representation.
.. code-block:: python .. code-block:: python
@ -170,31 +200,34 @@ def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
print(d) print(d)
# {"hello": {"world": 42}} # {"hello": {"world": 42}}
d = tree_unflatten({"hello.world": 42})
print(d)
# {"hello": {"world": 42}}
Args: Args:
tree (list[tuple[str, Any]]): The flat representation of a Python tree. tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree.
For instance as returned by :meth:`tree_flatten`. For instance as returned by :meth:`tree_flatten`.
Returns: Returns:
A Python tree. A Python tree.
""" """
if len(tree) == 1 and tree[0][0] == "": items = tree.items() if isinstance(tree, dict) else tree
return tree[0][1]
try: # Special case when we have just one element in the tree ie not a tree
int(tree[0][0].split(".", maxsplit=1)[0]) if len(items) == 1:
is_list = True key, value = next(iter(items))
except ValueError: if key == "":
is_list = False return value
# collect children # collect children
children = defaultdict(list) children = defaultdict(list)
for key, value in tree: for key, value in items:
current_idx, *next_idx = key.split(".", maxsplit=1) current_idx, *next_idx = key.split(".", maxsplit=1)
next_idx = "" if not next_idx else next_idx[0] next_idx = "" if not next_idx else next_idx[0]
children[current_idx].append((next_idx, value)) children[current_idx].append((next_idx, value))
# recursively map them to the original container # Assume they are a list and fail to dict if the keys are not all integers
if is_list: try:
keys = sorted((int(idx), idx) for idx in children.keys()) keys = sorted((int(idx), idx) for idx in children.keys())
l = [] l = []
for i, k in keys: for i, k in keys:
@ -202,7 +235,7 @@ def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
l.extend([{} for _ in range(i - len(l))]) l.extend([{} for _ in range(i - len(l))])
l.append(tree_unflatten(children[k])) l.append(tree_unflatten(children[k]))
return l return l
else: except ValueError:
return {k: tree_unflatten(v) for k, v in children.items()} return {k: tree_unflatten(v) for k, v in children.items()}

View File

@ -80,7 +80,7 @@ class TestBase(mlx_tests.MLXTestCase):
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))} self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
model = DictModule() model = DictModule()
params = dict(tree_flatten(model.parameters())) params = tree_flatten(model.parameters(), destination={})
self.assertEqual(len(params), 2) self.assertEqual(len(params), 2)
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2)))) self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2)))) self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))