mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-15 21:21:16 +08:00
* Solves issue #1240 * Correction * Update python/mlx/utils.py * Update python/mlx/utils.py --------- Co-authored-by: Awni Hannun <awni@apple.com> Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
291cf40aca
commit
4e22a1dffe
@ -1,6 +1,6 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Callable, Optional, Tuple
|
from typing import Any, Callable, List, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
def tree_map(
|
def tree_map(
|
||||||
@ -111,7 +111,9 @@ def tree_map_with_path(
|
|||||||
return fn(path, tree, *rest)
|
return fn(path, tree, *rest)
|
||||||
|
|
||||||
|
|
||||||
def tree_flatten(tree, prefix="", is_leaf=None):
|
def tree_flatten(
|
||||||
|
tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None
|
||||||
|
) -> Any:
|
||||||
"""Flattens a Python tree to a list of key, value tuples.
|
"""Flattens a Python tree to a list of key, value tuples.
|
||||||
|
|
||||||
The keys are using the dot notation to define trees of arbitrary depth and
|
The keys are using the dot notation to define trees of arbitrary depth and
|
||||||
@ -155,7 +157,7 @@ def tree_flatten(tree, prefix="", is_leaf=None):
|
|||||||
return [(prefix[1:], tree)]
|
return [(prefix[1:], tree)]
|
||||||
|
|
||||||
|
|
||||||
def tree_unflatten(tree):
|
def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
|
||||||
"""Recreate a Python tree from its flat representation.
|
"""Recreate a Python tree from its flat representation.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
Loading…
Reference in New Issue
Block a user