mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00
Fix optimizer reloading from checkpoint (#1329)
* fix optimizer reloading from checkpoint * comment
This commit is contained in:
parent
d0630ffe8c
commit
ae5b5cabfd
@ -48,8 +48,28 @@ class Optimizer:
|
|||||||
>>> optimizer.state.keys()
|
>>> optimizer.state.keys()
|
||||||
dict_keys(['step', 'learning_rate', 'weight', 'bias'])
|
dict_keys(['step', 'learning_rate', 'weight', 'bias'])
|
||||||
"""
|
"""
|
||||||
self._state.update(tree_map(lambda x: {}, parameters))
|
|
||||||
tree_map(self.init_single, parameters, self._state)
|
# Iniatilize the optimizer state to match the parameter state
|
||||||
|
def update_state(params, state):
|
||||||
|
if isinstance(params, (list, tuple)):
|
||||||
|
state = list(state)
|
||||||
|
for i in range(len(state)):
|
||||||
|
state[i] = update_state(params[i], state[i])
|
||||||
|
if len(state) != len(params):
|
||||||
|
state.extend(tree_map(lambda x: {}, params[len(state) :]))
|
||||||
|
return type(params)(state)
|
||||||
|
elif isinstance(params, dict):
|
||||||
|
for k, v in params.items():
|
||||||
|
if k not in state:
|
||||||
|
state[k] = tree_map(lambda x: {}, v)
|
||||||
|
else:
|
||||||
|
state[k] = update_state(v, state[k])
|
||||||
|
return state
|
||||||
|
else:
|
||||||
|
return state
|
||||||
|
|
||||||
|
update_state(parameters, self._state)
|
||||||
|
tree_map(lambda p, s: s or self.init_single(p, s), parameters, self._state)
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
def init_single(self, parameter: mx.array, state: dict):
|
def init_single(self, parameter: mx.array, state: dict):
|
||||||
@ -104,7 +124,7 @@ class Optimizer:
|
|||||||
|
|
||||||
@state.setter
|
@state.setter
|
||||||
def state(self, state: dict):
|
def state(self, state: dict):
|
||||||
self._initialized = True
|
self._initialized = False
|
||||||
self._state = state
|
self._state = state
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -10,7 +10,7 @@ import mlx.nn as nn
|
|||||||
import mlx.optimizers as opt
|
import mlx.optimizers as opt
|
||||||
import mlx.utils
|
import mlx.utils
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
from mlx.utils import tree_flatten, tree_map
|
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||||
|
|
||||||
|
|
||||||
def get_all_optimizers():
|
def get_all_optimizers():
|
||||||
@ -206,20 +206,22 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
def test_adafactor(self):
|
def test_adafactor(self):
|
||||||
x = mx.zeros((5, 5))
|
x = mx.zeros((5, 5))
|
||||||
grad = mx.ones_like(x)
|
params = {"x": x}
|
||||||
|
grad = {"x": mx.ones_like(x)}
|
||||||
optimizer = opt.Adafactor()
|
optimizer = opt.Adafactor()
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
xp = optimizer.apply_gradients(grad, x)
|
xp = optimizer.apply_gradients(grad, params)
|
||||||
self.assertEqual(xp.dtype, x.dtype)
|
self.assertEqual(xp["x"].dtype, x.dtype)
|
||||||
self.assertEqual(xp.shape, x.shape)
|
self.assertEqual(xp["x"].shape, x.shape)
|
||||||
|
|
||||||
x = mx.zeros((5, 5), mx.float16)
|
x = mx.zeros((5, 5), mx.float16)
|
||||||
grad = mx.ones_like(x)
|
params = {"x": x}
|
||||||
|
grad = {"x": mx.ones_like(x)}
|
||||||
optimizer = opt.Adafactor()
|
optimizer = opt.Adafactor()
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
xp = optimizer.apply_gradients(grad, x)
|
xp = optimizer.apply_gradients(grad, params)
|
||||||
self.assertEqual(xp.dtype, x.dtype)
|
self.assertEqual(xp["x"].dtype, x.dtype)
|
||||||
self.assertEqual(xp.shape, x.shape)
|
self.assertEqual(xp["x"].shape, x.shape)
|
||||||
self.assertEqual(optimizer.state["step"], 2)
|
self.assertEqual(optimizer.state["step"], 2)
|
||||||
|
|
||||||
def test_compiled_optimizer(self):
|
def test_compiled_optimizer(self):
|
||||||
@ -420,6 +422,30 @@ class TestSchedulers(unittest.TestCase):
|
|||||||
"Gradients were not scaled correctly during clipping.",
|
"Gradients were not scaled correctly during clipping.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_init_from_state(self):
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.l1 = nn.Linear(2, 2)
|
||||||
|
self.drop = nn.Dropout(p=0.5)
|
||||||
|
self.l2 = nn.Linear(2, 2)
|
||||||
|
self.vals = [nn.Linear(2, 2), nn.ReLU(), nn.ReLU()]
|
||||||
|
|
||||||
|
model = Model()
|
||||||
|
optimizer = opt.Adam(learning_rate=3e-4)
|
||||||
|
optimizer.init(model.trainable_parameters())
|
||||||
|
|
||||||
|
# Flatten the state for serialization
|
||||||
|
state = tree_flatten(optimizer.state)
|
||||||
|
|
||||||
|
# Make a new optimizer and load the state
|
||||||
|
optimizer = opt.Adam(learning_rate=3e-4)
|
||||||
|
optimizer.state = tree_unflatten(state)
|
||||||
|
|
||||||
|
# This should work without any errors
|
||||||
|
grads = model.trainable_parameters()
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user