From 4e22a1dffef2ded37fd0f537c54ae0e215e59d26 Mon Sep 17 00:00:00 2001 From: Saanidhya <50399005+Saanidhyavats@users.noreply.github.com> Date: Wed, 28 Aug 2024 14:40:41 -0400 Subject: [PATCH] In continuation to PR1243 to solve issue #1240 (#1365) * Solves issue #1240 * Correction * Update python/mlx/utils.py * Update python/mlx/utils.py --------- Co-authored-by: Awni Hannun Co-authored-by: Awni Hannun --- python/mlx/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/mlx/utils.py b/python/mlx/utils.py index 39b9ed21a..6754232a6 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -1,6 +1,6 @@ # Copyright © 2023 Apple Inc. from collections import defaultdict -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple def tree_map( @@ -111,7 +111,9 @@ def tree_map_with_path( return fn(path, tree, *rest) -def tree_flatten(tree, prefix="", is_leaf=None): +def tree_flatten( + tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None +) -> Any: """Flattens a Python tree to a list of key, value tuples. The keys are using the dot notation to define trees of arbitrary depth and @@ -155,7 +157,7 @@ def tree_flatten(tree, prefix="", is_leaf=None): return [(prefix[1:], tree)] -def tree_unflatten(tree): +def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any: """Recreate a Python tree from its flat representation. .. code-block:: python