feat: support dict and list destinations in tree_flatten, add tree_unflatten

This commit is contained in:
Luca Vivona
2025-07-29 17:25:57 -04:00
parent 970dbe8e25
commit ab51eb80ad

View File

@@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
from collections import defaultdict from collections import OrderedDict, defaultdict
from itertools import zip_longest from itertools import zip_longest
from typing import Any, Callable, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple, Union
def tree_map( def tree_map(
@@ -114,52 +114,73 @@ def tree_map_with_path(
def tree_flatten( def tree_flatten(
tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None tree: Any,
) -> Any: destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None,
prefix: str = "",
is_leaf: Optional[Callable] = None,
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
"""Flattens a Python tree to a list of key, value tuples. """Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and The keys are using the dot notation to define trees of arbitrary depth and
complexity. complexity.
.. code-block:: python .. code-block:: python
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
print(tree_flatten([[[0]]])) print(tree_flatten([[[0]]]))
# [("0.0.0", 0)] # [("0.0.0", 0)]
print(tree_flatten([[[0]]], prefix=".hello"))
print(tree_flatten([[[0]]], ".hello"))
# [("hello.0.0.0", 0)] # [("hello.0.0.0", 0)]
tree_flatten({"a": {"b": 1}})
{"a.b": 1}
.. note:: .. note::
Dictionaries should have keys that are valid Python identifiers. Dictionaries should have keys that are valid Python identifiers.
Args: Args:
tree (Any): The Python tree to be flattened. tree (Any): The Python tree to be flattened.
destination(Any, optional): Container to store results. If None, creates an OrderedDict.
Can be a list (for tuples) or dict (for key-value mapping).
prefix (str): A prefix to use for the keys. The first character is prefix (str): A prefix to use for the keys. The first character is
always discarded. always discarded if it starts with a dot.
is_leaf (callable): An optional callable that returns True if the is_leaf (callable): An optional callable that returns True if the
passed object is considered a leaf or False otherwise. passed object is considered a leaf or False otherwise.
Returns: Returns:
List[Tuple[str, Any]]: The flat representation of the Python tree. List[Tuple[str, Any]]: The flat representation of the Python tree.
""" """
flat_tree = [] if destination is None:
destination = OrderedDict()
if is_leaf is None or not is_leaf(tree): def _add_to_destination(k: str, v: Any) -> None:
if isinstance(tree, (list, tuple)): key: str = k[1:] if k.startswith(".") else k
for i, t in enumerate(tree): if isinstance(destination, list):
flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf)) destination.append((key, v))
return flat_tree elif isinstance(destination, dict):
if isinstance(tree, dict): destination[key] = v
for k, t in tree.items(): else:
flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf)) raise ValueError(
return flat_tree f"Unsupported destination type: {type(destination)}. "
"Must be list, tuple, or dict."
)
return [(prefix[1:], tree)] if is_leaf is not None and is_leaf(tree):
_add_to_destination(prefix, tree)
return destination
if isinstance(tree, (list, tuple)):
for i, item in enumerate(tree):
tree_flatten(item, destination, f"{prefix}.{i}", is_leaf)
return destination
if isinstance(tree, dict):
for key, value in tree.items():
tree_flatten(value, destination, f"{prefix}.{key}", is_leaf)
return destination
_add_to_destination(prefix, tree)
return destination
def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any: def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
"""Recreate a Python tree from its flat representation. """Recreate a Python tree from its flat representation.
.. code-block:: python .. code-block:: python
@@ -171,12 +192,15 @@ def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
# {"hello": {"world": 42}} # {"hello": {"world": 42}}
Args: Args:
tree (list[tuple[str, Any]]): The flat representation of a Python tree. tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree.
For instance as returned by :meth:`tree_flatten`. For instance as returned by :meth:`tree_flatten`.
Returns: Returns:
A Python tree. A Python tree.
""" """
if isinstance(tree, dict):
tree = list(tree.items())
if len(tree) == 1 and tree[0][0] == "": if len(tree) == 1 and tree[0][0] == "":
return tree[0][1] return tree[0][1]