In continuation to PR1243 to solve issue #1240 (#1365)

* Solves issue #1240

* Correction

* Update python/mlx/utils.py

* Update python/mlx/utils.py

---------

Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Saanidhya 2024-08-28 14:40:41 -04:00 committed by GitHub
parent 291cf40aca
commit 4e22a1dffe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,6 @@
# Copyright © 2023 Apple Inc.
from collections import defaultdict
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple
def tree_map(
@ -111,7 +111,9 @@ def tree_map_with_path(
return fn(path, tree, *rest)
def tree_flatten(tree, prefix="", is_leaf=None):
def tree_flatten(
tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None
) -> Any:
"""Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and
@ -155,7 +157,7 @@ def tree_flatten(tree, prefix="", is_leaf=None):
return [(prefix[1:], tree)]
def tree_unflatten(tree):
def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
"""Recreate a Python tree from its flat representation.
.. code-block:: python