mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Support LR schedulers (#334)
* Add a few LR schedulers * Move parents's constructor call to the top * Fix docstring * refactor optimizers into two files * add docs * nit * Fix Callable type annotation for python 3.8 --------- Co-authored-by: Awni Hannun <awni@apple.com> Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
@@ -15,9 +16,12 @@ from mlx.utils import tree_flatten, tree_map
|
||||
def get_all_optimizers():
|
||||
classes = dict()
|
||||
for name, obj in inspect.getmembers(opt):
|
||||
if inspect.isclass(obj):
|
||||
if obj.__name__ not in ["Optimizer"]:
|
||||
classes[name] = obj
|
||||
if (
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, opt.Optimizer)
|
||||
and obj != opt.Optimizer
|
||||
):
|
||||
classes[name] = obj
|
||||
return classes
|
||||
|
||||
|
||||
@@ -204,18 +208,16 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
x = mx.zeros((5, 5))
|
||||
grad = mx.ones_like(x)
|
||||
optimizer = opt.Adafactor()
|
||||
optimizer.init(x)
|
||||
for _ in range(2):
|
||||
xp = optimizer.apply_single(grad, x, optimizer.state)
|
||||
xp = optimizer.apply_gradients(grad, x)
|
||||
self.assertEqual(xp.dtype, x.dtype)
|
||||
self.assertEqual(xp.shape, x.shape)
|
||||
|
||||
x = mx.zeros((5, 5), mx.float16)
|
||||
grad = mx.ones_like(x)
|
||||
optimizer = opt.Adafactor()
|
||||
optimizer.init(x)
|
||||
for _ in range(2):
|
||||
xp = optimizer.apply_single(grad, x, optimizer.state)
|
||||
xp = optimizer.apply_gradients(grad, x)
|
||||
self.assertEqual(xp.dtype, x.dtype)
|
||||
self.assertEqual(xp.shape, x.shape)
|
||||
self.assertEqual(optimizer.state["step"], 2)
|
||||
@@ -294,5 +296,50 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
|
||||
|
||||
|
||||
class TestSchedulers(unittest.TestCase):
|
||||
def test_decay_lr(self):
|
||||
for optim_class in optimizers_dict.values():
|
||||
lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
|
||||
optimizer = optim_class(learning_rate=lr_schedule)
|
||||
|
||||
params = {"w": mx.ones((5, 5))}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
for it in range(10):
|
||||
expected_lr = 0.1 * (0.9**it)
|
||||
self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7)
|
||||
return optimizer.apply_gradients(grads, params)
|
||||
|
||||
def test_step_decay(self):
|
||||
lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
|
||||
lr = lr_schedule(2500)
|
||||
expected_lr = 0.1 * (0.9**2)
|
||||
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||
|
||||
def test_exponential_decay(self):
|
||||
lr_schedule = opt.exponential_decay(1e-1, 0.99)
|
||||
lr = lr_schedule(10)
|
||||
expected_lr = 0.1 * (0.99**10)
|
||||
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||
|
||||
def test_cosine_decay(self):
|
||||
lr_schedule = opt.cosine_decay(0.1, 10)
|
||||
lr = lr_schedule(4)
|
||||
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
|
||||
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||
|
||||
def test_compile_with_schedule(self):
|
||||
lr_schedule = opt.exponential_decay(1e-1, 0.9)
|
||||
optimizer = opt.SGD(learning_rate=lr_schedule)
|
||||
|
||||
@partial(mx.compile, inputs=optimizer.state, outputs=optimizer.state)
|
||||
def update():
|
||||
optimizer.update({}, {})
|
||||
|
||||
for step in range(5):
|
||||
update()
|
||||
self.assertAlmostEqual(lr_schedule(step), optimizer.learning_rate.item())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user