mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-24 22:36:39 +08:00
Minor refactor for tree_map and tree_unflatten (#311)
* Minor refact for tree_map and tree_unflatten * Remove the if statement --------- Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
This commit is contained in:
parent
473b6b43b4
commit
2aedf3e791
@ -39,13 +39,9 @@ def tree_map(fn, tree, *rest, is_leaf=None):
|
|||||||
"""
|
"""
|
||||||
if is_leaf is not None and is_leaf(tree):
|
if is_leaf is not None and is_leaf(tree):
|
||||||
return fn(tree, *rest)
|
return fn(tree, *rest)
|
||||||
elif isinstance(tree, list):
|
elif isinstance(tree, (list, tuple)):
|
||||||
return [
|
TreeType = type(tree)
|
||||||
tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
|
return TreeType(
|
||||||
for i, child in enumerate(tree)
|
|
||||||
]
|
|
||||||
elif isinstance(tree, tuple):
|
|
||||||
return tuple(
|
|
||||||
tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
|
tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
|
||||||
for i, child in enumerate(tree)
|
for i, child in enumerate(tree)
|
||||||
)
|
)
|
||||||
@ -141,8 +137,8 @@ def tree_unflatten(tree):
|
|||||||
keys = sorted((int(idx), idx) for idx in children.keys())
|
keys = sorted((int(idx), idx) for idx in children.keys())
|
||||||
l = []
|
l = []
|
||||||
for i, k in keys:
|
for i, k in keys:
|
||||||
while i > len(l):
|
# if i <= len(l), no {} will be appended.
|
||||||
l.append({})
|
l.extend([{} for _ in range(i - len(l))])
|
||||||
l.append(tree_unflatten(children[k]))
|
l.append(tree_unflatten(children[k]))
|
||||||
return l
|
return l
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user