feat: support dict and list destinations in tree_flatten, add tree_unflatten

This commit is contained in:
Luca Vivona
2025-07-29 17:25:57 -04:00
parent 970dbe8e25
commit ab51eb80ad

View File

@@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc.
from collections import defaultdict
from collections import OrderedDict, defaultdict
from itertools import zip_longest
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
def tree_map(
@@ -114,52 +114,73 @@ def tree_map_with_path(
def tree_flatten(
tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None
) -> Any:
tree: Any,
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None,
prefix: str = "",
is_leaf: Optional[Callable] = None,
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
"""Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and
complexity.
.. code-block:: python
from mlx.utils import tree_flatten
print(tree_flatten([[[0]]]))
# [("0.0.0", 0)]
print(tree_flatten([[[0]]], ".hello"))
print(tree_flatten([[[0]]], prefix=".hello"))
# [("hello.0.0.0", 0)]
tree_flatten({"a": {"b": 1}})
{"a.b": 1}
.. note::
Dictionaries should have keys that are valid Python identifiers.
Args:
tree (Any): The Python tree to be flattened.
destination(Any, optional): Container to store results. If None, creates an OrderedDict.
Can be a list (for tuples) or dict (for key-value mapping).
prefix (str): A prefix to use for the keys. The first character is
always discarded.
always discarded if it starts with a dot.
is_leaf (callable): An optional callable that returns True if the
passed object is considered a leaf or False otherwise.
Returns:
List[Tuple[str, Any]]: The flat representation of the Python tree.
"""
flat_tree = []
if destination is None:
destination = OrderedDict()
def _add_to_destination(k: str, v: Any) -> None:
key: str = k[1:] if k.startswith(".") else k
if isinstance(destination, list):
destination.append((key, v))
elif isinstance(destination, dict):
destination[key] = v
else:
raise ValueError(
f"Unsupported destination type: {type(destination)}. "
"Must be list, tuple, or dict."
)
if is_leaf is not None and is_leaf(tree):
_add_to_destination(prefix, tree)
return destination
if is_leaf is None or not is_leaf(tree):
if isinstance(tree, (list, tuple)):
for i, t in enumerate(tree):
flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf))
return flat_tree
for i, item in enumerate(tree):
tree_flatten(item, destination, f"{prefix}.{i}", is_leaf)
return destination
if isinstance(tree, dict):
for k, t in tree.items():
flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf))
return flat_tree
for key, value in tree.items():
tree_flatten(value, destination, f"{prefix}.{key}", is_leaf)
return destination
return [(prefix[1:], tree)]
_add_to_destination(prefix, tree)
return destination
def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
"""Recreate a Python tree from its flat representation.
.. code-block:: python
@@ -171,12 +192,15 @@ def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
# {"hello": {"world": 42}}
Args:
tree (list[tuple[str, Any]]): The flat representation of a Python tree.
tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree.
For instance as returned by :meth:`tree_flatten`.
Returns:
A Python tree.
"""
if isinstance(tree, dict):
tree = list(tree.items())
if len(tree) == 1 and tree[0][0] == "":
return tree[0][1]