diff --git a/python/mlx/utils.py b/python/mlx/utils.py index b7173deb7..a7f68effa 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -1,7 +1,7 @@ # Copyright © 2023 Apple Inc. -from collections import defaultdict +from collections import OrderedDict, defaultdict from itertools import zip_longest -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union def tree_map( @@ -114,52 +114,73 @@ def tree_map_with_path( def tree_flatten( - tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None -) -> Any: + tree: Any, + destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None, + prefix: str = "", + is_leaf: Optional[Callable] = None, +) -> Union[List[Tuple[str, Any]], Dict[str, Any]]: """Flattens a Python tree to a list of key, value tuples. - The keys are using the dot notation to define trees of arbitrary depth and complexity. .. code-block:: python - from mlx.utils import tree_flatten - print(tree_flatten([[[0]]])) # [("0.0.0", 0)] - - print(tree_flatten([[[0]]], ".hello")) + print(tree_flatten([[[0]]], prefix=".hello")) # [("hello.0.0.0", 0)] + tree_flatten({"a": {"b": 1}}) + {"a.b": 1} .. note:: Dictionaries should have keys that are valid Python identifiers. Args: tree (Any): The Python tree to be flattened. + destination(Any, optional): Container to store results. If None, creates an OrderedDict. + Can be a list (for tuples) or dict (for key-value mapping). prefix (str): A prefix to use for the keys. The first character is - always discarded. + always discarded if it starts with a dot. is_leaf (callable): An optional callable that returns True if the passed object is considered a leaf or False otherwise. Returns: List[Tuple[str, Any]]: The flat representation of the Python tree. """ - flat_tree = [] + if destination is None: + destination = OrderedDict() - if is_leaf is None or not is_leaf(tree): - if isinstance(tree, (list, tuple)): - for i, t in enumerate(tree): - flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf)) - return flat_tree - if isinstance(tree, dict): - for k, t in tree.items(): - flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf)) - return flat_tree + def _add_to_destination(k: str, v: Any) -> None: + key: str = k[1:] if k.startswith(".") else k + if isinstance(destination, list): + destination.append((key, v)) + elif isinstance(destination, dict): + destination[key] = v + else: + raise ValueError( + f"Unsupported destination type: {type(destination)}. " + "Must be list, tuple, or dict." + ) - return [(prefix[1:], tree)] + if is_leaf is not None and is_leaf(tree): + _add_to_destination(prefix, tree) + return destination + + if isinstance(tree, (list, tuple)): + for i, item in enumerate(tree): + tree_flatten(item, destination, f"{prefix}.{i}", is_leaf) + return destination + + if isinstance(tree, dict): + for key, value in tree.items(): + tree_flatten(value, destination, f"{prefix}.{key}", is_leaf) + return destination + + _add_to_destination(prefix, tree) + return destination -def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any: +def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any: """Recreate a Python tree from its flat representation. .. code-block:: python @@ -171,12 +192,15 @@ def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any: # {"hello": {"world": 42}} Args: - tree (list[tuple[str, Any]]): The flat representation of a Python tree. + tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree. For instance as returned by :meth:`tree_flatten`. Returns: A Python tree. """ + if isinstance(tree, dict): + tree = list(tree.items()) + if len(tree) == 1 and tree[0][0] == "": return tree[0][1]