diff --git a/python/mlx/utils.py b/python/mlx/utils.py index 3a9edec6e..fa9884c10 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -115,21 +115,26 @@ def tree_map_with_path( def tree_flatten( tree: Any, - destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None, prefix: str = "", is_leaf: Optional[Callable] = None, + destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None, ) -> Union[List[Tuple[str, Any]], Dict[str, Any]]: """Flattens a Python tree to a list of key, value tuples. + The keys are using the dot notation to define trees of arbitrary depth and complexity. .. code-block:: python + from mlx.utils import tree_flatten + print(tree_flatten([[[0]]])) # [("0.0.0", 0)] + print(tree_flatten([[[0]]], prefix=".hello")) # [("hello.0.0.0", 0)] - tree_flatten({"a": {"b": 1}}, destination=dict()) + + tree_flatten({"a": {"b": 1}}, destination={}) {"a.b": 1} .. note:: @@ -137,46 +142,50 @@ def tree_flatten( Args: tree (Any): The Python tree to be flattened. - destination(Any, optional): Container to store results. If None, creates an OrderedDict. - Can be a list (for tuples) or dict (for key-value mapping). prefix (str): A prefix to use for the keys. The first character is - always discarded if it starts with a dot. + always discarded. is_leaf (callable): An optional callable that returns True if the passed object is considered a leaf or False otherwise. + destination (list or dict, optional): A list or dictionary to store the + flattened tree. If None an empty list will be used. Default: ``None``. Returns: - List[Tuple[str, Any]]: The flat representation of the Python tree. + Union[List[Tuple[str, Any]], Dict[str, Any]]: The flat representation of + the Python tree. """ if destination is None: destination = [] - def _add_to_destination(k: str, v: Any) -> None: - key: str = k[1:] if k.startswith(".") else k - if isinstance(destination, list): - destination.append((key, v)) - elif isinstance(destination, dict): - destination[key] = v - else: - raise ValueError( - f"Unsupported destination type: {type(destination)}. " - "Must be list, tuple, or dict." - ) + # Create the function to update the destination. We are taking advantage of + # the fact that list.extend and dict.update have the same API to simplify + # the code a bit. + if isinstance(destination, list): + _add_to_destination = destination.extend + elif isinstance(destination, dict): + _add_to_destination = destination.update + else: + raise ValueError("Destination should be either a list or a dictionary or None") + # Leaf identified by is_leaf so add it and return if is_leaf is not None and is_leaf(tree): - _add_to_destination(prefix, tree) + _add_to_destination([(prefix[1:], tree)]) return destination + # List or tuple so recursively add each subtree if isinstance(tree, (list, tuple)): for i, item in enumerate(tree): - tree_flatten(item, destination, f"{prefix}.{i}", is_leaf) + tree_flatten(item, f"{prefix}.{i}", is_leaf, destination) return destination + # Dictionary so recursively add each subtree if isinstance(tree, dict): for key, value in tree.items(): - tree_flatten(value, destination, f"{prefix}.{key}", is_leaf) + tree_flatten(value, f"{prefix}.{key}", is_leaf, destination) return destination - _add_to_destination(prefix, tree) + # Leaf so add it and return + _add_to_destination([(prefix[1:], tree)]) + return destination @@ -191,6 +200,10 @@ def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any: print(d) # {"hello": {"world": 42}} + d = tree_unflatten({"hello.world": 42}) + print(d) + # {"hello": {"world": 42}} + Args: tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree. For instance as returned by :meth:`tree_flatten`. @@ -198,27 +211,23 @@ def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any: Returns: A Python tree. """ - if isinstance(tree, dict): - tree = list(tree.items()) + items = tree.items() if isinstance(tree, dict) else tree - if len(tree) == 1 and tree[0][0] == "": - return tree[0][1] - - try: - int(tree[0][0].split(".", maxsplit=1)[0]) - is_list = True - except ValueError: - is_list = False + # Special case when we have just one element in the tree ie not a tree + if len(items) == 1: + key, value = next(iter(items)) + if key == "": + return value # collect children children = defaultdict(list) - for key, value in tree: + for key, value in items: current_idx, *next_idx = key.split(".", maxsplit=1) next_idx = "" if not next_idx else next_idx[0] children[current_idx].append((next_idx, value)) - # recursively map them to the original container - if is_list: + # Assume they are a list and fail to dict if the keys are not all integers + try: keys = sorted((int(idx), idx) for idx in children.keys()) l = [] for i, k in keys: @@ -226,7 +235,7 @@ def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any: l.extend([{} for _ in range(i - len(l))]) l.append(tree_unflatten(children[k])) return l - else: + except ValueError: return {k: tree_unflatten(v) for k, v in children.items()} diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 53b6900f9..cb1cd478a 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -30,16 +30,15 @@ class TestBase(mlx_tests.MLXTestCase): self.assertEqual(len(flat_children), 3) leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module) - if isinstance(leaves, list): - self.assertEqual(len(leaves), 4) - self.assertEqual(leaves[0][0], "layers.0.layers.0") - self.assertEqual(leaves[1][0], "layers.1.layers.0") - self.assertEqual(leaves[2][0], "layers.1.layers.1") - self.assertEqual(leaves[3][0], "layers.2") - self.assertTrue(leaves[0][1] is m.layers[0].layers[0]) - self.assertTrue(leaves[1][1] is m.layers[1].layers[0]) - self.assertTrue(leaves[2][1] is m.layers[1].layers[1]) - self.assertTrue(leaves[3][1] is m.layers[2]) + self.assertEqual(len(leaves), 4) + self.assertEqual(leaves[0][0], "layers.0.layers.0") + self.assertEqual(leaves[1][0], "layers.1.layers.0") + self.assertEqual(leaves[2][0], "layers.1.layers.1") + self.assertEqual(leaves[3][0], "layers.2") + self.assertTrue(leaves[0][1] is m.layers[0].layers[0]) + self.assertTrue(leaves[1][1] is m.layers[1].layers[0]) + self.assertTrue(leaves[2][1] is m.layers[1].layers[1]) + self.assertTrue(leaves[3][1] is m.layers[2]) m.eval()