Simplify the utils a bit

This commit is contained in:
Angelos Katharopoulos
2025-08-04 23:01:36 -07:00
parent 5659b12730
commit 8ff54a9595
2 changed files with 53 additions and 45 deletions

View File

@@ -115,21 +115,26 @@ def tree_map_with_path(
def tree_flatten(
tree: Any,
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None,
prefix: str = "",
is_leaf: Optional[Callable] = None,
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None,
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
"""Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and
complexity.
.. code-block:: python
from mlx.utils import tree_flatten
print(tree_flatten([[[0]]]))
# [("0.0.0", 0)]
print(tree_flatten([[[0]]], prefix=".hello"))
# [("hello.0.0.0", 0)]
tree_flatten({"a": {"b": 1}}, destination=dict())
tree_flatten({"a": {"b": 1}}, destination={})
{"a.b": 1}
.. note::
@@ -137,46 +142,50 @@ def tree_flatten(
Args:
tree (Any): The Python tree to be flattened.
destination(Any, optional): Container to store results. If None, creates an OrderedDict.
Can be a list (for tuples) or dict (for key-value mapping).
prefix (str): A prefix to use for the keys. The first character is
always discarded if it starts with a dot.
always discarded.
is_leaf (callable): An optional callable that returns True if the
passed object is considered a leaf or False otherwise.
destination (list or dict, optional): A list or dictionary to store the
flattened tree. If None an empty list will be used. Default: ``None``.
Returns:
List[Tuple[str, Any]]: The flat representation of the Python tree.
Union[List[Tuple[str, Any]], Dict[str, Any]]: The flat representation of
the Python tree.
"""
if destination is None:
destination = []
def _add_to_destination(k: str, v: Any) -> None:
key: str = k[1:] if k.startswith(".") else k
if isinstance(destination, list):
destination.append((key, v))
elif isinstance(destination, dict):
destination[key] = v
else:
raise ValueError(
f"Unsupported destination type: {type(destination)}. "
"Must be list, tuple, or dict."
)
# Create the function to update the destination. We are taking advantage of
# the fact that list.extend and dict.update have the same API to simplify
# the code a bit.
if isinstance(destination, list):
_add_to_destination = destination.extend
elif isinstance(destination, dict):
_add_to_destination = destination.update
else:
raise ValueError("Destination should be either a list or a dictionary or None")
# Leaf identified by is_leaf so add it and return
if is_leaf is not None and is_leaf(tree):
_add_to_destination(prefix, tree)
_add_to_destination([(prefix[1:], tree)])
return destination
# List or tuple so recursively add each subtree
if isinstance(tree, (list, tuple)):
for i, item in enumerate(tree):
tree_flatten(item, destination, f"{prefix}.{i}", is_leaf)
tree_flatten(item, f"{prefix}.{i}", is_leaf, destination)
return destination
# Dictionary so recursively add each subtree
if isinstance(tree, dict):
for key, value in tree.items():
tree_flatten(value, destination, f"{prefix}.{key}", is_leaf)
tree_flatten(value, f"{prefix}.{key}", is_leaf, destination)
return destination
_add_to_destination(prefix, tree)
# Leaf so add it and return
_add_to_destination([(prefix[1:], tree)])
return destination
@@ -191,6 +200,10 @@ def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
print(d)
# {"hello": {"world": 42}}
d = tree_unflatten({"hello.world": 42})
print(d)
# {"hello": {"world": 42}}
Args:
tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree.
For instance as returned by :meth:`tree_flatten`.
@@ -198,27 +211,23 @@ def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
Returns:
A Python tree.
"""
if isinstance(tree, dict):
tree = list(tree.items())
items = tree.items() if isinstance(tree, dict) else tree
if len(tree) == 1 and tree[0][0] == "":
return tree[0][1]
try:
int(tree[0][0].split(".", maxsplit=1)[0])
is_list = True
except ValueError:
is_list = False
# Special case when we have just one element in the tree ie not a tree
if len(items) == 1:
key, value = next(iter(items))
if key == "":
return value
# collect children
children = defaultdict(list)
for key, value in tree:
for key, value in items:
current_idx, *next_idx = key.split(".", maxsplit=1)
next_idx = "" if not next_idx else next_idx[0]
children[current_idx].append((next_idx, value))
# recursively map them to the original container
if is_list:
# Assume they are a list and fail to dict if the keys are not all integers
try:
keys = sorted((int(idx), idx) for idx in children.keys())
l = []
for i, k in keys:
@@ -226,7 +235,7 @@ def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
l.extend([{} for _ in range(i - len(l))])
l.append(tree_unflatten(children[k]))
return l
else:
except ValueError:
return {k: tree_unflatten(v) for k, v in children.items()}

View File

@@ -30,16 +30,15 @@ class TestBase(mlx_tests.MLXTestCase):
self.assertEqual(len(flat_children), 3)
leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module)
if isinstance(leaves, list):
self.assertEqual(len(leaves), 4)
self.assertEqual(leaves[0][0], "layers.0.layers.0")
self.assertEqual(leaves[1][0], "layers.1.layers.0")
self.assertEqual(leaves[2][0], "layers.1.layers.1")
self.assertEqual(leaves[3][0], "layers.2")
self.assertTrue(leaves[0][1] is m.layers[0].layers[0])
self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
self.assertTrue(leaves[3][1] is m.layers[2])
self.assertEqual(len(leaves), 4)
self.assertEqual(leaves[0][0], "layers.0.layers.0")
self.assertEqual(leaves[1][0], "layers.1.layers.0")
self.assertEqual(leaves[2][0], "layers.1.layers.1")
self.assertEqual(leaves[3][0], "layers.2")
self.assertTrue(leaves[0][1] is m.layers[0].layers[0])
self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
self.assertTrue(leaves[3][1] is m.layers[2])
m.eval()