Support LR schedulers (#334)

* Add a few LR schedulers

* Move parents's constructor call to the top

* Fix docstring

* refactor optimizers into two files

* add docs

* nit

* Fix Callable type annotation for python 3.8

---------

Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Srimukh Sripada
2024-02-15 20:26:20 +01:00
committed by GitHub
parent 85143fecdd
commit 818cda16bc
10 changed files with 235 additions and 47 deletions

View File

@@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc.
import inspect
import math
import unittest
from functools import partial
@@ -15,9 +16,12 @@ from mlx.utils import tree_flatten, tree_map
def get_all_optimizers():
classes = dict()
for name, obj in inspect.getmembers(opt):
if inspect.isclass(obj):
if obj.__name__ not in ["Optimizer"]:
classes[name] = obj
if (
inspect.isclass(obj)
and issubclass(obj, opt.Optimizer)
and obj != opt.Optimizer
):
classes[name] = obj
return classes
@@ -204,18 +208,16 @@ class TestOptimizers(mlx_tests.MLXTestCase):
x = mx.zeros((5, 5))
grad = mx.ones_like(x)
optimizer = opt.Adafactor()
optimizer.init(x)
for _ in range(2):
xp = optimizer.apply_single(grad, x, optimizer.state)
xp = optimizer.apply_gradients(grad, x)
self.assertEqual(xp.dtype, x.dtype)
self.assertEqual(xp.shape, x.shape)
x = mx.zeros((5, 5), mx.float16)
grad = mx.ones_like(x)
optimizer = opt.Adafactor()
optimizer.init(x)
for _ in range(2):
xp = optimizer.apply_single(grad, x, optimizer.state)
xp = optimizer.apply_gradients(grad, x)
self.assertEqual(xp.dtype, x.dtype)
self.assertEqual(xp.shape, x.shape)
self.assertEqual(optimizer.state["step"], 2)
@@ -294,5 +296,50 @@ class TestOptimizers(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
class TestSchedulers(unittest.TestCase):
def test_decay_lr(self):
for optim_class in optimizers_dict.values():
lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
optimizer = optim_class(learning_rate=lr_schedule)
params = {"w": mx.ones((5, 5))}
grads = tree_map(lambda x: mx.ones_like(x), params)
for it in range(10):
expected_lr = 0.1 * (0.9**it)
self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7)
return optimizer.apply_gradients(grads, params)
def test_step_decay(self):
lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
lr = lr_schedule(2500)
expected_lr = 0.1 * (0.9**2)
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
def test_exponential_decay(self):
lr_schedule = opt.exponential_decay(1e-1, 0.99)
lr = lr_schedule(10)
expected_lr = 0.1 * (0.99**10)
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
def test_cosine_decay(self):
lr_schedule = opt.cosine_decay(0.1, 10)
lr = lr_schedule(4)
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
def test_compile_with_schedule(self):
lr_schedule = opt.exponential_decay(1e-1, 0.9)
optimizer = opt.SGD(learning_rate=lr_schedule)
@partial(mx.compile, inputs=optimizer.state, outputs=optimizer.state)
def update():
optimizer.update({}, {})
for step in range(5):
update()
self.assertAlmostEqual(lr_schedule(step), optimizer.learning_rate.item())
if __name__ == "__main__":
unittest.main()