Support LR schedulers (#334)

* Add a few LR schedulers

* Move parents's constructor call to the top

* Fix docstring

* refactor optimizers into two files

* add docs

* nit

* Fix Callable type annotation for python 3.8

---------

Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Srimukh Sripada 2024-02-15 20:26:20 +01:00 committed by GitHub
parent 85143fecdd
commit 818cda16bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 235 additions and 47 deletions

1
docs/.gitignore vendored
View File

@ -1,2 +1,3 @@
src/python/_autosummary*/
src/python/nn/_autosummary*/
src/python/optimizers/_autosummary*/

View File

@ -31,20 +31,6 @@ model's parameters and the **optimizer state**.
.. toctree::
optimizer
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
SGD
RMSprop
Adagrad
Adafactor
AdaDelta
Adam
AdamW
Adamax
Lion
optimizers/optimizer
optimizers/common_optimizers
optimizers/schedulers

View File

@ -0,0 +1,20 @@
.. _common_optimizers:
Common Optimizers
=================
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
SGD
RMSprop
Adagrad
Adafactor
AdaDelta
Adam
AdamW
Adamax
Lion

View File

@ -0,0 +1,13 @@
.. _schedulers:
Schedulers
==========
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
step_decay
exponential_decay
cosine_decay

View File

@ -0,0 +1,4 @@
# Copyright © 2023-2024 Apple Inc.
from mlx.optimizers.optimizers import *
from mlx.optimizers.schedulers import *

View File

@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import math
from typing import List, Optional, Tuple
from typing import Callable, List, Optional, Tuple, Union
import mlx.core as mx
from mlx.utils import tree_map
@ -12,9 +12,10 @@ class Optimizer:
optimizer on a per-parameter basis and apply it to a parameter tree.
"""
def __init__(self):
def __init__(self, schedulers=None):
self._initialized = False
self._state = {}
self._state = {"step": mx.array(0, mx.uint64)}
self._schedulers = {k: v for k, v in (schedulers or {}).items()}
def update(self, model: "mlx.nn.Module", gradients: dict):
"""Apply the gradients to the parameters of the model and update the
@ -44,9 +45,8 @@ class Optimizer:
>>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9)
>>> model = nn.Linear(2, 2)
>>> optimizer.init(model.trainable_parameters())
>>> optimizer.state
{'learning_rate': array(0.1, dtype=float32), 'weight': {'v': array([[0, 0],
[0, 0]], dtype=float32)}, 'bias': {'v': array([0, 0], dtype=float32)}}
>>> optimizer.state.keys()
dict_keys(['step', 'learning_rate', 'weight', 'bias'])
"""
self._state.update(tree_map(lambda x: {}, parameters))
tree_map(self.init_single, parameters, self._state)
@ -76,6 +76,15 @@ class Optimizer:
"""
if not self._initialized:
self.init(gradients)
# Update any scheduled variables
for param, scheduler in self._schedulers.items():
self.state[param] = scheduler(self.step)
# Increment the step
self.state["step"] = self.step + 1
# Apply the update
return tree_map(self.apply_single, gradients, parameters, self.state)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
@ -97,14 +106,31 @@ class Optimizer:
def state(self, state: dict):
self._state = state
@property
def step(self):
return self.state["step"]
@property
def learning_rate(self):
return self.state["learning_rate"]
@learning_rate.setter
def learning_rate(self, learning_rate: mx.array):
def learning_rate(self, learning_rate: Union[float, mx.array]):
self.state["learning_rate"] = mx.array(learning_rate)
def _maybe_schedule(
self, name: str, param: Union[float, Callable[[mx.array], mx.array]]
):
"""
To be used by derived classes to optionally put a parameter on a schedule.
"""
if isinstance(param, Callable):
self._schedulers[name] = param
param = param(self.step)
else:
param = mx.array(param)
self.state[name] = param
class SGD(Optimizer):
r"""The stochastic gradient descent optimizer.
@ -117,7 +143,7 @@ class SGD(Optimizer):
w_{t+1} &= w_t - \lambda v_{t+1}
Args:
learning_rate (float): The learning rate :math:`\lambda`.
learning_rate (float or callable): The learning rate :math:`\lambda`.
momentum (float, optional): The momentum strength :math:`\mu`. Default: ``0``
weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0``
dampening (float, optional): Dampening for momentum :math:`\tau`. Default: ``0``
@ -126,7 +152,7 @@ class SGD(Optimizer):
def __init__(
self,
learning_rate: float,
learning_rate: Union[float, Callable[[mx.array], mx.array]],
momentum: float = 0.0,
weight_decay: float = 0.0,
dampening: float = 0.0,
@ -138,7 +164,7 @@ class SGD(Optimizer):
)
super().__init__()
self.learning_rate = learning_rate
self._maybe_schedule("learning_rate", learning_rate)
self.momentum = momentum
self.weight_decay = weight_decay
self.dampening = dampening
@ -194,7 +220,7 @@ class RMSprop(Optimizer):
def __init__(self, learning_rate: float, alpha: float = 0.99, eps: float = 1e-8):
super().__init__()
self.learning_rate = learning_rate
self._maybe_schedule("learning_rate", learning_rate)
self.alpha = alpha
self.eps = eps
@ -246,7 +272,7 @@ class Adagrad(Optimizer):
def __init__(self, learning_rate: float, eps: float = 1e-8):
super().__init__()
self.learning_rate = learning_rate
self._maybe_schedule("learning_rate", learning_rate)
self.eps = eps
if self.eps < 0.0:
@ -295,7 +321,7 @@ class AdaDelta(Optimizer):
def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6):
super().__init__()
self.learning_rate = learning_rate
self._maybe_schedule("learning_rate", learning_rate)
self.rho = rho
self.eps = eps
if self.rho < 0.0:
@ -361,7 +387,7 @@ class Adam(Optimizer):
):
super().__init__()
self.learning_rate = learning_rate
self._maybe_schedule("learning_rate", learning_rate)
self.betas = betas
self.eps = eps
@ -526,7 +552,7 @@ class Lion(Optimizer):
):
super().__init__()
self.learning_rate = learning_rate
self._maybe_schedule("learning_rate", learning_rate)
self.betas = betas
self.weight_decay = weight_decay
@ -596,7 +622,7 @@ class Adafactor(Optimizer):
):
super().__init__()
if learning_rate is not None:
self.learning_rate = learning_rate
self._maybe_schedule("learning_rate", learning_rate)
self.eps = eps
self.clip_threshold = clip_threshold
self.decay_rate = decay_rate
@ -608,7 +634,6 @@ class Adafactor(Optimizer):
def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["step"] = 0
if parameter.ndim >= 2:
shape = parameter.shape
dtype = parameter.dtype
@ -626,10 +651,11 @@ class Adafactor(Optimizer):
def _compute_learning_rate(self, step, parameter_rms):
if self.relative_step:
min_step = 1e-6 * step if self.warmup_init else 1e-2
relative_step_size = min(min_step, 1 / math.sqrt(step))
relative_step_size = mx.minimum(min_step, mx.rsqrt(step))
else:
relative_step_size = self.learning_rate.astype(parameter_rms)
relative_step_size = self.learning_rate
relative_step_size = relative_step_size.astype(parameter_rms.dtype)
parameter_scale = 1.0
if self.scale_parameter:
parameter_scale = mx.maximum(self.eps[1], parameter_rms)
@ -648,13 +674,12 @@ class Adafactor(Optimizer):
"""Performs the Adafactor parameter and state update."""
factored = gradient.ndim >= 2
step = state["step"] + 1
state["step"] = step
step = self.step
use_first_moment = self.beta_1 is not None
parameter_rms = self._compute_rms(parameter)
learning_rate = self._compute_learning_rate(step, parameter_rms)
beta_2 = 1.0 - math.pow(step, self.decay_rate)
beta_2 = 1.0 - (step**self.decay_rate).astype(parameter_rms.dtype)
update = mx.square(gradient) + self.eps[0]
if factored:

View File

@ -0,0 +1,86 @@
# Copyright © 2023-2024 Apple Inc.
import math
import mlx.core as mx
def exponential_decay(init: float, decay_rate: float):
r"""Make an exponential decay scheduler.
Args:
init (float): Initial value.
decay_rate (float): Multiplicative factor to decay by.
Example:
>>> lr_schedule = optim.exponential_decay(1e-1, 0.9)
>>> optimizer = optim.SGD(learning_rate=lr_schedule)
>>> optimizer.learning_rate
array(0.1, dtype=float32)
>>>
>>> for _ in range(5): optimizer.update({}, {})
...
>>> optimizer.learning_rate
array(0.06561, dtype=float32)
"""
def schedule(step):
return init * decay_rate**step
return schedule
def step_decay(init: float, decay_rate: float, step_size: int):
r"""Make a step decay scheduler.
Args:
init (float): Initial value.
decay_rate (float): Multiplicative factor to decay by.
step_size (int): Decay every ``step_size`` steps.
Example:
>>> lr_schedule = optim.step_decay(1e-1, 0.9, 10)
>>> optimizer = optim.SGD(learning_rate=lr_schedule)
>>> optimizer.learning_rate
array(0.1, dtype=float32)
>>>
>>> for _ in range(21): optimizer.update({}, {})
...
>>> optimizer.learning_rate
array(0.081, dtype=float32)
"""
def schedule(step):
return init * (decay_rate ** (step // step_size))
return schedule
def cosine_decay(init: float, decay_steps: int):
r"""Make a cosine decay scheduler.
Args:
init (float): Initial value.
decay_steps (int): Number of steps to decay over. The decayed
value is constant for steps beyond ``decay_steps``.
Example:
>>> lr_schedule = optim.cosine_decay(1e-1, 1000)
>>> optimizer = optim.SGD(learning_rate=lr_schedule)
>>> optimizer.learning_rate
array(0.1, dtype=float32)
>>>
>>> for _ in range(5): optimizer.update({}, {})
...
>>> optimizer.learning_rate
array(0.0999961, dtype=float32)
"""
def scheduler(step):
s = mx.minimum(step, decay_steps)
decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s))
return init * decay
return scheduler

View File

@ -971,6 +971,12 @@ void init_array(py::module_& m) {
return power(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__rpow__",
[](const array& a, const ScalarOrArray v) {
return power(to_array(v, a.dtype()), a);
},
"other"_a)
.def(
"__ipow__",
[](array& a, const ScalarOrArray v) {

View File

@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc.
import inspect
import math
import unittest
from functools import partial
@ -15,9 +16,12 @@ from mlx.utils import tree_flatten, tree_map
def get_all_optimizers():
classes = dict()
for name, obj in inspect.getmembers(opt):
if inspect.isclass(obj):
if obj.__name__ not in ["Optimizer"]:
classes[name] = obj
if (
inspect.isclass(obj)
and issubclass(obj, opt.Optimizer)
and obj != opt.Optimizer
):
classes[name] = obj
return classes
@ -204,18 +208,16 @@ class TestOptimizers(mlx_tests.MLXTestCase):
x = mx.zeros((5, 5))
grad = mx.ones_like(x)
optimizer = opt.Adafactor()
optimizer.init(x)
for _ in range(2):
xp = optimizer.apply_single(grad, x, optimizer.state)
xp = optimizer.apply_gradients(grad, x)
self.assertEqual(xp.dtype, x.dtype)
self.assertEqual(xp.shape, x.shape)
x = mx.zeros((5, 5), mx.float16)
grad = mx.ones_like(x)
optimizer = opt.Adafactor()
optimizer.init(x)
for _ in range(2):
xp = optimizer.apply_single(grad, x, optimizer.state)
xp = optimizer.apply_gradients(grad, x)
self.assertEqual(xp.dtype, x.dtype)
self.assertEqual(xp.shape, x.shape)
self.assertEqual(optimizer.state["step"], 2)
@ -294,5 +296,50 @@ class TestOptimizers(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
class TestSchedulers(unittest.TestCase):
def test_decay_lr(self):
for optim_class in optimizers_dict.values():
lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
optimizer = optim_class(learning_rate=lr_schedule)
params = {"w": mx.ones((5, 5))}
grads = tree_map(lambda x: mx.ones_like(x), params)
for it in range(10):
expected_lr = 0.1 * (0.9**it)
self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7)
return optimizer.apply_gradients(grads, params)
def test_step_decay(self):
lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
lr = lr_schedule(2500)
expected_lr = 0.1 * (0.9**2)
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
def test_exponential_decay(self):
lr_schedule = opt.exponential_decay(1e-1, 0.99)
lr = lr_schedule(10)
expected_lr = 0.1 * (0.99**10)
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
def test_cosine_decay(self):
lr_schedule = opt.cosine_decay(0.1, 10)
lr = lr_schedule(4)
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
def test_compile_with_schedule(self):
lr_schedule = opt.exponential_decay(1e-1, 0.9)
optimizer = opt.SGD(learning_rate=lr_schedule)
@partial(mx.compile, inputs=optimizer.state, outputs=optimizer.state)
def update():
optimizer.update({}, {})
for step in range(5):
update()
self.assertAlmostEqual(lr_schedule(step), optimizer.learning_rate.item())
if __name__ == "__main__":
unittest.main()