mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Support LR schedulers (#334)
* Add a few LR schedulers * Move parents's constructor call to the top * Fix docstring * refactor optimizers into two files * add docs * nit * Fix Callable type annotation for python 3.8 --------- Co-authored-by: Awni Hannun <awni@apple.com> Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
85143fecdd
commit
818cda16bc
1
docs/.gitignore
vendored
1
docs/.gitignore
vendored
@ -1,2 +1,3 @@
|
||||
src/python/_autosummary*/
|
||||
src/python/nn/_autosummary*/
|
||||
src/python/optimizers/_autosummary*/
|
||||
|
@ -31,20 +31,6 @@ model's parameters and the **optimizer state**.
|
||||
|
||||
.. toctree::
|
||||
|
||||
optimizer
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
:template: optimizers-template.rst
|
||||
|
||||
SGD
|
||||
RMSprop
|
||||
Adagrad
|
||||
Adafactor
|
||||
AdaDelta
|
||||
Adam
|
||||
AdamW
|
||||
Adamax
|
||||
Lion
|
||||
optimizers/optimizer
|
||||
optimizers/common_optimizers
|
||||
optimizers/schedulers
|
||||
|
20
docs/src/python/optimizers/common_optimizers.rst
Normal file
20
docs/src/python/optimizers/common_optimizers.rst
Normal file
@ -0,0 +1,20 @@
|
||||
.. _common_optimizers:
|
||||
|
||||
Common Optimizers
|
||||
=================
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
:template: optimizers-template.rst
|
||||
|
||||
SGD
|
||||
RMSprop
|
||||
Adagrad
|
||||
Adafactor
|
||||
AdaDelta
|
||||
Adam
|
||||
AdamW
|
||||
Adamax
|
||||
Lion
|
13
docs/src/python/optimizers/schedulers.rst
Normal file
13
docs/src/python/optimizers/schedulers.rst
Normal file
@ -0,0 +1,13 @@
|
||||
.. _schedulers:
|
||||
|
||||
Schedulers
|
||||
==========
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
step_decay
|
||||
exponential_decay
|
||||
cosine_decay
|
4
python/mlx/optimizers/__init__.py
Normal file
4
python/mlx/optimizers/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from mlx.optimizers.optimizers import *
|
||||
from mlx.optimizers.schedulers import *
|
@ -1,7 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
@ -12,9 +12,10 @@ class Optimizer:
|
||||
optimizer on a per-parameter basis and apply it to a parameter tree.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, schedulers=None):
|
||||
self._initialized = False
|
||||
self._state = {}
|
||||
self._state = {"step": mx.array(0, mx.uint64)}
|
||||
self._schedulers = {k: v for k, v in (schedulers or {}).items()}
|
||||
|
||||
def update(self, model: "mlx.nn.Module", gradients: dict):
|
||||
"""Apply the gradients to the parameters of the model and update the
|
||||
@ -44,9 +45,8 @@ class Optimizer:
|
||||
>>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9)
|
||||
>>> model = nn.Linear(2, 2)
|
||||
>>> optimizer.init(model.trainable_parameters())
|
||||
>>> optimizer.state
|
||||
{'learning_rate': array(0.1, dtype=float32), 'weight': {'v': array([[0, 0],
|
||||
[0, 0]], dtype=float32)}, 'bias': {'v': array([0, 0], dtype=float32)}}
|
||||
>>> optimizer.state.keys()
|
||||
dict_keys(['step', 'learning_rate', 'weight', 'bias'])
|
||||
"""
|
||||
self._state.update(tree_map(lambda x: {}, parameters))
|
||||
tree_map(self.init_single, parameters, self._state)
|
||||
@ -76,6 +76,15 @@ class Optimizer:
|
||||
"""
|
||||
if not self._initialized:
|
||||
self.init(gradients)
|
||||
|
||||
# Update any scheduled variables
|
||||
for param, scheduler in self._schedulers.items():
|
||||
self.state[param] = scheduler(self.step)
|
||||
|
||||
# Increment the step
|
||||
self.state["step"] = self.step + 1
|
||||
|
||||
# Apply the update
|
||||
return tree_map(self.apply_single, gradients, parameters, self.state)
|
||||
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
@ -97,14 +106,31 @@ class Optimizer:
|
||||
def state(self, state: dict):
|
||||
self._state = state
|
||||
|
||||
@property
|
||||
def step(self):
|
||||
return self.state["step"]
|
||||
|
||||
@property
|
||||
def learning_rate(self):
|
||||
return self.state["learning_rate"]
|
||||
|
||||
@learning_rate.setter
|
||||
def learning_rate(self, learning_rate: mx.array):
|
||||
def learning_rate(self, learning_rate: Union[float, mx.array]):
|
||||
self.state["learning_rate"] = mx.array(learning_rate)
|
||||
|
||||
def _maybe_schedule(
|
||||
self, name: str, param: Union[float, Callable[[mx.array], mx.array]]
|
||||
):
|
||||
"""
|
||||
To be used by derived classes to optionally put a parameter on a schedule.
|
||||
"""
|
||||
if isinstance(param, Callable):
|
||||
self._schedulers[name] = param
|
||||
param = param(self.step)
|
||||
else:
|
||||
param = mx.array(param)
|
||||
self.state[name] = param
|
||||
|
||||
|
||||
class SGD(Optimizer):
|
||||
r"""The stochastic gradient descent optimizer.
|
||||
@ -117,7 +143,7 @@ class SGD(Optimizer):
|
||||
w_{t+1} &= w_t - \lambda v_{t+1}
|
||||
|
||||
Args:
|
||||
learning_rate (float): The learning rate :math:`\lambda`.
|
||||
learning_rate (float or callable): The learning rate :math:`\lambda`.
|
||||
momentum (float, optional): The momentum strength :math:`\mu`. Default: ``0``
|
||||
weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0``
|
||||
dampening (float, optional): Dampening for momentum :math:`\tau`. Default: ``0``
|
||||
@ -126,7 +152,7 @@ class SGD(Optimizer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate: float,
|
||||
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||
momentum: float = 0.0,
|
||||
weight_decay: float = 0.0,
|
||||
dampening: float = 0.0,
|
||||
@ -138,7 +164,7 @@ class SGD(Optimizer):
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self._maybe_schedule("learning_rate", learning_rate)
|
||||
self.momentum = momentum
|
||||
self.weight_decay = weight_decay
|
||||
self.dampening = dampening
|
||||
@ -194,7 +220,7 @@ class RMSprop(Optimizer):
|
||||
def __init__(self, learning_rate: float, alpha: float = 0.99, eps: float = 1e-8):
|
||||
super().__init__()
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self._maybe_schedule("learning_rate", learning_rate)
|
||||
self.alpha = alpha
|
||||
self.eps = eps
|
||||
|
||||
@ -246,7 +272,7 @@ class Adagrad(Optimizer):
|
||||
def __init__(self, learning_rate: float, eps: float = 1e-8):
|
||||
super().__init__()
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self._maybe_schedule("learning_rate", learning_rate)
|
||||
self.eps = eps
|
||||
|
||||
if self.eps < 0.0:
|
||||
@ -295,7 +321,7 @@ class AdaDelta(Optimizer):
|
||||
def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self._maybe_schedule("learning_rate", learning_rate)
|
||||
self.rho = rho
|
||||
self.eps = eps
|
||||
if self.rho < 0.0:
|
||||
@ -361,7 +387,7 @@ class Adam(Optimizer):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self._maybe_schedule("learning_rate", learning_rate)
|
||||
self.betas = betas
|
||||
self.eps = eps
|
||||
|
||||
@ -526,7 +552,7 @@ class Lion(Optimizer):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self._maybe_schedule("learning_rate", learning_rate)
|
||||
self.betas = betas
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
@ -596,7 +622,7 @@ class Adafactor(Optimizer):
|
||||
):
|
||||
super().__init__()
|
||||
if learning_rate is not None:
|
||||
self.learning_rate = learning_rate
|
||||
self._maybe_schedule("learning_rate", learning_rate)
|
||||
self.eps = eps
|
||||
self.clip_threshold = clip_threshold
|
||||
self.decay_rate = decay_rate
|
||||
@ -608,7 +634,6 @@ class Adafactor(Optimizer):
|
||||
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""Initialize optimizer state"""
|
||||
state["step"] = 0
|
||||
if parameter.ndim >= 2:
|
||||
shape = parameter.shape
|
||||
dtype = parameter.dtype
|
||||
@ -626,10 +651,11 @@ class Adafactor(Optimizer):
|
||||
def _compute_learning_rate(self, step, parameter_rms):
|
||||
if self.relative_step:
|
||||
min_step = 1e-6 * step if self.warmup_init else 1e-2
|
||||
relative_step_size = min(min_step, 1 / math.sqrt(step))
|
||||
relative_step_size = mx.minimum(min_step, mx.rsqrt(step))
|
||||
else:
|
||||
relative_step_size = self.learning_rate.astype(parameter_rms)
|
||||
relative_step_size = self.learning_rate
|
||||
|
||||
relative_step_size = relative_step_size.astype(parameter_rms.dtype)
|
||||
parameter_scale = 1.0
|
||||
if self.scale_parameter:
|
||||
parameter_scale = mx.maximum(self.eps[1], parameter_rms)
|
||||
@ -648,13 +674,12 @@ class Adafactor(Optimizer):
|
||||
"""Performs the Adafactor parameter and state update."""
|
||||
factored = gradient.ndim >= 2
|
||||
|
||||
step = state["step"] + 1
|
||||
state["step"] = step
|
||||
step = self.step
|
||||
use_first_moment = self.beta_1 is not None
|
||||
|
||||
parameter_rms = self._compute_rms(parameter)
|
||||
learning_rate = self._compute_learning_rate(step, parameter_rms)
|
||||
beta_2 = 1.0 - math.pow(step, self.decay_rate)
|
||||
beta_2 = 1.0 - (step**self.decay_rate).astype(parameter_rms.dtype)
|
||||
update = mx.square(gradient) + self.eps[0]
|
||||
|
||||
if factored:
|
86
python/mlx/optimizers/schedulers.py
Normal file
86
python/mlx/optimizers/schedulers.py
Normal file
@ -0,0 +1,86 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def exponential_decay(init: float, decay_rate: float):
|
||||
r"""Make an exponential decay scheduler.
|
||||
|
||||
Args:
|
||||
init (float): Initial value.
|
||||
decay_rate (float): Multiplicative factor to decay by.
|
||||
|
||||
Example:
|
||||
>>> lr_schedule = optim.exponential_decay(1e-1, 0.9)
|
||||
>>> optimizer = optim.SGD(learning_rate=lr_schedule)
|
||||
>>> optimizer.learning_rate
|
||||
array(0.1, dtype=float32)
|
||||
>>>
|
||||
>>> for _ in range(5): optimizer.update({}, {})
|
||||
...
|
||||
>>> optimizer.learning_rate
|
||||
array(0.06561, dtype=float32)
|
||||
"""
|
||||
|
||||
def schedule(step):
|
||||
return init * decay_rate**step
|
||||
|
||||
return schedule
|
||||
|
||||
|
||||
def step_decay(init: float, decay_rate: float, step_size: int):
|
||||
r"""Make a step decay scheduler.
|
||||
|
||||
Args:
|
||||
init (float): Initial value.
|
||||
decay_rate (float): Multiplicative factor to decay by.
|
||||
step_size (int): Decay every ``step_size`` steps.
|
||||
|
||||
Example:
|
||||
|
||||
>>> lr_schedule = optim.step_decay(1e-1, 0.9, 10)
|
||||
>>> optimizer = optim.SGD(learning_rate=lr_schedule)
|
||||
>>> optimizer.learning_rate
|
||||
array(0.1, dtype=float32)
|
||||
>>>
|
||||
>>> for _ in range(21): optimizer.update({}, {})
|
||||
...
|
||||
>>> optimizer.learning_rate
|
||||
array(0.081, dtype=float32)
|
||||
"""
|
||||
|
||||
def schedule(step):
|
||||
return init * (decay_rate ** (step // step_size))
|
||||
|
||||
return schedule
|
||||
|
||||
|
||||
def cosine_decay(init: float, decay_steps: int):
|
||||
r"""Make a cosine decay scheduler.
|
||||
|
||||
Args:
|
||||
init (float): Initial value.
|
||||
decay_steps (int): Number of steps to decay over. The decayed
|
||||
value is constant for steps beyond ``decay_steps``.
|
||||
|
||||
Example:
|
||||
|
||||
>>> lr_schedule = optim.cosine_decay(1e-1, 1000)
|
||||
>>> optimizer = optim.SGD(learning_rate=lr_schedule)
|
||||
>>> optimizer.learning_rate
|
||||
array(0.1, dtype=float32)
|
||||
>>>
|
||||
>>> for _ in range(5): optimizer.update({}, {})
|
||||
...
|
||||
>>> optimizer.learning_rate
|
||||
array(0.0999961, dtype=float32)
|
||||
"""
|
||||
|
||||
def scheduler(step):
|
||||
s = mx.minimum(step, decay_steps)
|
||||
decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s))
|
||||
return init * decay
|
||||
|
||||
return scheduler
|
@ -971,6 +971,12 @@ void init_array(py::module_& m) {
|
||||
return power(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__rpow__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
return power(to_array(v, a.dtype()), a);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__ipow__",
|
||||
[](array& a, const ScalarOrArray v) {
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
@ -15,9 +16,12 @@ from mlx.utils import tree_flatten, tree_map
|
||||
def get_all_optimizers():
|
||||
classes = dict()
|
||||
for name, obj in inspect.getmembers(opt):
|
||||
if inspect.isclass(obj):
|
||||
if obj.__name__ not in ["Optimizer"]:
|
||||
classes[name] = obj
|
||||
if (
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, opt.Optimizer)
|
||||
and obj != opt.Optimizer
|
||||
):
|
||||
classes[name] = obj
|
||||
return classes
|
||||
|
||||
|
||||
@ -204,18 +208,16 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
x = mx.zeros((5, 5))
|
||||
grad = mx.ones_like(x)
|
||||
optimizer = opt.Adafactor()
|
||||
optimizer.init(x)
|
||||
for _ in range(2):
|
||||
xp = optimizer.apply_single(grad, x, optimizer.state)
|
||||
xp = optimizer.apply_gradients(grad, x)
|
||||
self.assertEqual(xp.dtype, x.dtype)
|
||||
self.assertEqual(xp.shape, x.shape)
|
||||
|
||||
x = mx.zeros((5, 5), mx.float16)
|
||||
grad = mx.ones_like(x)
|
||||
optimizer = opt.Adafactor()
|
||||
optimizer.init(x)
|
||||
for _ in range(2):
|
||||
xp = optimizer.apply_single(grad, x, optimizer.state)
|
||||
xp = optimizer.apply_gradients(grad, x)
|
||||
self.assertEqual(xp.dtype, x.dtype)
|
||||
self.assertEqual(xp.shape, x.shape)
|
||||
self.assertEqual(optimizer.state["step"], 2)
|
||||
@ -294,5 +296,50 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
|
||||
|
||||
|
||||
class TestSchedulers(unittest.TestCase):
|
||||
def test_decay_lr(self):
|
||||
for optim_class in optimizers_dict.values():
|
||||
lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
|
||||
optimizer = optim_class(learning_rate=lr_schedule)
|
||||
|
||||
params = {"w": mx.ones((5, 5))}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
for it in range(10):
|
||||
expected_lr = 0.1 * (0.9**it)
|
||||
self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7)
|
||||
return optimizer.apply_gradients(grads, params)
|
||||
|
||||
def test_step_decay(self):
|
||||
lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
|
||||
lr = lr_schedule(2500)
|
||||
expected_lr = 0.1 * (0.9**2)
|
||||
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||
|
||||
def test_exponential_decay(self):
|
||||
lr_schedule = opt.exponential_decay(1e-1, 0.99)
|
||||
lr = lr_schedule(10)
|
||||
expected_lr = 0.1 * (0.99**10)
|
||||
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||
|
||||
def test_cosine_decay(self):
|
||||
lr_schedule = opt.cosine_decay(0.1, 10)
|
||||
lr = lr_schedule(4)
|
||||
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
|
||||
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||
|
||||
def test_compile_with_schedule(self):
|
||||
lr_schedule = opt.exponential_decay(1e-1, 0.9)
|
||||
optimizer = opt.SGD(learning_rate=lr_schedule)
|
||||
|
||||
@partial(mx.compile, inputs=optimizer.state, outputs=optimizer.state)
|
||||
def update():
|
||||
optimizer.update({}, {})
|
||||
|
||||
for step in range(5):
|
||||
update()
|
||||
self.assertAlmostEqual(lr_schedule(step), optimizer.learning_rate.item())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user