mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	minor fix in optimizer + docs (#1264)
This commit is contained in:
		| @@ -31,6 +31,41 @@ model's parameters and the **optimizer state**. | ||||
|             # Compute the new parameters but also the optimizer state. | ||||
|             mx.eval(model.parameters(), optimizer.state) | ||||
|  | ||||
| Saving and Loading | ||||
| ------------------ | ||||
|  | ||||
| To serialize an optimizer, save its state. To load an optimizer, load and set | ||||
| the saved state. Here's a simple example: | ||||
|  | ||||
| .. code-block:: python | ||||
|  | ||||
|    import mlx.core as mx | ||||
|    from mlx.utils import tree_flatten, tree_unflatten | ||||
|    import mlx.optimizers as optim | ||||
|  | ||||
|    optimizer = optim.Adam(learning_rate=1e-2) | ||||
|  | ||||
|    # Perform some updates with the optimizer | ||||
|    model = {"w" : mx.zeros((5, 5))} | ||||
|    grads = {"w" : mx.ones((5, 5))} | ||||
|    optimizer.update(model, grads) | ||||
|  | ||||
|    # Save the state | ||||
|    state = tree_flatten(optimizer.state) | ||||
|    mx.save_safetensors("optimizer.safetensors", dict(state)) | ||||
|  | ||||
|    # Later on, for example when loading from a checkpoint, | ||||
|    # recreate the optimizer and load the state | ||||
|    optimizer = optim.Adam(learning_rate=1e-2) | ||||
|  | ||||
|    state = tree_unflatten(list(mx.load("optimizer.safetensors").items())) | ||||
|    optimizer.state = state | ||||
|  | ||||
| Note, not every optimizer configuation parameter is saved in the state. For | ||||
| example, for Adam the learning rate is saved but the ``betas`` and ``eps`` | ||||
| parameters are not. A good rule of thumb is if the parameter can be scheduled | ||||
| then it will be included in the optimizer state. | ||||
|  | ||||
| .. toctree:: | ||||
|  | ||||
|    optimizers/optimizer | ||||
|   | ||||
| @@ -104,6 +104,7 @@ class Optimizer: | ||||
|  | ||||
|     @state.setter | ||||
|     def state(self, state: dict): | ||||
|         self._initialized = True | ||||
|         self._state = state | ||||
|  | ||||
|     @property | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun