diff --git a/python/mlx/utils.py b/python/mlx/utils.py index 39b9ed21a..6754232a6 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -1,6 +1,6 @@ # Copyright © 2023 Apple Inc. from collections import defaultdict -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple def tree_map( @@ -111,7 +111,9 @@ def tree_map_with_path( return fn(path, tree, *rest) -def tree_flatten(tree, prefix="", is_leaf=None): +def tree_flatten( + tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None +) -> Any: """Flattens a Python tree to a list of key, value tuples. The keys are using the dot notation to define trees of arbitrary depth and @@ -155,7 +157,7 @@ def tree_flatten(tree, prefix="", is_leaf=None): return [(prefix[1:], tree)] -def tree_unflatten(tree): +def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any: """Recreate a Python tree from its flat representation. .. code-block:: python