diff --git a/python/mlx/utils.py b/python/mlx/utils.py index 8cb8e90c8..daa387420 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -1,5 +1,7 @@ # Copyright © 2023 Apple Inc. +from collections import defaultdict + def tree_map(fn, tree, *rest, is_leaf=None): """Applies ``fn`` to the leaves of the python tree ``tree`` and @@ -128,12 +130,10 @@ def tree_unflatten(tree): is_list = False # collect children - children = {} + children = defaultdict(list) for key, value in tree: current_idx, *next_idx = key.split(".", maxsplit=1) next_idx = "" if not next_idx else next_idx[0] - if current_idx not in children: - children[current_idx] = [] children[current_idx].append((next_idx, value)) # recursively map them to the original container