diff --git a/docs/src/python/optimizers.rst b/docs/src/python/optimizers.rst index 1897483d88..3320225aed 100644 --- a/docs/src/python/optimizers.rst +++ b/docs/src/python/optimizers.rst @@ -51,14 +51,14 @@ the saved state. Here's a simple example: optimizer.update(model, grads) # Save the state - state = tree_flatten(optimizer.state) - mx.save_safetensors("optimizer.safetensors", dict(state)) + state = tree_flatten(optimizer.state, destination={}) + mx.save_safetensors("optimizer.safetensors", state) # Later on, for example when loading from a checkpoint, # recreate the optimizer and load the state optimizer = optim.Adam(learning_rate=1e-2) - state = tree_unflatten(list(mx.load("optimizer.safetensors").items())) + state = tree_unflatten(mx.load("optimizer.safetensors")) optimizer.state = state Note, not every optimizer configuation parameter is saved in the state. For diff --git a/docs/src/usage/export.rst b/docs/src/usage/export.rst index 8120736093..4d77b464d1 100644 --- a/docs/src/usage/export.rst +++ b/docs/src/usage/export.rst @@ -7,17 +7,17 @@ Exporting Functions MLX has an API to export and import functions to and from a file. This lets you run computations written in one MLX front-end (e.g. Python) in another MLX -front-end (e.g. C++). +front-end (e.g. C++). This guide walks through the basics of the MLX export API with some examples. To see the full list of functions check-out the :ref:`API documentation `. -Basics of Exporting +Basics of Exporting ------------------- Let's start with a simple example: - + .. code-block:: python def fun(x, y): @@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays: x = mx.array(1.0) y = mx.array(1.0) - + # Both arguments to fun are positional mx.export_function("add.mlxfn", fun, x, y) @@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file. For enclosed arrays inside an exported function, be extra careful to ensure they are evaluated. The computation graph that gets exported will include the computation that produces enclosed inputs. - + If the above example was missing ``mx.eval(model.parameters()``, the exported function would include the random initialization of the :obj:`mlx.nn.Module` parameters. @@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper: # Set the model's parameters to the input parameters model.update(tree_unflatten(list(params.items()))) return model(x) - - params = dict(tree_flatten(model.parameters())) + + params = tree_flatten(model.parameters(), destination={}) mx.export_function("model.mlxfn", call, (mx.zeros(4),), params) @@ -169,8 +169,8 @@ to export a function which can be used for inputs with variable shapes: # Ok out, = imported_abs(mx.array(-1.0)) - - # Also ok + + # Also ok out, = imported_abs(mx.array([-1.0, -2.0])) With ``shapeless=False`` (which is the default), the second call to @@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`: def fun(x, y=None): constant = mx.array(3.0) if y is not None: - x += y + x += y return x + constant with mx.exporter("fun.mlxfn", fun) as exporter: @@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`: print(out) In the above example the function constant data, (i.e. ``constant``), is only -saved once. +saved once. Transformations with Imported Functions --------------------------------------- @@ -238,7 +238,7 @@ on imported functions just like regular Python functions: # Prints: array(1, dtype=float32) print(dfdx(x)) - # Compile the imported function + # Compile the imported function mx.compile(imported_fun) # Prints: array(0, dtype=float32) print(compiled_fun(x)[0]) @@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code: // Prints: array(2, dtype=float32) std::cout << outputs[0] << std::endl; -Imported functions can be transformed in C++ just like in Python. Use +Imported functions can be transformed in C++ just like in Python. Use ``std::vector`` for positional arguments and ``std::map`` for keyword arguments when calling imported functions in C++. diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index c3a517d163..1ae9272799 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -178,7 +178,7 @@ class Module(dict): if strict: new_weights = dict(weights) - curr_weights = dict(tree_flatten(self.parameters())) + curr_weights = tree_flatten(self.parameters(), destination={}) if extras := (new_weights.keys() - curr_weights.keys()): num_extra = len(extras) extras = ",\n".join(sorted(extras)) @@ -212,7 +212,7 @@ class Module(dict): - ``.npz`` will use :func:`mx.savez` - ``.safetensors`` will use :func:`mx.save_safetensors` """ - params_dict = dict(tree_flatten(self.parameters())) + params_dict = tree_flatten(self.parameters(), destination={}) if file.endswith(".npz"): mx.savez(file, **params_dict) diff --git a/python/mlx/utils.py b/python/mlx/utils.py index b7173deb71..fa9884c10f 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -1,7 +1,7 @@ # Copyright © 2023 Apple Inc. from collections import defaultdict from itertools import zip_longest -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union def tree_map( @@ -114,8 +114,11 @@ def tree_map_with_path( def tree_flatten( - tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None -) -> Any: + tree: Any, + prefix: str = "", + is_leaf: Optional[Callable] = None, + destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None, +) -> Union[List[Tuple[str, Any]], Dict[str, Any]]: """Flattens a Python tree to a list of key, value tuples. The keys are using the dot notation to define trees of arbitrary depth and @@ -128,9 +131,12 @@ def tree_flatten( print(tree_flatten([[[0]]])) # [("0.0.0", 0)] - print(tree_flatten([[[0]]], ".hello")) + print(tree_flatten([[[0]]], prefix=".hello")) # [("hello.0.0.0", 0)] + tree_flatten({"a": {"b": 1}}, destination={}) + {"a.b": 1} + .. note:: Dictionaries should have keys that are valid Python identifiers. @@ -140,26 +146,50 @@ def tree_flatten( always discarded. is_leaf (callable): An optional callable that returns True if the passed object is considered a leaf or False otherwise. + destination (list or dict, optional): A list or dictionary to store the + flattened tree. If None an empty list will be used. Default: ``None``. Returns: - List[Tuple[str, Any]]: The flat representation of the Python tree. + Union[List[Tuple[str, Any]], Dict[str, Any]]: The flat representation of + the Python tree. """ - flat_tree = [] + if destination is None: + destination = [] - if is_leaf is None or not is_leaf(tree): - if isinstance(tree, (list, tuple)): - for i, t in enumerate(tree): - flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf)) - return flat_tree - if isinstance(tree, dict): - for k, t in tree.items(): - flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf)) - return flat_tree + # Create the function to update the destination. We are taking advantage of + # the fact that list.extend and dict.update have the same API to simplify + # the code a bit. + if isinstance(destination, list): + _add_to_destination = destination.extend + elif isinstance(destination, dict): + _add_to_destination = destination.update + else: + raise ValueError("Destination should be either a list or a dictionary or None") - return [(prefix[1:], tree)] + # Leaf identified by is_leaf so add it and return + if is_leaf is not None and is_leaf(tree): + _add_to_destination([(prefix[1:], tree)]) + return destination + + # List or tuple so recursively add each subtree + if isinstance(tree, (list, tuple)): + for i, item in enumerate(tree): + tree_flatten(item, f"{prefix}.{i}", is_leaf, destination) + return destination + + # Dictionary so recursively add each subtree + if isinstance(tree, dict): + for key, value in tree.items(): + tree_flatten(value, f"{prefix}.{key}", is_leaf, destination) + return destination + + # Leaf so add it and return + _add_to_destination([(prefix[1:], tree)]) + + return destination -def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any: +def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any: """Recreate a Python tree from its flat representation. .. code-block:: python @@ -170,31 +200,34 @@ def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any: print(d) # {"hello": {"world": 42}} + d = tree_unflatten({"hello.world": 42}) + print(d) + # {"hello": {"world": 42}} + Args: - tree (list[tuple[str, Any]]): The flat representation of a Python tree. + tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree. For instance as returned by :meth:`tree_flatten`. Returns: A Python tree. """ - if len(tree) == 1 and tree[0][0] == "": - return tree[0][1] + items = tree.items() if isinstance(tree, dict) else tree - try: - int(tree[0][0].split(".", maxsplit=1)[0]) - is_list = True - except ValueError: - is_list = False + # Special case when we have just one element in the tree ie not a tree + if len(items) == 1: + key, value = next(iter(items)) + if key == "": + return value # collect children children = defaultdict(list) - for key, value in tree: + for key, value in items: current_idx, *next_idx = key.split(".", maxsplit=1) next_idx = "" if not next_idx else next_idx[0] children[current_idx].append((next_idx, value)) - # recursively map them to the original container - if is_list: + # Assume they are a list and fail to dict if the keys are not all integers + try: keys = sorted((int(idx), idx) for idx in children.keys()) l = [] for i, k in keys: @@ -202,7 +235,7 @@ def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any: l.extend([{} for _ in range(i - len(l))]) l.append(tree_unflatten(children[k])) return l - else: + except ValueError: return {k: tree_unflatten(v) for k, v in children.items()} diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index a771020875..296f6ee8d9 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -80,7 +80,7 @@ class TestBase(mlx_tests.MLXTestCase): self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))} model = DictModule() - params = dict(tree_flatten(model.parameters())) + params = tree_flatten(model.parameters(), destination={}) self.assertEqual(len(params), 2) self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2)))) self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))