mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Support LR schedulers (#334)
* Add a few LR schedulers * Move parents's constructor call to the top * Fix docstring * refactor optimizers into two files * add docs * nit * Fix Callable type annotation for python 3.8 --------- Co-authored-by: Awni Hannun <awni@apple.com> Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
		
							
								
								
									
										4
									
								
								python/mlx/optimizers/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								python/mlx/optimizers/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,4 @@ | ||||
| # Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| from mlx.optimizers.optimizers import * | ||||
| from mlx.optimizers.schedulers import * | ||||
| @@ -1,7 +1,7 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
| # Copyright © 2023-2024 Apple Inc. | ||||
| 
 | ||||
| import math | ||||
| from typing import List, Optional, Tuple | ||||
| from typing import Callable, List, Optional, Tuple, Union | ||||
| 
 | ||||
| import mlx.core as mx | ||||
| from mlx.utils import tree_map | ||||
| @@ -12,9 +12,10 @@ class Optimizer: | ||||
|     optimizer on a per-parameter basis and apply it to a parameter tree. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self): | ||||
|     def __init__(self, schedulers=None): | ||||
|         self._initialized = False | ||||
|         self._state = {} | ||||
|         self._state = {"step": mx.array(0, mx.uint64)} | ||||
|         self._schedulers = {k: v for k, v in (schedulers or {}).items()} | ||||
| 
 | ||||
|     def update(self, model: "mlx.nn.Module", gradients: dict): | ||||
|         """Apply the gradients to the parameters of the model and update the | ||||
| @@ -44,9 +45,8 @@ class Optimizer: | ||||
|             >>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9) | ||||
|             >>> model = nn.Linear(2, 2) | ||||
|             >>> optimizer.init(model.trainable_parameters()) | ||||
|             >>> optimizer.state | ||||
|             {'learning_rate': array(0.1, dtype=float32), 'weight': {'v': array([[0, 0], | ||||
|                    [0, 0]], dtype=float32)}, 'bias': {'v': array([0, 0], dtype=float32)}} | ||||
|             >>> optimizer.state.keys() | ||||
|             dict_keys(['step', 'learning_rate', 'weight', 'bias']) | ||||
|         """ | ||||
|         self._state.update(tree_map(lambda x: {}, parameters)) | ||||
|         tree_map(self.init_single, parameters, self._state) | ||||
| @@ -76,6 +76,15 @@ class Optimizer: | ||||
|         """ | ||||
|         if not self._initialized: | ||||
|             self.init(gradients) | ||||
| 
 | ||||
|         # Update any scheduled variables | ||||
|         for param, scheduler in self._schedulers.items(): | ||||
|             self.state[param] = scheduler(self.step) | ||||
| 
 | ||||
|         # Increment the step | ||||
|         self.state["step"] = self.step + 1 | ||||
| 
 | ||||
|         # Apply the update | ||||
|         return tree_map(self.apply_single, gradients, parameters, self.state) | ||||
| 
 | ||||
|     def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): | ||||
| @@ -97,14 +106,31 @@ class Optimizer: | ||||
|     def state(self, state: dict): | ||||
|         self._state = state | ||||
| 
 | ||||
|     @property | ||||
|     def step(self): | ||||
|         return self.state["step"] | ||||
| 
 | ||||
|     @property | ||||
|     def learning_rate(self): | ||||
|         return self.state["learning_rate"] | ||||
| 
 | ||||
|     @learning_rate.setter | ||||
|     def learning_rate(self, learning_rate: mx.array): | ||||
|     def learning_rate(self, learning_rate: Union[float, mx.array]): | ||||
|         self.state["learning_rate"] = mx.array(learning_rate) | ||||
| 
 | ||||
|     def _maybe_schedule( | ||||
|         self, name: str, param: Union[float, Callable[[mx.array], mx.array]] | ||||
|     ): | ||||
|         """ | ||||
|         To be used by derived classes to optionally put a parameter on a schedule. | ||||
|         """ | ||||
|         if isinstance(param, Callable): | ||||
|             self._schedulers[name] = param | ||||
|             param = param(self.step) | ||||
|         else: | ||||
|             param = mx.array(param) | ||||
|         self.state[name] = param | ||||
| 
 | ||||
| 
 | ||||
| class SGD(Optimizer): | ||||
|     r"""The stochastic gradient descent optimizer. | ||||
| @@ -117,7 +143,7 @@ class SGD(Optimizer): | ||||
|         w_{t+1} &= w_t - \lambda v_{t+1} | ||||
| 
 | ||||
|     Args: | ||||
|         learning_rate (float): The learning rate :math:`\lambda`. | ||||
|         learning_rate (float or callable): The learning rate :math:`\lambda`. | ||||
|         momentum (float, optional): The momentum strength :math:`\mu`. Default: ``0`` | ||||
|         weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0`` | ||||
|         dampening (float, optional): Dampening for momentum :math:`\tau`. Default: ``0`` | ||||
| @@ -126,7 +152,7 @@ class SGD(Optimizer): | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         learning_rate: float, | ||||
|         learning_rate: Union[float, Callable[[mx.array], mx.array]], | ||||
|         momentum: float = 0.0, | ||||
|         weight_decay: float = 0.0, | ||||
|         dampening: float = 0.0, | ||||
| @@ -138,7 +164,7 @@ class SGD(Optimizer): | ||||
|             ) | ||||
|         super().__init__() | ||||
| 
 | ||||
|         self.learning_rate = learning_rate | ||||
|         self._maybe_schedule("learning_rate", learning_rate) | ||||
|         self.momentum = momentum | ||||
|         self.weight_decay = weight_decay | ||||
|         self.dampening = dampening | ||||
| @@ -194,7 +220,7 @@ class RMSprop(Optimizer): | ||||
|     def __init__(self, learning_rate: float, alpha: float = 0.99, eps: float = 1e-8): | ||||
|         super().__init__() | ||||
| 
 | ||||
|         self.learning_rate = learning_rate | ||||
|         self._maybe_schedule("learning_rate", learning_rate) | ||||
|         self.alpha = alpha | ||||
|         self.eps = eps | ||||
| 
 | ||||
| @@ -246,7 +272,7 @@ class Adagrad(Optimizer): | ||||
|     def __init__(self, learning_rate: float, eps: float = 1e-8): | ||||
|         super().__init__() | ||||
| 
 | ||||
|         self.learning_rate = learning_rate | ||||
|         self._maybe_schedule("learning_rate", learning_rate) | ||||
|         self.eps = eps | ||||
| 
 | ||||
|         if self.eps < 0.0: | ||||
| @@ -295,7 +321,7 @@ class AdaDelta(Optimizer): | ||||
|     def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6): | ||||
|         super().__init__() | ||||
| 
 | ||||
|         self.learning_rate = learning_rate | ||||
|         self._maybe_schedule("learning_rate", learning_rate) | ||||
|         self.rho = rho | ||||
|         self.eps = eps | ||||
|         if self.rho < 0.0: | ||||
| @@ -361,7 +387,7 @@ class Adam(Optimizer): | ||||
|     ): | ||||
|         super().__init__() | ||||
| 
 | ||||
|         self.learning_rate = learning_rate | ||||
|         self._maybe_schedule("learning_rate", learning_rate) | ||||
|         self.betas = betas | ||||
|         self.eps = eps | ||||
| 
 | ||||
| @@ -526,7 +552,7 @@ class Lion(Optimizer): | ||||
|     ): | ||||
|         super().__init__() | ||||
| 
 | ||||
|         self.learning_rate = learning_rate | ||||
|         self._maybe_schedule("learning_rate", learning_rate) | ||||
|         self.betas = betas | ||||
|         self.weight_decay = weight_decay | ||||
| 
 | ||||
| @@ -596,7 +622,7 @@ class Adafactor(Optimizer): | ||||
|     ): | ||||
|         super().__init__() | ||||
|         if learning_rate is not None: | ||||
|             self.learning_rate = learning_rate | ||||
|             self._maybe_schedule("learning_rate", learning_rate) | ||||
|         self.eps = eps | ||||
|         self.clip_threshold = clip_threshold | ||||
|         self.decay_rate = decay_rate | ||||
| @@ -608,7 +634,6 @@ class Adafactor(Optimizer): | ||||
| 
 | ||||
|     def init_single(self, parameter: mx.array, state: dict): | ||||
|         """Initialize optimizer state""" | ||||
|         state["step"] = 0 | ||||
|         if parameter.ndim >= 2: | ||||
|             shape = parameter.shape | ||||
|             dtype = parameter.dtype | ||||
| @@ -626,10 +651,11 @@ class Adafactor(Optimizer): | ||||
|     def _compute_learning_rate(self, step, parameter_rms): | ||||
|         if self.relative_step: | ||||
|             min_step = 1e-6 * step if self.warmup_init else 1e-2 | ||||
|             relative_step_size = min(min_step, 1 / math.sqrt(step)) | ||||
|             relative_step_size = mx.minimum(min_step, mx.rsqrt(step)) | ||||
|         else: | ||||
|             relative_step_size = self.learning_rate.astype(parameter_rms) | ||||
|             relative_step_size = self.learning_rate | ||||
| 
 | ||||
|         relative_step_size = relative_step_size.astype(parameter_rms.dtype) | ||||
|         parameter_scale = 1.0 | ||||
|         if self.scale_parameter: | ||||
|             parameter_scale = mx.maximum(self.eps[1], parameter_rms) | ||||
| @@ -648,13 +674,12 @@ class Adafactor(Optimizer): | ||||
|         """Performs the Adafactor parameter and state update.""" | ||||
|         factored = gradient.ndim >= 2 | ||||
| 
 | ||||
|         step = state["step"] + 1 | ||||
|         state["step"] = step | ||||
|         step = self.step | ||||
|         use_first_moment = self.beta_1 is not None | ||||
| 
 | ||||
|         parameter_rms = self._compute_rms(parameter) | ||||
|         learning_rate = self._compute_learning_rate(step, parameter_rms) | ||||
|         beta_2 = 1.0 - math.pow(step, self.decay_rate) | ||||
|         beta_2 = 1.0 - (step**self.decay_rate).astype(parameter_rms.dtype) | ||||
|         update = mx.square(gradient) + self.eps[0] | ||||
| 
 | ||||
|         if factored: | ||||
							
								
								
									
										86
									
								
								python/mlx/optimizers/schedulers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								python/mlx/optimizers/schedulers.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,86 @@ | ||||
| # Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| import math | ||||
|  | ||||
| import mlx.core as mx | ||||
|  | ||||
|  | ||||
| def exponential_decay(init: float, decay_rate: float): | ||||
|     r"""Make an exponential decay scheduler. | ||||
|  | ||||
|     Args: | ||||
|         init (float): Initial value. | ||||
|         decay_rate (float): Multiplicative factor to decay by. | ||||
|  | ||||
|     Example: | ||||
|         >>> lr_schedule = optim.exponential_decay(1e-1, 0.9) | ||||
|         >>> optimizer = optim.SGD(learning_rate=lr_schedule) | ||||
|         >>> optimizer.learning_rate | ||||
|         array(0.1, dtype=float32) | ||||
|         >>> | ||||
|         >>> for _ in range(5): optimizer.update({}, {}) | ||||
|         ... | ||||
|         >>> optimizer.learning_rate | ||||
|         array(0.06561, dtype=float32) | ||||
|     """ | ||||
|  | ||||
|     def schedule(step): | ||||
|         return init * decay_rate**step | ||||
|  | ||||
|     return schedule | ||||
|  | ||||
|  | ||||
| def step_decay(init: float, decay_rate: float, step_size: int): | ||||
|     r"""Make a step decay scheduler. | ||||
|  | ||||
|     Args: | ||||
|         init (float): Initial value. | ||||
|         decay_rate (float): Multiplicative factor to decay by. | ||||
|         step_size (int): Decay every ``step_size`` steps. | ||||
|  | ||||
|     Example: | ||||
|  | ||||
|         >>> lr_schedule = optim.step_decay(1e-1, 0.9, 10) | ||||
|         >>> optimizer = optim.SGD(learning_rate=lr_schedule) | ||||
|         >>> optimizer.learning_rate | ||||
|         array(0.1, dtype=float32) | ||||
|         >>> | ||||
|         >>> for _ in range(21): optimizer.update({}, {}) | ||||
|         ... | ||||
|         >>> optimizer.learning_rate | ||||
|         array(0.081, dtype=float32) | ||||
|     """ | ||||
|  | ||||
|     def schedule(step): | ||||
|         return init * (decay_rate ** (step // step_size)) | ||||
|  | ||||
|     return schedule | ||||
|  | ||||
|  | ||||
| def cosine_decay(init: float, decay_steps: int): | ||||
|     r"""Make a cosine decay scheduler. | ||||
|  | ||||
|     Args: | ||||
|         init (float): Initial value. | ||||
|         decay_steps (int): Number of steps to decay over. The decayed | ||||
|             value is constant for steps beyond ``decay_steps``. | ||||
|  | ||||
|     Example: | ||||
|  | ||||
|         >>> lr_schedule = optim.cosine_decay(1e-1, 1000) | ||||
|         >>> optimizer = optim.SGD(learning_rate=lr_schedule) | ||||
|         >>> optimizer.learning_rate | ||||
|         array(0.1, dtype=float32) | ||||
|         >>> | ||||
|         >>> for _ in range(5): optimizer.update({}, {}) | ||||
|         ... | ||||
|         >>> optimizer.learning_rate | ||||
|         array(0.0999961, dtype=float32) | ||||
|     """ | ||||
|  | ||||
|     def scheduler(step): | ||||
|         s = mx.minimum(step, decay_steps) | ||||
|         decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s)) | ||||
|         return init * decay | ||||
|  | ||||
|     return scheduler | ||||
| @@ -971,6 +971,12 @@ void init_array(py::module_& m) { | ||||
|             return power(a, to_array(v, a.dtype())); | ||||
|           }, | ||||
|           "other"_a) | ||||
|       .def( | ||||
|           "__rpow__", | ||||
|           [](const array& a, const ScalarOrArray v) { | ||||
|             return power(to_array(v, a.dtype()), a); | ||||
|           }, | ||||
|           "other"_a) | ||||
|       .def( | ||||
|           "__ipow__", | ||||
|           [](array& a, const ScalarOrArray v) { | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import inspect | ||||
| import math | ||||
| import unittest | ||||
| from functools import partial | ||||
|  | ||||
| @@ -15,9 +16,12 @@ from mlx.utils import tree_flatten, tree_map | ||||
| def get_all_optimizers(): | ||||
|     classes = dict() | ||||
|     for name, obj in inspect.getmembers(opt): | ||||
|         if inspect.isclass(obj): | ||||
|             if obj.__name__ not in ["Optimizer"]: | ||||
|                 classes[name] = obj | ||||
|         if ( | ||||
|             inspect.isclass(obj) | ||||
|             and issubclass(obj, opt.Optimizer) | ||||
|             and obj != opt.Optimizer | ||||
|         ): | ||||
|             classes[name] = obj | ||||
|     return classes | ||||
|  | ||||
|  | ||||
| @@ -204,18 +208,16 @@ class TestOptimizers(mlx_tests.MLXTestCase): | ||||
|         x = mx.zeros((5, 5)) | ||||
|         grad = mx.ones_like(x) | ||||
|         optimizer = opt.Adafactor() | ||||
|         optimizer.init(x) | ||||
|         for _ in range(2): | ||||
|             xp = optimizer.apply_single(grad, x, optimizer.state) | ||||
|             xp = optimizer.apply_gradients(grad, x) | ||||
|             self.assertEqual(xp.dtype, x.dtype) | ||||
|             self.assertEqual(xp.shape, x.shape) | ||||
|  | ||||
|         x = mx.zeros((5, 5), mx.float16) | ||||
|         grad = mx.ones_like(x) | ||||
|         optimizer = opt.Adafactor() | ||||
|         optimizer.init(x) | ||||
|         for _ in range(2): | ||||
|             xp = optimizer.apply_single(grad, x, optimizer.state) | ||||
|             xp = optimizer.apply_gradients(grad, x) | ||||
|             self.assertEqual(xp.dtype, x.dtype) | ||||
|             self.assertEqual(xp.shape, x.shape) | ||||
|         self.assertEqual(optimizer.state["step"], 2) | ||||
| @@ -294,5 +296,50 @@ class TestOptimizers(mlx_tests.MLXTestCase): | ||||
|         self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0))) | ||||
|  | ||||
|  | ||||
| class TestSchedulers(unittest.TestCase): | ||||
|     def test_decay_lr(self): | ||||
|         for optim_class in optimizers_dict.values(): | ||||
|             lr_schedule = opt.step_decay(1e-1, 0.9, 1000) | ||||
|             optimizer = optim_class(learning_rate=lr_schedule) | ||||
|  | ||||
|             params = {"w": mx.ones((5, 5))} | ||||
|             grads = tree_map(lambda x: mx.ones_like(x), params) | ||||
|  | ||||
|             for it in range(10): | ||||
|                 expected_lr = 0.1 * (0.9**it) | ||||
|                 self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7) | ||||
|                 return optimizer.apply_gradients(grads, params) | ||||
|  | ||||
|     def test_step_decay(self): | ||||
|         lr_schedule = opt.step_decay(1e-1, 0.9, 1000) | ||||
|         lr = lr_schedule(2500) | ||||
|         expected_lr = 0.1 * (0.9**2) | ||||
|         self.assertAlmostEqual(lr, expected_lr, delta=1e-7) | ||||
|  | ||||
|     def test_exponential_decay(self): | ||||
|         lr_schedule = opt.exponential_decay(1e-1, 0.99) | ||||
|         lr = lr_schedule(10) | ||||
|         expected_lr = 0.1 * (0.99**10) | ||||
|         self.assertAlmostEqual(lr, expected_lr, delta=1e-7) | ||||
|  | ||||
|     def test_cosine_decay(self): | ||||
|         lr_schedule = opt.cosine_decay(0.1, 10) | ||||
|         lr = lr_schedule(4) | ||||
|         expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10)) | ||||
|         self.assertAlmostEqual(lr, expected_lr, delta=1e-7) | ||||
|  | ||||
|     def test_compile_with_schedule(self): | ||||
|         lr_schedule = opt.exponential_decay(1e-1, 0.9) | ||||
|         optimizer = opt.SGD(learning_rate=lr_schedule) | ||||
|  | ||||
|         @partial(mx.compile, inputs=optimizer.state, outputs=optimizer.state) | ||||
|         def update(): | ||||
|             optimizer.update({}, {}) | ||||
|  | ||||
|         for step in range(5): | ||||
|             update() | ||||
|             self.assertAlmostEqual(lr_schedule(step), optimizer.learning_rate.item()) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Srimukh Sripada
					Srimukh Sripada