diff --git a/python/mlx/utils.py b/python/mlx/utils.py index daa387420..137a8aae4 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -39,13 +39,9 @@ def tree_map(fn, tree, *rest, is_leaf=None): """ if is_leaf is not None and is_leaf(tree): return fn(tree, *rest) - elif isinstance(tree, list): - return [ - tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf) - for i, child in enumerate(tree) - ] - elif isinstance(tree, tuple): - return tuple( + elif isinstance(tree, (list, tuple)): + TreeType = type(tree) + return TreeType( tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf) for i, child in enumerate(tree) ) @@ -141,8 +137,8 @@ def tree_unflatten(tree): keys = sorted((int(idx), idx) for idx in children.keys()) l = [] for i, k in keys: - while i > len(l): - l.append({}) + # if i <= len(l), no {} will be appended. + l.extend([{} for _ in range(i - len(l))]) l.append(tree_unflatten(children[k])) return l else: