mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
feat: support dict and list destinations in tree_flatten, add tree_unflatten
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
from collections import defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
from typing import Any, Callable, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
def tree_map(
|
def tree_map(
|
||||||
@@ -114,52 +114,73 @@ def tree_map_with_path(
|
|||||||
|
|
||||||
|
|
||||||
def tree_flatten(
|
def tree_flatten(
|
||||||
tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None
|
tree: Any,
|
||||||
) -> Any:
|
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
is_leaf: Optional[Callable] = None,
|
||||||
|
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
|
||||||
"""Flattens a Python tree to a list of key, value tuples.
|
"""Flattens a Python tree to a list of key, value tuples.
|
||||||
|
|
||||||
The keys are using the dot notation to define trees of arbitrary depth and
|
The keys are using the dot notation to define trees of arbitrary depth and
|
||||||
complexity.
|
complexity.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
print(tree_flatten([[[0]]]))
|
print(tree_flatten([[[0]]]))
|
||||||
# [("0.0.0", 0)]
|
# [("0.0.0", 0)]
|
||||||
|
print(tree_flatten([[[0]]], prefix=".hello"))
|
||||||
print(tree_flatten([[[0]]], ".hello"))
|
|
||||||
# [("hello.0.0.0", 0)]
|
# [("hello.0.0.0", 0)]
|
||||||
|
tree_flatten({"a": {"b": 1}})
|
||||||
|
{"a.b": 1}
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
Dictionaries should have keys that are valid Python identifiers.
|
Dictionaries should have keys that are valid Python identifiers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tree (Any): The Python tree to be flattened.
|
tree (Any): The Python tree to be flattened.
|
||||||
|
destination(Any, optional): Container to store results. If None, creates an OrderedDict.
|
||||||
|
Can be a list (for tuples) or dict (for key-value mapping).
|
||||||
prefix (str): A prefix to use for the keys. The first character is
|
prefix (str): A prefix to use for the keys. The first character is
|
||||||
always discarded.
|
always discarded if it starts with a dot.
|
||||||
is_leaf (callable): An optional callable that returns True if the
|
is_leaf (callable): An optional callable that returns True if the
|
||||||
passed object is considered a leaf or False otherwise.
|
passed object is considered a leaf or False otherwise.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Tuple[str, Any]]: The flat representation of the Python tree.
|
List[Tuple[str, Any]]: The flat representation of the Python tree.
|
||||||
"""
|
"""
|
||||||
flat_tree = []
|
if destination is None:
|
||||||
|
destination = OrderedDict()
|
||||||
|
|
||||||
if is_leaf is None or not is_leaf(tree):
|
def _add_to_destination(k: str, v: Any) -> None:
|
||||||
if isinstance(tree, (list, tuple)):
|
key: str = k[1:] if k.startswith(".") else k
|
||||||
for i, t in enumerate(tree):
|
if isinstance(destination, list):
|
||||||
flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf))
|
destination.append((key, v))
|
||||||
return flat_tree
|
elif isinstance(destination, dict):
|
||||||
if isinstance(tree, dict):
|
destination[key] = v
|
||||||
for k, t in tree.items():
|
else:
|
||||||
flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf))
|
raise ValueError(
|
||||||
return flat_tree
|
f"Unsupported destination type: {type(destination)}. "
|
||||||
|
"Must be list, tuple, or dict."
|
||||||
|
)
|
||||||
|
|
||||||
return [(prefix[1:], tree)]
|
if is_leaf is not None and is_leaf(tree):
|
||||||
|
_add_to_destination(prefix, tree)
|
||||||
|
return destination
|
||||||
|
|
||||||
|
if isinstance(tree, (list, tuple)):
|
||||||
|
for i, item in enumerate(tree):
|
||||||
|
tree_flatten(item, destination, f"{prefix}.{i}", is_leaf)
|
||||||
|
return destination
|
||||||
|
|
||||||
|
if isinstance(tree, dict):
|
||||||
|
for key, value in tree.items():
|
||||||
|
tree_flatten(value, destination, f"{prefix}.{key}", is_leaf)
|
||||||
|
return destination
|
||||||
|
|
||||||
|
_add_to_destination(prefix, tree)
|
||||||
|
return destination
|
||||||
|
|
||||||
|
|
||||||
def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
|
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
|
||||||
"""Recreate a Python tree from its flat representation.
|
"""Recreate a Python tree from its flat representation.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@@ -171,12 +192,15 @@ def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
|
|||||||
# {"hello": {"world": 42}}
|
# {"hello": {"world": 42}}
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tree (list[tuple[str, Any]]): The flat representation of a Python tree.
|
tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree.
|
||||||
For instance as returned by :meth:`tree_flatten`.
|
For instance as returned by :meth:`tree_flatten`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A Python tree.
|
A Python tree.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(tree, dict):
|
||||||
|
tree = list(tree.items())
|
||||||
|
|
||||||
if len(tree) == 1 and tree[0][0] == "":
|
if len(tree) == 1 and tree[0][0] == "":
|
||||||
return tree[0][1]
|
return tree[0][1]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user