From ae5b5cabfd2a3b8303abc571c673c4a84ca71c6b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 15 Aug 2024 07:33:23 -0700 Subject: [PATCH] Fix optimizer reloading from checkpoint (#1329) * fix optimizer reloading from checkpoint * comment --- python/mlx/optimizers/optimizers.py | 26 +++++++++++++++-- python/tests/test_optimizers.py | 44 +++++++++++++++++++++++------ 2 files changed, 58 insertions(+), 12 deletions(-) diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index 892e8b40f..a997a031d 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -48,8 +48,28 @@ class Optimizer: >>> optimizer.state.keys() dict_keys(['step', 'learning_rate', 'weight', 'bias']) """ - self._state.update(tree_map(lambda x: {}, parameters)) - tree_map(self.init_single, parameters, self._state) + + # Iniatilize the optimizer state to match the parameter state + def update_state(params, state): + if isinstance(params, (list, tuple)): + state = list(state) + for i in range(len(state)): + state[i] = update_state(params[i], state[i]) + if len(state) != len(params): + state.extend(tree_map(lambda x: {}, params[len(state) :])) + return type(params)(state) + elif isinstance(params, dict): + for k, v in params.items(): + if k not in state: + state[k] = tree_map(lambda x: {}, v) + else: + state[k] = update_state(v, state[k]) + return state + else: + return state + + update_state(parameters, self._state) + tree_map(lambda p, s: s or self.init_single(p, s), parameters, self._state) self._initialized = True def init_single(self, parameter: mx.array, state: dict): @@ -104,7 +124,7 @@ class Optimizer: @state.setter def state(self, state: dict): - self._initialized = True + self._initialized = False self._state = state @property diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index 1a6e5e431..d817c10e9 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -10,7 +10,7 @@ import mlx.nn as nn import mlx.optimizers as opt import mlx.utils import mlx_tests -from mlx.utils import tree_flatten, tree_map +from mlx.utils import tree_flatten, tree_map, tree_unflatten def get_all_optimizers(): @@ -206,20 +206,22 @@ class TestOptimizers(mlx_tests.MLXTestCase): def test_adafactor(self): x = mx.zeros((5, 5)) - grad = mx.ones_like(x) + params = {"x": x} + grad = {"x": mx.ones_like(x)} optimizer = opt.Adafactor() for _ in range(2): - xp = optimizer.apply_gradients(grad, x) - self.assertEqual(xp.dtype, x.dtype) - self.assertEqual(xp.shape, x.shape) + xp = optimizer.apply_gradients(grad, params) + self.assertEqual(xp["x"].dtype, x.dtype) + self.assertEqual(xp["x"].shape, x.shape) x = mx.zeros((5, 5), mx.float16) - grad = mx.ones_like(x) + params = {"x": x} + grad = {"x": mx.ones_like(x)} optimizer = opt.Adafactor() for _ in range(2): - xp = optimizer.apply_gradients(grad, x) - self.assertEqual(xp.dtype, x.dtype) - self.assertEqual(xp.shape, x.shape) + xp = optimizer.apply_gradients(grad, params) + self.assertEqual(xp["x"].dtype, x.dtype) + self.assertEqual(xp["x"].shape, x.shape) self.assertEqual(optimizer.state["step"], 2) def test_compiled_optimizer(self): @@ -420,6 +422,30 @@ class TestSchedulers(unittest.TestCase): "Gradients were not scaled correctly during clipping.", ) + def test_init_from_state(self): + class Model(nn.Module): + def __init__(self): + super().__init__() + self.l1 = nn.Linear(2, 2) + self.drop = nn.Dropout(p=0.5) + self.l2 = nn.Linear(2, 2) + self.vals = [nn.Linear(2, 2), nn.ReLU(), nn.ReLU()] + + model = Model() + optimizer = opt.Adam(learning_rate=3e-4) + optimizer.init(model.trainable_parameters()) + + # Flatten the state for serialization + state = tree_flatten(optimizer.state) + + # Make a new optimizer and load the state + optimizer = opt.Adam(learning_rate=3e-4) + optimizer.state = tree_unflatten(state) + + # This should work without any errors + grads = model.trainable_parameters() + optimizer.update(model, grads) + if __name__ == "__main__": unittest.main()