mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Support LR schedulers (#334)
* Add a few LR schedulers * Move parents's constructor call to the top * Fix docstring * refactor optimizers into two files * add docs * nit * Fix Callable type annotation for python 3.8 --------- Co-authored-by: Awni Hannun <awni@apple.com> Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
		| @@ -1,6 +1,7 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import inspect | ||||
| import math | ||||
| import unittest | ||||
| from functools import partial | ||||
|  | ||||
| @@ -15,9 +16,12 @@ from mlx.utils import tree_flatten, tree_map | ||||
| def get_all_optimizers(): | ||||
|     classes = dict() | ||||
|     for name, obj in inspect.getmembers(opt): | ||||
|         if inspect.isclass(obj): | ||||
|             if obj.__name__ not in ["Optimizer"]: | ||||
|                 classes[name] = obj | ||||
|         if ( | ||||
|             inspect.isclass(obj) | ||||
|             and issubclass(obj, opt.Optimizer) | ||||
|             and obj != opt.Optimizer | ||||
|         ): | ||||
|             classes[name] = obj | ||||
|     return classes | ||||
|  | ||||
|  | ||||
| @@ -204,18 +208,16 @@ class TestOptimizers(mlx_tests.MLXTestCase): | ||||
|         x = mx.zeros((5, 5)) | ||||
|         grad = mx.ones_like(x) | ||||
|         optimizer = opt.Adafactor() | ||||
|         optimizer.init(x) | ||||
|         for _ in range(2): | ||||
|             xp = optimizer.apply_single(grad, x, optimizer.state) | ||||
|             xp = optimizer.apply_gradients(grad, x) | ||||
|             self.assertEqual(xp.dtype, x.dtype) | ||||
|             self.assertEqual(xp.shape, x.shape) | ||||
|  | ||||
|         x = mx.zeros((5, 5), mx.float16) | ||||
|         grad = mx.ones_like(x) | ||||
|         optimizer = opt.Adafactor() | ||||
|         optimizer.init(x) | ||||
|         for _ in range(2): | ||||
|             xp = optimizer.apply_single(grad, x, optimizer.state) | ||||
|             xp = optimizer.apply_gradients(grad, x) | ||||
|             self.assertEqual(xp.dtype, x.dtype) | ||||
|             self.assertEqual(xp.shape, x.shape) | ||||
|         self.assertEqual(optimizer.state["step"], 2) | ||||
| @@ -294,5 +296,50 @@ class TestOptimizers(mlx_tests.MLXTestCase): | ||||
|         self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0))) | ||||
|  | ||||
|  | ||||
| class TestSchedulers(unittest.TestCase): | ||||
|     def test_decay_lr(self): | ||||
|         for optim_class in optimizers_dict.values(): | ||||
|             lr_schedule = opt.step_decay(1e-1, 0.9, 1000) | ||||
|             optimizer = optim_class(learning_rate=lr_schedule) | ||||
|  | ||||
|             params = {"w": mx.ones((5, 5))} | ||||
|             grads = tree_map(lambda x: mx.ones_like(x), params) | ||||
|  | ||||
|             for it in range(10): | ||||
|                 expected_lr = 0.1 * (0.9**it) | ||||
|                 self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7) | ||||
|                 return optimizer.apply_gradients(grads, params) | ||||
|  | ||||
|     def test_step_decay(self): | ||||
|         lr_schedule = opt.step_decay(1e-1, 0.9, 1000) | ||||
|         lr = lr_schedule(2500) | ||||
|         expected_lr = 0.1 * (0.9**2) | ||||
|         self.assertAlmostEqual(lr, expected_lr, delta=1e-7) | ||||
|  | ||||
|     def test_exponential_decay(self): | ||||
|         lr_schedule = opt.exponential_decay(1e-1, 0.99) | ||||
|         lr = lr_schedule(10) | ||||
|         expected_lr = 0.1 * (0.99**10) | ||||
|         self.assertAlmostEqual(lr, expected_lr, delta=1e-7) | ||||
|  | ||||
|     def test_cosine_decay(self): | ||||
|         lr_schedule = opt.cosine_decay(0.1, 10) | ||||
|         lr = lr_schedule(4) | ||||
|         expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10)) | ||||
|         self.assertAlmostEqual(lr, expected_lr, delta=1e-7) | ||||
|  | ||||
|     def test_compile_with_schedule(self): | ||||
|         lr_schedule = opt.exponential_decay(1e-1, 0.9) | ||||
|         optimizer = opt.SGD(learning_rate=lr_schedule) | ||||
|  | ||||
|         @partial(mx.compile, inputs=optimizer.state, outputs=optimizer.state) | ||||
|         def update(): | ||||
|             optimizer.update({}, {}) | ||||
|  | ||||
|         for step in range(5): | ||||
|             update() | ||||
|             self.assertAlmostEqual(lr_schedule(step), optimizer.learning_rate.item()) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Srimukh Sripada
					Srimukh Sripada