Support LR schedulers (#334)

* Add a few LR schedulers

* Move parents's constructor call to the top

* Fix docstring

* refactor optimizers into two files

* add docs

* nit

* Fix Callable type annotation for python 3.8

---------

Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Srimukh Sripada 2024-02-15 20:26:20 +01:00 committed by GitHub
parent 85143fecdd
commit 818cda16bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 235 additions and 47 deletions

1
docs/.gitignore vendored
View File

@ -1,2 +1,3 @@
src/python/_autosummary*/ src/python/_autosummary*/
src/python/nn/_autosummary*/ src/python/nn/_autosummary*/
src/python/optimizers/_autosummary*/

View File

@ -31,20 +31,6 @@ model's parameters and the **optimizer state**.
.. toctree:: .. toctree::
optimizer optimizers/optimizer
optimizers/common_optimizers
.. currentmodule:: mlx.optimizers optimizers/schedulers
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
SGD
RMSprop
Adagrad
Adafactor
AdaDelta
Adam
AdamW
Adamax
Lion

View File

@ -0,0 +1,20 @@
.. _common_optimizers:
Common Optimizers
=================
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
SGD
RMSprop
Adagrad
Adafactor
AdaDelta
Adam
AdamW
Adamax
Lion

View File

@ -0,0 +1,13 @@
.. _schedulers:
Schedulers
==========
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
step_decay
exponential_decay
cosine_decay

View File

@ -0,0 +1,4 @@
# Copyright © 2023-2024 Apple Inc.
from mlx.optimizers.optimizers import *
from mlx.optimizers.schedulers import *

View File

@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import math import math
from typing import List, Optional, Tuple from typing import Callable, List, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
from mlx.utils import tree_map from mlx.utils import tree_map
@ -12,9 +12,10 @@ class Optimizer:
optimizer on a per-parameter basis and apply it to a parameter tree. optimizer on a per-parameter basis and apply it to a parameter tree.
""" """
def __init__(self): def __init__(self, schedulers=None):
self._initialized = False self._initialized = False
self._state = {} self._state = {"step": mx.array(0, mx.uint64)}
self._schedulers = {k: v for k, v in (schedulers or {}).items()}
def update(self, model: "mlx.nn.Module", gradients: dict): def update(self, model: "mlx.nn.Module", gradients: dict):
"""Apply the gradients to the parameters of the model and update the """Apply the gradients to the parameters of the model and update the
@ -44,9 +45,8 @@ class Optimizer:
>>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9) >>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9)
>>> model = nn.Linear(2, 2) >>> model = nn.Linear(2, 2)
>>> optimizer.init(model.trainable_parameters()) >>> optimizer.init(model.trainable_parameters())
>>> optimizer.state >>> optimizer.state.keys()
{'learning_rate': array(0.1, dtype=float32), 'weight': {'v': array([[0, 0], dict_keys(['step', 'learning_rate', 'weight', 'bias'])
[0, 0]], dtype=float32)}, 'bias': {'v': array([0, 0], dtype=float32)}}
""" """
self._state.update(tree_map(lambda x: {}, parameters)) self._state.update(tree_map(lambda x: {}, parameters))
tree_map(self.init_single, parameters, self._state) tree_map(self.init_single, parameters, self._state)
@ -76,6 +76,15 @@ class Optimizer:
""" """
if not self._initialized: if not self._initialized:
self.init(gradients) self.init(gradients)
# Update any scheduled variables
for param, scheduler in self._schedulers.items():
self.state[param] = scheduler(self.step)
# Increment the step
self.state["step"] = self.step + 1
# Apply the update
return tree_map(self.apply_single, gradients, parameters, self.state) return tree_map(self.apply_single, gradients, parameters, self.state)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
@ -97,14 +106,31 @@ class Optimizer:
def state(self, state: dict): def state(self, state: dict):
self._state = state self._state = state
@property
def step(self):
return self.state["step"]
@property @property
def learning_rate(self): def learning_rate(self):
return self.state["learning_rate"] return self.state["learning_rate"]
@learning_rate.setter @learning_rate.setter
def learning_rate(self, learning_rate: mx.array): def learning_rate(self, learning_rate: Union[float, mx.array]):
self.state["learning_rate"] = mx.array(learning_rate) self.state["learning_rate"] = mx.array(learning_rate)
def _maybe_schedule(
self, name: str, param: Union[float, Callable[[mx.array], mx.array]]
):
"""
To be used by derived classes to optionally put a parameter on a schedule.
"""
if isinstance(param, Callable):
self._schedulers[name] = param
param = param(self.step)
else:
param = mx.array(param)
self.state[name] = param
class SGD(Optimizer): class SGD(Optimizer):
r"""The stochastic gradient descent optimizer. r"""The stochastic gradient descent optimizer.
@ -117,7 +143,7 @@ class SGD(Optimizer):
w_{t+1} &= w_t - \lambda v_{t+1} w_{t+1} &= w_t - \lambda v_{t+1}
Args: Args:
learning_rate (float): The learning rate :math:`\lambda`. learning_rate (float or callable): The learning rate :math:`\lambda`.
momentum (float, optional): The momentum strength :math:`\mu`. Default: ``0`` momentum (float, optional): The momentum strength :math:`\mu`. Default: ``0``
weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0`` weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0``
dampening (float, optional): Dampening for momentum :math:`\tau`. Default: ``0`` dampening (float, optional): Dampening for momentum :math:`\tau`. Default: ``0``
@ -126,7 +152,7 @@ class SGD(Optimizer):
def __init__( def __init__(
self, self,
learning_rate: float, learning_rate: Union[float, Callable[[mx.array], mx.array]],
momentum: float = 0.0, momentum: float = 0.0,
weight_decay: float = 0.0, weight_decay: float = 0.0,
dampening: float = 0.0, dampening: float = 0.0,
@ -138,7 +164,7 @@ class SGD(Optimizer):
) )
super().__init__() super().__init__()
self.learning_rate = learning_rate self._maybe_schedule("learning_rate", learning_rate)
self.momentum = momentum self.momentum = momentum
self.weight_decay = weight_decay self.weight_decay = weight_decay
self.dampening = dampening self.dampening = dampening
@ -194,7 +220,7 @@ class RMSprop(Optimizer):
def __init__(self, learning_rate: float, alpha: float = 0.99, eps: float = 1e-8): def __init__(self, learning_rate: float, alpha: float = 0.99, eps: float = 1e-8):
super().__init__() super().__init__()
self.learning_rate = learning_rate self._maybe_schedule("learning_rate", learning_rate)
self.alpha = alpha self.alpha = alpha
self.eps = eps self.eps = eps
@ -246,7 +272,7 @@ class Adagrad(Optimizer):
def __init__(self, learning_rate: float, eps: float = 1e-8): def __init__(self, learning_rate: float, eps: float = 1e-8):
super().__init__() super().__init__()
self.learning_rate = learning_rate self._maybe_schedule("learning_rate", learning_rate)
self.eps = eps self.eps = eps
if self.eps < 0.0: if self.eps < 0.0:
@ -295,7 +321,7 @@ class AdaDelta(Optimizer):
def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6): def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6):
super().__init__() super().__init__()
self.learning_rate = learning_rate self._maybe_schedule("learning_rate", learning_rate)
self.rho = rho self.rho = rho
self.eps = eps self.eps = eps
if self.rho < 0.0: if self.rho < 0.0:
@ -361,7 +387,7 @@ class Adam(Optimizer):
): ):
super().__init__() super().__init__()
self.learning_rate = learning_rate self._maybe_schedule("learning_rate", learning_rate)
self.betas = betas self.betas = betas
self.eps = eps self.eps = eps
@ -526,7 +552,7 @@ class Lion(Optimizer):
): ):
super().__init__() super().__init__()
self.learning_rate = learning_rate self._maybe_schedule("learning_rate", learning_rate)
self.betas = betas self.betas = betas
self.weight_decay = weight_decay self.weight_decay = weight_decay
@ -596,7 +622,7 @@ class Adafactor(Optimizer):
): ):
super().__init__() super().__init__()
if learning_rate is not None: if learning_rate is not None:
self.learning_rate = learning_rate self._maybe_schedule("learning_rate", learning_rate)
self.eps = eps self.eps = eps
self.clip_threshold = clip_threshold self.clip_threshold = clip_threshold
self.decay_rate = decay_rate self.decay_rate = decay_rate
@ -608,7 +634,6 @@ class Adafactor(Optimizer):
def init_single(self, parameter: mx.array, state: dict): def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state""" """Initialize optimizer state"""
state["step"] = 0
if parameter.ndim >= 2: if parameter.ndim >= 2:
shape = parameter.shape shape = parameter.shape
dtype = parameter.dtype dtype = parameter.dtype
@ -626,10 +651,11 @@ class Adafactor(Optimizer):
def _compute_learning_rate(self, step, parameter_rms): def _compute_learning_rate(self, step, parameter_rms):
if self.relative_step: if self.relative_step:
min_step = 1e-6 * step if self.warmup_init else 1e-2 min_step = 1e-6 * step if self.warmup_init else 1e-2
relative_step_size = min(min_step, 1 / math.sqrt(step)) relative_step_size = mx.minimum(min_step, mx.rsqrt(step))
else: else:
relative_step_size = self.learning_rate.astype(parameter_rms) relative_step_size = self.learning_rate
relative_step_size = relative_step_size.astype(parameter_rms.dtype)
parameter_scale = 1.0 parameter_scale = 1.0
if self.scale_parameter: if self.scale_parameter:
parameter_scale = mx.maximum(self.eps[1], parameter_rms) parameter_scale = mx.maximum(self.eps[1], parameter_rms)
@ -648,13 +674,12 @@ class Adafactor(Optimizer):
"""Performs the Adafactor parameter and state update.""" """Performs the Adafactor parameter and state update."""
factored = gradient.ndim >= 2 factored = gradient.ndim >= 2
step = state["step"] + 1 step = self.step
state["step"] = step
use_first_moment = self.beta_1 is not None use_first_moment = self.beta_1 is not None
parameter_rms = self._compute_rms(parameter) parameter_rms = self._compute_rms(parameter)
learning_rate = self._compute_learning_rate(step, parameter_rms) learning_rate = self._compute_learning_rate(step, parameter_rms)
beta_2 = 1.0 - math.pow(step, self.decay_rate) beta_2 = 1.0 - (step**self.decay_rate).astype(parameter_rms.dtype)
update = mx.square(gradient) + self.eps[0] update = mx.square(gradient) + self.eps[0]
if factored: if factored:

View File

@ -0,0 +1,86 @@
# Copyright © 2023-2024 Apple Inc.
import math
import mlx.core as mx
def exponential_decay(init: float, decay_rate: float):
r"""Make an exponential decay scheduler.
Args:
init (float): Initial value.
decay_rate (float): Multiplicative factor to decay by.
Example:
>>> lr_schedule = optim.exponential_decay(1e-1, 0.9)
>>> optimizer = optim.SGD(learning_rate=lr_schedule)
>>> optimizer.learning_rate
array(0.1, dtype=float32)
>>>
>>> for _ in range(5): optimizer.update({}, {})
...
>>> optimizer.learning_rate
array(0.06561, dtype=float32)
"""
def schedule(step):
return init * decay_rate**step
return schedule
def step_decay(init: float, decay_rate: float, step_size: int):
r"""Make a step decay scheduler.
Args:
init (float): Initial value.
decay_rate (float): Multiplicative factor to decay by.
step_size (int): Decay every ``step_size`` steps.
Example:
>>> lr_schedule = optim.step_decay(1e-1, 0.9, 10)
>>> optimizer = optim.SGD(learning_rate=lr_schedule)
>>> optimizer.learning_rate
array(0.1, dtype=float32)
>>>
>>> for _ in range(21): optimizer.update({}, {})
...
>>> optimizer.learning_rate
array(0.081, dtype=float32)
"""
def schedule(step):
return init * (decay_rate ** (step // step_size))
return schedule
def cosine_decay(init: float, decay_steps: int):
r"""Make a cosine decay scheduler.
Args:
init (float): Initial value.
decay_steps (int): Number of steps to decay over. The decayed
value is constant for steps beyond ``decay_steps``.
Example:
>>> lr_schedule = optim.cosine_decay(1e-1, 1000)
>>> optimizer = optim.SGD(learning_rate=lr_schedule)
>>> optimizer.learning_rate
array(0.1, dtype=float32)
>>>
>>> for _ in range(5): optimizer.update({}, {})
...
>>> optimizer.learning_rate
array(0.0999961, dtype=float32)
"""
def scheduler(step):
s = mx.minimum(step, decay_steps)
decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s))
return init * decay
return scheduler

View File

@ -971,6 +971,12 @@ void init_array(py::module_& m) {
return power(a, to_array(v, a.dtype())); return power(a, to_array(v, a.dtype()));
}, },
"other"_a) "other"_a)
.def(
"__rpow__",
[](const array& a, const ScalarOrArray v) {
return power(to_array(v, a.dtype()), a);
},
"other"_a)
.def( .def(
"__ipow__", "__ipow__",
[](array& a, const ScalarOrArray v) { [](array& a, const ScalarOrArray v) {

View File

@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import inspect import inspect
import math
import unittest import unittest
from functools import partial from functools import partial
@ -15,8 +16,11 @@ from mlx.utils import tree_flatten, tree_map
def get_all_optimizers(): def get_all_optimizers():
classes = dict() classes = dict()
for name, obj in inspect.getmembers(opt): for name, obj in inspect.getmembers(opt):
if inspect.isclass(obj): if (
if obj.__name__ not in ["Optimizer"]: inspect.isclass(obj)
and issubclass(obj, opt.Optimizer)
and obj != opt.Optimizer
):
classes[name] = obj classes[name] = obj
return classes return classes
@ -204,18 +208,16 @@ class TestOptimizers(mlx_tests.MLXTestCase):
x = mx.zeros((5, 5)) x = mx.zeros((5, 5))
grad = mx.ones_like(x) grad = mx.ones_like(x)
optimizer = opt.Adafactor() optimizer = opt.Adafactor()
optimizer.init(x)
for _ in range(2): for _ in range(2):
xp = optimizer.apply_single(grad, x, optimizer.state) xp = optimizer.apply_gradients(grad, x)
self.assertEqual(xp.dtype, x.dtype) self.assertEqual(xp.dtype, x.dtype)
self.assertEqual(xp.shape, x.shape) self.assertEqual(xp.shape, x.shape)
x = mx.zeros((5, 5), mx.float16) x = mx.zeros((5, 5), mx.float16)
grad = mx.ones_like(x) grad = mx.ones_like(x)
optimizer = opt.Adafactor() optimizer = opt.Adafactor()
optimizer.init(x)
for _ in range(2): for _ in range(2):
xp = optimizer.apply_single(grad, x, optimizer.state) xp = optimizer.apply_gradients(grad, x)
self.assertEqual(xp.dtype, x.dtype) self.assertEqual(xp.dtype, x.dtype)
self.assertEqual(xp.shape, x.shape) self.assertEqual(xp.shape, x.shape)
self.assertEqual(optimizer.state["step"], 2) self.assertEqual(optimizer.state["step"], 2)
@ -294,5 +296,50 @@ class TestOptimizers(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0))) self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
class TestSchedulers(unittest.TestCase):
def test_decay_lr(self):
for optim_class in optimizers_dict.values():
lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
optimizer = optim_class(learning_rate=lr_schedule)
params = {"w": mx.ones((5, 5))}
grads = tree_map(lambda x: mx.ones_like(x), params)
for it in range(10):
expected_lr = 0.1 * (0.9**it)
self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7)
return optimizer.apply_gradients(grads, params)
def test_step_decay(self):
lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
lr = lr_schedule(2500)
expected_lr = 0.1 * (0.9**2)
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
def test_exponential_decay(self):
lr_schedule = opt.exponential_decay(1e-1, 0.99)
lr = lr_schedule(10)
expected_lr = 0.1 * (0.99**10)
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
def test_cosine_decay(self):
lr_schedule = opt.cosine_decay(0.1, 10)
lr = lr_schedule(4)
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
def test_compile_with_schedule(self):
lr_schedule = opt.exponential_decay(1e-1, 0.9)
optimizer = opt.SGD(learning_rate=lr_schedule)
@partial(mx.compile, inputs=optimizer.state, outputs=optimizer.state)
def update():
optimizer.update({}, {})
for step in range(5):
update()
self.assertAlmostEqual(lr_schedule(step), optimizer.learning_rate.item())
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()