mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-12 23:34:36 +08:00
Fix optimizer reloading from checkpoint (#1329)
* fix optimizer reloading from checkpoint * comment
This commit is contained in:
@@ -10,7 +10,7 @@ import mlx.nn as nn
|
||||
import mlx.optimizers as opt
|
||||
import mlx.utils
|
||||
import mlx_tests
|
||||
from mlx.utils import tree_flatten, tree_map
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
|
||||
def get_all_optimizers():
|
||||
@@ -206,20 +206,22 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_adafactor(self):
|
||||
x = mx.zeros((5, 5))
|
||||
grad = mx.ones_like(x)
|
||||
params = {"x": x}
|
||||
grad = {"x": mx.ones_like(x)}
|
||||
optimizer = opt.Adafactor()
|
||||
for _ in range(2):
|
||||
xp = optimizer.apply_gradients(grad, x)
|
||||
self.assertEqual(xp.dtype, x.dtype)
|
||||
self.assertEqual(xp.shape, x.shape)
|
||||
xp = optimizer.apply_gradients(grad, params)
|
||||
self.assertEqual(xp["x"].dtype, x.dtype)
|
||||
self.assertEqual(xp["x"].shape, x.shape)
|
||||
|
||||
x = mx.zeros((5, 5), mx.float16)
|
||||
grad = mx.ones_like(x)
|
||||
params = {"x": x}
|
||||
grad = {"x": mx.ones_like(x)}
|
||||
optimizer = opt.Adafactor()
|
||||
for _ in range(2):
|
||||
xp = optimizer.apply_gradients(grad, x)
|
||||
self.assertEqual(xp.dtype, x.dtype)
|
||||
self.assertEqual(xp.shape, x.shape)
|
||||
xp = optimizer.apply_gradients(grad, params)
|
||||
self.assertEqual(xp["x"].dtype, x.dtype)
|
||||
self.assertEqual(xp["x"].shape, x.shape)
|
||||
self.assertEqual(optimizer.state["step"], 2)
|
||||
|
||||
def test_compiled_optimizer(self):
|
||||
@@ -420,6 +422,30 @@ class TestSchedulers(unittest.TestCase):
|
||||
"Gradients were not scaled correctly during clipping.",
|
||||
)
|
||||
|
||||
def test_init_from_state(self):
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l1 = nn.Linear(2, 2)
|
||||
self.drop = nn.Dropout(p=0.5)
|
||||
self.l2 = nn.Linear(2, 2)
|
||||
self.vals = [nn.Linear(2, 2), nn.ReLU(), nn.ReLU()]
|
||||
|
||||
model = Model()
|
||||
optimizer = opt.Adam(learning_rate=3e-4)
|
||||
optimizer.init(model.trainable_parameters())
|
||||
|
||||
# Flatten the state for serialization
|
||||
state = tree_flatten(optimizer.state)
|
||||
|
||||
# Make a new optimizer and load the state
|
||||
optimizer = opt.Adam(learning_rate=3e-4)
|
||||
optimizer.state = tree_unflatten(state)
|
||||
|
||||
# This should work without any errors
|
||||
grads = model.trainable_parameters()
|
||||
optimizer.update(model, grads)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user