From 818cda16bcc4a68a1e971874c42f909df30145e3 Mon Sep 17 00:00:00 2001 From: Srimukh Sripada Date: Thu, 15 Feb 2024 20:26:20 +0100 Subject: [PATCH] Support LR schedulers (#334) * Add a few LR schedulers * Move parents's constructor call to the top * Fix docstring * refactor optimizers into two files * add docs * nit * Fix Callable type annotation for python 3.8 --------- Co-authored-by: Awni Hannun Co-authored-by: Angelos Katharopoulos --- docs/.gitignore | 1 + docs/src/python/optimizers.rst | 20 +---- .../python/optimizers/common_optimizers.rst | 20 +++++ .../src/python/{ => optimizers}/optimizer.rst | 0 docs/src/python/optimizers/schedulers.rst | 13 +++ python/mlx/optimizers/__init__.py | 4 + python/mlx/{ => optimizers}/optimizers.py | 71 ++++++++++----- python/mlx/optimizers/schedulers.py | 86 +++++++++++++++++++ python/src/array.cpp | 6 ++ python/tests/test_optimizers.py | 61 +++++++++++-- 10 files changed, 235 insertions(+), 47 deletions(-) create mode 100644 docs/src/python/optimizers/common_optimizers.rst rename docs/src/python/{ => optimizers}/optimizer.rst (100%) create mode 100644 docs/src/python/optimizers/schedulers.rst create mode 100644 python/mlx/optimizers/__init__.py rename python/mlx/{ => optimizers}/optimizers.py (92%) create mode 100644 python/mlx/optimizers/schedulers.py diff --git a/docs/.gitignore b/docs/.gitignore index 5c2693cb6..fa80a135e 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -1,2 +1,3 @@ src/python/_autosummary*/ src/python/nn/_autosummary*/ +src/python/optimizers/_autosummary*/ diff --git a/docs/src/python/optimizers.rst b/docs/src/python/optimizers.rst index 4ef43d50f..f437ddc15 100644 --- a/docs/src/python/optimizers.rst +++ b/docs/src/python/optimizers.rst @@ -31,20 +31,6 @@ model's parameters and the **optimizer state**. .. toctree:: - optimizer - -.. currentmodule:: mlx.optimizers - -.. autosummary:: - :toctree: _autosummary - :template: optimizers-template.rst - - SGD - RMSprop - Adagrad - Adafactor - AdaDelta - Adam - AdamW - Adamax - Lion + optimizers/optimizer + optimizers/common_optimizers + optimizers/schedulers diff --git a/docs/src/python/optimizers/common_optimizers.rst b/docs/src/python/optimizers/common_optimizers.rst new file mode 100644 index 000000000..41b3fba03 --- /dev/null +++ b/docs/src/python/optimizers/common_optimizers.rst @@ -0,0 +1,20 @@ +.. _common_optimizers: + +Common Optimizers +================= + +.. currentmodule:: mlx.optimizers + +.. autosummary:: + :toctree: _autosummary + :template: optimizers-template.rst + + SGD + RMSprop + Adagrad + Adafactor + AdaDelta + Adam + AdamW + Adamax + Lion diff --git a/docs/src/python/optimizer.rst b/docs/src/python/optimizers/optimizer.rst similarity index 100% rename from docs/src/python/optimizer.rst rename to docs/src/python/optimizers/optimizer.rst diff --git a/docs/src/python/optimizers/schedulers.rst b/docs/src/python/optimizers/schedulers.rst new file mode 100644 index 000000000..a83883ddb --- /dev/null +++ b/docs/src/python/optimizers/schedulers.rst @@ -0,0 +1,13 @@ +.. _schedulers: + +Schedulers +========== + +.. currentmodule:: mlx.optimizers + +.. autosummary:: + :toctree: _autosummary + + step_decay + exponential_decay + cosine_decay diff --git a/python/mlx/optimizers/__init__.py b/python/mlx/optimizers/__init__.py new file mode 100644 index 000000000..6e8e0ccd4 --- /dev/null +++ b/python/mlx/optimizers/__init__.py @@ -0,0 +1,4 @@ +# Copyright © 2023-2024 Apple Inc. + +from mlx.optimizers.optimizers import * +from mlx.optimizers.schedulers import * diff --git a/python/mlx/optimizers.py b/python/mlx/optimizers/optimizers.py similarity index 92% rename from python/mlx/optimizers.py rename to python/mlx/optimizers/optimizers.py index 4a53d4681..16928625f 100644 --- a/python/mlx/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -1,7 +1,7 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import math -from typing import List, Optional, Tuple +from typing import Callable, List, Optional, Tuple, Union import mlx.core as mx from mlx.utils import tree_map @@ -12,9 +12,10 @@ class Optimizer: optimizer on a per-parameter basis and apply it to a parameter tree. """ - def __init__(self): + def __init__(self, schedulers=None): self._initialized = False - self._state = {} + self._state = {"step": mx.array(0, mx.uint64)} + self._schedulers = {k: v for k, v in (schedulers or {}).items()} def update(self, model: "mlx.nn.Module", gradients: dict): """Apply the gradients to the parameters of the model and update the @@ -44,9 +45,8 @@ class Optimizer: >>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9) >>> model = nn.Linear(2, 2) >>> optimizer.init(model.trainable_parameters()) - >>> optimizer.state - {'learning_rate': array(0.1, dtype=float32), 'weight': {'v': array([[0, 0], - [0, 0]], dtype=float32)}, 'bias': {'v': array([0, 0], dtype=float32)}} + >>> optimizer.state.keys() + dict_keys(['step', 'learning_rate', 'weight', 'bias']) """ self._state.update(tree_map(lambda x: {}, parameters)) tree_map(self.init_single, parameters, self._state) @@ -76,6 +76,15 @@ class Optimizer: """ if not self._initialized: self.init(gradients) + + # Update any scheduled variables + for param, scheduler in self._schedulers.items(): + self.state[param] = scheduler(self.step) + + # Increment the step + self.state["step"] = self.step + 1 + + # Apply the update return tree_map(self.apply_single, gradients, parameters, self.state) def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): @@ -97,14 +106,31 @@ class Optimizer: def state(self, state: dict): self._state = state + @property + def step(self): + return self.state["step"] + @property def learning_rate(self): return self.state["learning_rate"] @learning_rate.setter - def learning_rate(self, learning_rate: mx.array): + def learning_rate(self, learning_rate: Union[float, mx.array]): self.state["learning_rate"] = mx.array(learning_rate) + def _maybe_schedule( + self, name: str, param: Union[float, Callable[[mx.array], mx.array]] + ): + """ + To be used by derived classes to optionally put a parameter on a schedule. + """ + if isinstance(param, Callable): + self._schedulers[name] = param + param = param(self.step) + else: + param = mx.array(param) + self.state[name] = param + class SGD(Optimizer): r"""The stochastic gradient descent optimizer. @@ -117,7 +143,7 @@ class SGD(Optimizer): w_{t+1} &= w_t - \lambda v_{t+1} Args: - learning_rate (float): The learning rate :math:`\lambda`. + learning_rate (float or callable): The learning rate :math:`\lambda`. momentum (float, optional): The momentum strength :math:`\mu`. Default: ``0`` weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0`` dampening (float, optional): Dampening for momentum :math:`\tau`. Default: ``0`` @@ -126,7 +152,7 @@ class SGD(Optimizer): def __init__( self, - learning_rate: float, + learning_rate: Union[float, Callable[[mx.array], mx.array]], momentum: float = 0.0, weight_decay: float = 0.0, dampening: float = 0.0, @@ -138,7 +164,7 @@ class SGD(Optimizer): ) super().__init__() - self.learning_rate = learning_rate + self._maybe_schedule("learning_rate", learning_rate) self.momentum = momentum self.weight_decay = weight_decay self.dampening = dampening @@ -194,7 +220,7 @@ class RMSprop(Optimizer): def __init__(self, learning_rate: float, alpha: float = 0.99, eps: float = 1e-8): super().__init__() - self.learning_rate = learning_rate + self._maybe_schedule("learning_rate", learning_rate) self.alpha = alpha self.eps = eps @@ -246,7 +272,7 @@ class Adagrad(Optimizer): def __init__(self, learning_rate: float, eps: float = 1e-8): super().__init__() - self.learning_rate = learning_rate + self._maybe_schedule("learning_rate", learning_rate) self.eps = eps if self.eps < 0.0: @@ -295,7 +321,7 @@ class AdaDelta(Optimizer): def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6): super().__init__() - self.learning_rate = learning_rate + self._maybe_schedule("learning_rate", learning_rate) self.rho = rho self.eps = eps if self.rho < 0.0: @@ -361,7 +387,7 @@ class Adam(Optimizer): ): super().__init__() - self.learning_rate = learning_rate + self._maybe_schedule("learning_rate", learning_rate) self.betas = betas self.eps = eps @@ -526,7 +552,7 @@ class Lion(Optimizer): ): super().__init__() - self.learning_rate = learning_rate + self._maybe_schedule("learning_rate", learning_rate) self.betas = betas self.weight_decay = weight_decay @@ -596,7 +622,7 @@ class Adafactor(Optimizer): ): super().__init__() if learning_rate is not None: - self.learning_rate = learning_rate + self._maybe_schedule("learning_rate", learning_rate) self.eps = eps self.clip_threshold = clip_threshold self.decay_rate = decay_rate @@ -608,7 +634,6 @@ class Adafactor(Optimizer): def init_single(self, parameter: mx.array, state: dict): """Initialize optimizer state""" - state["step"] = 0 if parameter.ndim >= 2: shape = parameter.shape dtype = parameter.dtype @@ -626,10 +651,11 @@ class Adafactor(Optimizer): def _compute_learning_rate(self, step, parameter_rms): if self.relative_step: min_step = 1e-6 * step if self.warmup_init else 1e-2 - relative_step_size = min(min_step, 1 / math.sqrt(step)) + relative_step_size = mx.minimum(min_step, mx.rsqrt(step)) else: - relative_step_size = self.learning_rate.astype(parameter_rms) + relative_step_size = self.learning_rate + relative_step_size = relative_step_size.astype(parameter_rms.dtype) parameter_scale = 1.0 if self.scale_parameter: parameter_scale = mx.maximum(self.eps[1], parameter_rms) @@ -648,13 +674,12 @@ class Adafactor(Optimizer): """Performs the Adafactor parameter and state update.""" factored = gradient.ndim >= 2 - step = state["step"] + 1 - state["step"] = step + step = self.step use_first_moment = self.beta_1 is not None parameter_rms = self._compute_rms(parameter) learning_rate = self._compute_learning_rate(step, parameter_rms) - beta_2 = 1.0 - math.pow(step, self.decay_rate) + beta_2 = 1.0 - (step**self.decay_rate).astype(parameter_rms.dtype) update = mx.square(gradient) + self.eps[0] if factored: diff --git a/python/mlx/optimizers/schedulers.py b/python/mlx/optimizers/schedulers.py new file mode 100644 index 000000000..da058c03a --- /dev/null +++ b/python/mlx/optimizers/schedulers.py @@ -0,0 +1,86 @@ +# Copyright © 2023-2024 Apple Inc. + +import math + +import mlx.core as mx + + +def exponential_decay(init: float, decay_rate: float): + r"""Make an exponential decay scheduler. + + Args: + init (float): Initial value. + decay_rate (float): Multiplicative factor to decay by. + + Example: + >>> lr_schedule = optim.exponential_decay(1e-1, 0.9) + >>> optimizer = optim.SGD(learning_rate=lr_schedule) + >>> optimizer.learning_rate + array(0.1, dtype=float32) + >>> + >>> for _ in range(5): optimizer.update({}, {}) + ... + >>> optimizer.learning_rate + array(0.06561, dtype=float32) + """ + + def schedule(step): + return init * decay_rate**step + + return schedule + + +def step_decay(init: float, decay_rate: float, step_size: int): + r"""Make a step decay scheduler. + + Args: + init (float): Initial value. + decay_rate (float): Multiplicative factor to decay by. + step_size (int): Decay every ``step_size`` steps. + + Example: + + >>> lr_schedule = optim.step_decay(1e-1, 0.9, 10) + >>> optimizer = optim.SGD(learning_rate=lr_schedule) + >>> optimizer.learning_rate + array(0.1, dtype=float32) + >>> + >>> for _ in range(21): optimizer.update({}, {}) + ... + >>> optimizer.learning_rate + array(0.081, dtype=float32) + """ + + def schedule(step): + return init * (decay_rate ** (step // step_size)) + + return schedule + + +def cosine_decay(init: float, decay_steps: int): + r"""Make a cosine decay scheduler. + + Args: + init (float): Initial value. + decay_steps (int): Number of steps to decay over. The decayed + value is constant for steps beyond ``decay_steps``. + + Example: + + >>> lr_schedule = optim.cosine_decay(1e-1, 1000) + >>> optimizer = optim.SGD(learning_rate=lr_schedule) + >>> optimizer.learning_rate + array(0.1, dtype=float32) + >>> + >>> for _ in range(5): optimizer.update({}, {}) + ... + >>> optimizer.learning_rate + array(0.0999961, dtype=float32) + """ + + def scheduler(step): + s = mx.minimum(step, decay_steps) + decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s)) + return init * decay + + return scheduler diff --git a/python/src/array.cpp b/python/src/array.cpp index 4395d50e6..57b867dbc 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -971,6 +971,12 @@ void init_array(py::module_& m) { return power(a, to_array(v, a.dtype())); }, "other"_a) + .def( + "__rpow__", + [](const array& a, const ScalarOrArray v) { + return power(to_array(v, a.dtype()), a); + }, + "other"_a) .def( "__ipow__", [](array& a, const ScalarOrArray v) { diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index f894a7510..f978943de 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. import inspect +import math import unittest from functools import partial @@ -15,9 +16,12 @@ from mlx.utils import tree_flatten, tree_map def get_all_optimizers(): classes = dict() for name, obj in inspect.getmembers(opt): - if inspect.isclass(obj): - if obj.__name__ not in ["Optimizer"]: - classes[name] = obj + if ( + inspect.isclass(obj) + and issubclass(obj, opt.Optimizer) + and obj != opt.Optimizer + ): + classes[name] = obj return classes @@ -204,18 +208,16 @@ class TestOptimizers(mlx_tests.MLXTestCase): x = mx.zeros((5, 5)) grad = mx.ones_like(x) optimizer = opt.Adafactor() - optimizer.init(x) for _ in range(2): - xp = optimizer.apply_single(grad, x, optimizer.state) + xp = optimizer.apply_gradients(grad, x) self.assertEqual(xp.dtype, x.dtype) self.assertEqual(xp.shape, x.shape) x = mx.zeros((5, 5), mx.float16) grad = mx.ones_like(x) optimizer = opt.Adafactor() - optimizer.init(x) for _ in range(2): - xp = optimizer.apply_single(grad, x, optimizer.state) + xp = optimizer.apply_gradients(grad, x) self.assertEqual(xp.dtype, x.dtype) self.assertEqual(xp.shape, x.shape) self.assertEqual(optimizer.state["step"], 2) @@ -294,5 +296,50 @@ class TestOptimizers(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0))) +class TestSchedulers(unittest.TestCase): + def test_decay_lr(self): + for optim_class in optimizers_dict.values(): + lr_schedule = opt.step_decay(1e-1, 0.9, 1000) + optimizer = optim_class(learning_rate=lr_schedule) + + params = {"w": mx.ones((5, 5))} + grads = tree_map(lambda x: mx.ones_like(x), params) + + for it in range(10): + expected_lr = 0.1 * (0.9**it) + self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7) + return optimizer.apply_gradients(grads, params) + + def test_step_decay(self): + lr_schedule = opt.step_decay(1e-1, 0.9, 1000) + lr = lr_schedule(2500) + expected_lr = 0.1 * (0.9**2) + self.assertAlmostEqual(lr, expected_lr, delta=1e-7) + + def test_exponential_decay(self): + lr_schedule = opt.exponential_decay(1e-1, 0.99) + lr = lr_schedule(10) + expected_lr = 0.1 * (0.99**10) + self.assertAlmostEqual(lr, expected_lr, delta=1e-7) + + def test_cosine_decay(self): + lr_schedule = opt.cosine_decay(0.1, 10) + lr = lr_schedule(4) + expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10)) + self.assertAlmostEqual(lr, expected_lr, delta=1e-7) + + def test_compile_with_schedule(self): + lr_schedule = opt.exponential_decay(1e-1, 0.9) + optimizer = opt.SGD(learning_rate=lr_schedule) + + @partial(mx.compile, inputs=optimizer.state, outputs=optimizer.state) + def update(): + optimizer.update({}, {}) + + for step in range(5): + update() + self.assertAlmostEqual(lr_schedule(step), optimizer.learning_rate.item()) + + if __name__ == "__main__": unittest.main()