diff --git a/python/mlx/utils.py b/python/mlx/utils.py index a7f68effa..3a9edec6e 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -1,5 +1,5 @@ # Copyright © 2023 Apple Inc. -from collections import OrderedDict, defaultdict +from collections import defaultdict from itertools import zip_longest from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -129,7 +129,7 @@ def tree_flatten( # [("0.0.0", 0)] print(tree_flatten([[[0]]], prefix=".hello")) # [("hello.0.0.0", 0)] - tree_flatten({"a": {"b": 1}}) + tree_flatten({"a": {"b": 1}}, destination=dict()) {"a.b": 1} .. note:: @@ -148,7 +148,7 @@ def tree_flatten( List[Tuple[str, Any]]: The flat representation of the Python tree. """ if destination is None: - destination = OrderedDict() + destination = [] def _add_to_destination(k: str, v: Any) -> None: key: str = k[1:] if k.startswith(".") else k