mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Use defaultdict (#307)
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
This commit is contained in:
parent
d29770eeaa
commit
473b6b43b4
@ -1,5 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def tree_map(fn, tree, *rest, is_leaf=None):
|
||||
"""Applies ``fn`` to the leaves of the python tree ``tree`` and
|
||||
@ -128,12 +130,10 @@ def tree_unflatten(tree):
|
||||
is_list = False
|
||||
|
||||
# collect children
|
||||
children = {}
|
||||
children = defaultdict(list)
|
||||
for key, value in tree:
|
||||
current_idx, *next_idx = key.split(".", maxsplit=1)
|
||||
next_idx = "" if not next_idx else next_idx[0]
|
||||
if current_idx not in children:
|
||||
children[current_idx] = []
|
||||
children[current_idx].append((next_idx, value))
|
||||
|
||||
# recursively map them to the original container
|
||||
|
Loading…
Reference in New Issue
Block a user