mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-17 22:51:19 +08:00
minor fix in optimizer + docs (#1264)
This commit is contained in:
parent
218047c75a
commit
8c01a7893b
@ -31,6 +31,41 @@ model's parameters and the **optimizer state**.
|
||||
# Compute the new parameters but also the optimizer state.
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
|
||||
Saving and Loading
|
||||
------------------
|
||||
|
||||
To serialize an optimizer, save its state. To load an optimizer, load and set
|
||||
the saved state. Here's a simple example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
import mlx.optimizers as optim
|
||||
|
||||
optimizer = optim.Adam(learning_rate=1e-2)
|
||||
|
||||
# Perform some updates with the optimizer
|
||||
model = {"w" : mx.zeros((5, 5))}
|
||||
grads = {"w" : mx.ones((5, 5))}
|
||||
optimizer.update(model, grads)
|
||||
|
||||
# Save the state
|
||||
state = tree_flatten(optimizer.state)
|
||||
mx.save_safetensors("optimizer.safetensors", dict(state))
|
||||
|
||||
# Later on, for example when loading from a checkpoint,
|
||||
# recreate the optimizer and load the state
|
||||
optimizer = optim.Adam(learning_rate=1e-2)
|
||||
|
||||
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
||||
optimizer.state = state
|
||||
|
||||
Note, not every optimizer configuation parameter is saved in the state. For
|
||||
example, for Adam the learning rate is saved but the ``betas`` and ``eps``
|
||||
parameters are not. A good rule of thumb is if the parameter can be scheduled
|
||||
then it will be included in the optimizer state.
|
||||
|
||||
.. toctree::
|
||||
|
||||
optimizers/optimizer
|
||||
|
@ -104,6 +104,7 @@ class Optimizer:
|
||||
|
||||
@state.setter
|
||||
def state(self, state: dict):
|
||||
self._initialized = True
|
||||
self._state = state
|
||||
|
||||
@property
|
||||
|
Loading…
Reference in New Issue
Block a user