mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-15 04:51:13 +08:00
* Solves issue #1240 * Correction * Update python/mlx/utils.py * Update python/mlx/utils.py --------- Co-authored-by: Awni Hannun <awni@apple.com> Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
291cf40aca
commit
4e22a1dffe
@ -1,6 +1,6 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
|
||||
def tree_map(
|
||||
@ -111,7 +111,9 @@ def tree_map_with_path(
|
||||
return fn(path, tree, *rest)
|
||||
|
||||
|
||||
def tree_flatten(tree, prefix="", is_leaf=None):
|
||||
def tree_flatten(
|
||||
tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None
|
||||
) -> Any:
|
||||
"""Flattens a Python tree to a list of key, value tuples.
|
||||
|
||||
The keys are using the dot notation to define trees of arbitrary depth and
|
||||
@ -155,7 +157,7 @@ def tree_flatten(tree, prefix="", is_leaf=None):
|
||||
return [(prefix[1:], tree)]
|
||||
|
||||
|
||||
def tree_unflatten(tree):
|
||||
def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
|
||||
"""Recreate a Python tree from its flat representation.
|
||||
|
||||
.. code-block:: python
|
||||
|
Loading…
Reference in New Issue
Block a user