mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-17 22:51:19 +08:00
minor fix in optimizer + docs (#1264)
This commit is contained in:
parent
218047c75a
commit
8c01a7893b
@ -31,6 +31,41 @@ model's parameters and the **optimizer state**.
|
|||||||
# Compute the new parameters but also the optimizer state.
|
# Compute the new parameters but also the optimizer state.
|
||||||
mx.eval(model.parameters(), optimizer.state)
|
mx.eval(model.parameters(), optimizer.state)
|
||||||
|
|
||||||
|
Saving and Loading
|
||||||
|
------------------
|
||||||
|
|
||||||
|
To serialize an optimizer, save its state. To load an optimizer, load and set
|
||||||
|
the saved state. Here's a simple example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx.utils import tree_flatten, tree_unflatten
|
||||||
|
import mlx.optimizers as optim
|
||||||
|
|
||||||
|
optimizer = optim.Adam(learning_rate=1e-2)
|
||||||
|
|
||||||
|
# Perform some updates with the optimizer
|
||||||
|
model = {"w" : mx.zeros((5, 5))}
|
||||||
|
grads = {"w" : mx.ones((5, 5))}
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
|
||||||
|
# Save the state
|
||||||
|
state = tree_flatten(optimizer.state)
|
||||||
|
mx.save_safetensors("optimizer.safetensors", dict(state))
|
||||||
|
|
||||||
|
# Later on, for example when loading from a checkpoint,
|
||||||
|
# recreate the optimizer and load the state
|
||||||
|
optimizer = optim.Adam(learning_rate=1e-2)
|
||||||
|
|
||||||
|
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
||||||
|
optimizer.state = state
|
||||||
|
|
||||||
|
Note, not every optimizer configuation parameter is saved in the state. For
|
||||||
|
example, for Adam the learning rate is saved but the ``betas`` and ``eps``
|
||||||
|
parameters are not. A good rule of thumb is if the parameter can be scheduled
|
||||||
|
then it will be included in the optimizer state.
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
|
|
||||||
optimizers/optimizer
|
optimizers/optimizer
|
||||||
|
@ -104,6 +104,7 @@ class Optimizer:
|
|||||||
|
|
||||||
@state.setter
|
@state.setter
|
||||||
def state(self, state: dict):
|
def state(self, state: dict):
|
||||||
|
self._initialized = True
|
||||||
self._state = state
|
self._state = state
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
Loading…
Reference in New Issue
Block a user