mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-24 06:16:38 +08:00
Minor refactor for tree_map and tree_unflatten (#311)
* Minor refact for tree_map and tree_unflatten * Remove the if statement --------- Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
This commit is contained in:
parent
473b6b43b4
commit
2aedf3e791
@ -39,13 +39,9 @@ def tree_map(fn, tree, *rest, is_leaf=None):
|
||||
"""
|
||||
if is_leaf is not None and is_leaf(tree):
|
||||
return fn(tree, *rest)
|
||||
elif isinstance(tree, list):
|
||||
return [
|
||||
tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
|
||||
for i, child in enumerate(tree)
|
||||
]
|
||||
elif isinstance(tree, tuple):
|
||||
return tuple(
|
||||
elif isinstance(tree, (list, tuple)):
|
||||
TreeType = type(tree)
|
||||
return TreeType(
|
||||
tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
|
||||
for i, child in enumerate(tree)
|
||||
)
|
||||
@ -141,8 +137,8 @@ def tree_unflatten(tree):
|
||||
keys = sorted((int(idx), idx) for idx in children.keys())
|
||||
l = []
|
||||
for i, k in keys:
|
||||
while i > len(l):
|
||||
l.append({})
|
||||
# if i <= len(l), no {} will be appended.
|
||||
l.extend([{} for _ in range(i - len(l))])
|
||||
l.append(tree_unflatten(children[k]))
|
||||
return l
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user