From 8c01a7893beb3353a5efce67a0dfa72240023125 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 12 Jul 2024 12:18:02 -0700 Subject: [PATCH] minor fix in optimizer + docs (#1264) --- docs/src/python/optimizers.rst | 35 +++++++++++++++++++++++++++++ python/mlx/optimizers/optimizers.py | 1 + 2 files changed, 36 insertions(+) diff --git a/docs/src/python/optimizers.rst b/docs/src/python/optimizers.rst index 84ab933ac..1897483d8 100644 --- a/docs/src/python/optimizers.rst +++ b/docs/src/python/optimizers.rst @@ -31,6 +31,41 @@ model's parameters and the **optimizer state**. # Compute the new parameters but also the optimizer state. mx.eval(model.parameters(), optimizer.state) +Saving and Loading +------------------ + +To serialize an optimizer, save its state. To load an optimizer, load and set +the saved state. Here's a simple example: + +.. code-block:: python + + import mlx.core as mx + from mlx.utils import tree_flatten, tree_unflatten + import mlx.optimizers as optim + + optimizer = optim.Adam(learning_rate=1e-2) + + # Perform some updates with the optimizer + model = {"w" : mx.zeros((5, 5))} + grads = {"w" : mx.ones((5, 5))} + optimizer.update(model, grads) + + # Save the state + state = tree_flatten(optimizer.state) + mx.save_safetensors("optimizer.safetensors", dict(state)) + + # Later on, for example when loading from a checkpoint, + # recreate the optimizer and load the state + optimizer = optim.Adam(learning_rate=1e-2) + + state = tree_unflatten(list(mx.load("optimizer.safetensors").items())) + optimizer.state = state + +Note, not every optimizer configuation parameter is saved in the state. For +example, for Adam the learning rate is saved but the ``betas`` and ``eps`` +parameters are not. A good rule of thumb is if the parameter can be scheduled +then it will be included in the optimizer state. + .. toctree:: optimizers/optimizer diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index 58198f1d4..892e8b40f 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -104,6 +104,7 @@ class Optimizer: @state.setter def state(self, state: dict): + self._initialized = True self._state = state @property