In continuation to PR1243 to solve issue #1240 (#1365)

* Solves issue #1240

* Correction

* Update python/mlx/utils.py

* Update python/mlx/utils.py

---------

Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Saanidhya 2024-08-28 14:40:41 -04:00 committed by GitHub
parent 291cf40aca
commit 4e22a1dffe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,6 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
from collections import defaultdict from collections import defaultdict
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, List, Optional, Tuple
def tree_map( def tree_map(
@ -111,7 +111,9 @@ def tree_map_with_path(
return fn(path, tree, *rest) return fn(path, tree, *rest)
def tree_flatten(tree, prefix="", is_leaf=None): def tree_flatten(
tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None
) -> Any:
"""Flattens a Python tree to a list of key, value tuples. """Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and The keys are using the dot notation to define trees of arbitrary depth and
@ -155,7 +157,7 @@ def tree_flatten(tree, prefix="", is_leaf=None):
return [(prefix[1:], tree)] return [(prefix[1:], tree)]
def tree_unflatten(tree): def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
"""Recreate a Python tree from its flat representation. """Recreate a Python tree from its flat representation.
.. code-block:: python .. code-block:: python