Minor refactor for tree_map and tree_unflatten (#311)

* Minor refact for tree_map and tree_unflatten

* Remove the if statement

---------

Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
This commit is contained in:
Chunyang Wen 2023-12-29 12:55:10 +08:00 committed by GitHub
parent 473b6b43b4
commit 2aedf3e791
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -39,13 +39,9 @@ def tree_map(fn, tree, *rest, is_leaf=None):
"""
if is_leaf is not None and is_leaf(tree):
return fn(tree, *rest)
elif isinstance(tree, list):
return [
tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
for i, child in enumerate(tree)
]
elif isinstance(tree, tuple):
return tuple(
elif isinstance(tree, (list, tuple)):
TreeType = type(tree)
return TreeType(
tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
for i, child in enumerate(tree)
)
@ -141,8 +137,8 @@ def tree_unflatten(tree):
keys = sorted((int(idx), idx) for idx in children.keys())
l = []
for i, k in keys:
while i > len(l):
l.append({})
# if i <= len(l), no {} will be appended.
l.extend([{} for _ in range(i - len(l))])
l.append(tree_unflatten(children[k]))
return l
else: