mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Use defaultdict (#307)
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
This commit is contained in:
parent
d29770eeaa
commit
473b6b43b4
@ -1,5 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
def tree_map(fn, tree, *rest, is_leaf=None):
|
def tree_map(fn, tree, *rest, is_leaf=None):
|
||||||
"""Applies ``fn`` to the leaves of the python tree ``tree`` and
|
"""Applies ``fn`` to the leaves of the python tree ``tree`` and
|
||||||
@ -128,12 +130,10 @@ def tree_unflatten(tree):
|
|||||||
is_list = False
|
is_list = False
|
||||||
|
|
||||||
# collect children
|
# collect children
|
||||||
children = {}
|
children = defaultdict(list)
|
||||||
for key, value in tree:
|
for key, value in tree:
|
||||||
current_idx, *next_idx = key.split(".", maxsplit=1)
|
current_idx, *next_idx = key.split(".", maxsplit=1)
|
||||||
next_idx = "" if not next_idx else next_idx[0]
|
next_idx = "" if not next_idx else next_idx[0]
|
||||||
if current_idx not in children:
|
|
||||||
children[current_idx] = []
|
|
||||||
children[current_idx].append((next_idx, value))
|
children[current_idx].append((next_idx, value))
|
||||||
|
|
||||||
# recursively map them to the original container
|
# recursively map them to the original container
|
||||||
|
Loading…
Reference in New Issue
Block a user