Added Adafactor (#415)

* Added adafactor

* Added Adafactor and ran pre-commit

* modified operations

* Added docstrings

* Switched two ops to fix a bug

* added underscore for internal functions and removed the plus sign in the last return statment

* Removed parameter rms from the optimizer state because its not needed

* Added simple MNIST test for Adafactor and temporary training log

* remove test files

* nits in docs

* comment nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Hazem Essam 2024-01-24 01:11:27 +02:00 committed by GitHub
parent 755dcf6137
commit 37fc9db82c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 164 additions and 10 deletions

View File

@ -40,6 +40,7 @@ model's parameters and the **optimizer state**.
SGD SGD
RMSprop RMSprop
Adagrad Adagrad
Adafactor
AdaDelta AdaDelta
Adam Adam
AdamW AdamW

View File

@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import math import math
from typing import List from typing import List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
from mlx.utils import tree_map from mlx.utils import tree_map
@ -76,7 +76,7 @@ class Optimizer:
class SGD(Optimizer): class SGD(Optimizer):
r"""Stochastic gradient descent optimizer. r"""The stochastic gradient descent optimizer.
Updates a parameter :math:`w` with a gradient :math:`g` as follows Updates a parameter :math:`w` with a gradient :math:`g` as follows
@ -141,7 +141,7 @@ class SGD(Optimizer):
class RMSprop(Optimizer): class RMSprop(Optimizer):
r"""Implementation of the RMSprop optimizer [1]. r"""The RMSprop optimizer [1].
[1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning [1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning
@ -190,7 +190,7 @@ class RMSprop(Optimizer):
class Adagrad(Optimizer): class Adagrad(Optimizer):
r"""Implementation of the Adagrad optimizer [1]. r"""The Adagrad optimizer [1].
Our Adagrad implementation follows the original paper. In detail, Our Adagrad implementation follows the original paper. In detail,
@ -235,7 +235,7 @@ class Adagrad(Optimizer):
class AdaDelta(Optimizer): class AdaDelta(Optimizer):
r"""Implementation of the AdaDelta optimizer with learning rate[1]. r"""The AdaDelta optimizer with a learning rate [1].
Our AdaDelta implementation follows the original paper. In detail, Our AdaDelta implementation follows the original paper. In detail,
@ -294,7 +294,7 @@ class AdaDelta(Optimizer):
class Adam(Optimizer): class Adam(Optimizer):
r"""Implementation of the Adam optimizer [1]. r"""The Adam optimizer [1].
Our Adam implementation follows the original paper and omits the bias Our Adam implementation follows the original paper and omits the bias
correction in the first and second moment estimates. In detail, correction in the first and second moment estimates. In detail,
@ -346,7 +346,7 @@ class Adam(Optimizer):
class AdamW(Adam): class AdamW(Adam):
r"""Implementation of the AdamW optimizer [1]. r"""The AdamW optimizer [1].
Following the above convention, in contrast with [1], we do not use bias Following the above convention, in contrast with [1], we do not use bias
correction in the first and second moments for AdamW. We update the weights correction in the first and second moments for AdamW. We update the weights
@ -395,8 +395,7 @@ class AdamW(Adam):
class Adamax(Adam): class Adamax(Adam):
r"""Implementation of the Adamax optimizer. It is a variant of Adam based r"""The Adamax optimizer, a variant of Adam based on the infinity norm [1].
on the infinity norm [1].
Our Adam implementation follows the original paper and omits the bias Our Adam implementation follows the original paper and omits the bias
correction in the first and second moment estimates. In detail, correction in the first and second moment estimates. In detail,
@ -449,7 +448,7 @@ class Adamax(Adam):
class Lion(Optimizer): class Lion(Optimizer):
r"""Implementation of the Lion optimizer [1]. r"""The Lion optimizer [1].
Since updates are computed through the sign operation, they tend to Since updates are computed through the sign operation, they tend to
have larger norm than for other optimizers such as SGD and Adam. have larger norm than for other optimizers such as SGD and Adam.
@ -502,3 +501,139 @@ class Lion(Optimizer):
if weight_decay > 0: if weight_decay > 0:
parameter = (1 - lr * weight_decay) * parameter parameter = (1 - lr * weight_decay) * parameter
return parameter - lr * mx.sign(c) return parameter - lr * mx.sign(c)
class Adafactor(Optimizer):
r"""The Adafactor optimizer.
Our Adafactor implementation follows the original paper: `Adafactor:
Adaptive Learning Rates with Sublinear Memory Cost
<https://arxiv.org/abs/1804.04235>`_
Args:
learning_rate (float, optional): The learning rate. Default: ``None``.
eps (tuple(float, float), optional): The first term :math:`\epsilon_1`
added to the square of the gradients to improve numerical
stability and the second term :math:`\epsilon_2` is used for
parameter scaling if ``parameter_scale`` is set to ``True``.
Default: ``(1e-30, 1e-3)``.
clip_threshold (float, optional): Clips the unscaled update at
``clip_threshold``. Default: ``1.0``.
decay_rate (float, optional): Coefficient for the running average
of the squared gradient. Default: ``-0.8``.
beta_1 (float, optional): If set to a value bigger than zero
then first moment will be used. Default: ``None``.
weight_decay (float, optional): The weight decay :math:`\lambda`.
Default: ``0.0``.
scale_parameter (bool, optional): If set to ``True`` the learning rate
will be scaled by :math:`\max(\epsilon_1, \text{RMS}(w_{t-1}))`.
Default: ``True``.
relative_step (bool, optional): If set to ``True`` the ``learning_rate``
will be ignored and relative step size will be computed.
Default: ``True``.
warmup_init (bool, optional): If set to ``True`` then the relative
step size will be calculated by the current step. Default:
``False``.
"""
def __init__(
self,
learning_rate: Optional[float] = None,
eps: Tuple[float, float] = (1e-30, 1e-3),
clip_threshold: float = 1.0,
decay_rate: float = -0.8,
beta_1: Optional[float] = None,
weight_decay: float = 0.0,
scale_parameter: bool = True,
relative_step: bool = True,
warmup_init: bool = False,
):
super().__init__()
self.learning_rate = learning_rate
self.eps = eps
self.clip_threshold = clip_threshold
self.decay_rate = decay_rate
self.beta_1 = beta_1
self.weight_decay = weight_decay
self.scale_parameter = scale_parameter
self.relative_step = relative_step
self.warmup_init = warmup_init
def _compute_rms(self, inputs):
return mx.sqrt(mx.mean(mx.square(inputs)))
def _compute_learning_rate(self, step, parameter_rms):
relative_step_size = self.learning_rate
if self.relative_step:
min_step = 1e-6 * step if self.warmup_init else 1e-2
relative_step_size = min(min_step, 1 / math.sqrt(step))
parameter_scale = 1.0
if self.scale_parameter:
parameter_scale = mx.maximum(self.eps[1], parameter_rms)
return parameter_scale * relative_step_size
def _approximate_exp_moving_avg(self, exp_avg_sq_row, exp_avg_sq_col):
r_factor = mx.rsqrt(
exp_avg_sq_row / mx.mean(exp_avg_sq_row, axis=-1, keepdims=True)
)
c_factor = mx.rsqrt(exp_avg_sq_col)
return mx.matmul(
mx.expand_dims(r_factor, axis=-1), mx.expand_dims(c_factor, axis=0)
)
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""Performs the Adafactor parameter and state update."""
gradient_shape = gradient.shape
factored = len(gradient_shape) >= 2
step = state.get("step", 0) + 1
state["step"] = step
use_first_moment = self.beta_1 is not None
parameter_rms = self._compute_rms(parameter)
learning_rate = self._compute_learning_rate(step, parameter_rms)
beta_2 = 1.0 - math.pow(step, self.decay_rate)
update = mx.square(gradient) + self.eps[0]
if factored:
exp_avg_sq_row = state.get(
"exp_avg_sq_row", mx.zeros(gradient_shape[:-1], dtype=gradient.dtype)
)
exp_avg_sq_col = state.get(
"exp_avg_sq_col",
mx.zeros(
gradient_shape[:-2] + gradient_shape[-1:], dtype=gradient.dtype
),
)
exp_avg_sq_row = (beta_2 * exp_avg_sq_row) + (
(1 - beta_2) * mx.mean(update, axis=-1)
)
exp_avg_sq_col = (beta_2 * exp_avg_sq_col) + (
(1 - beta_2) * mx.mean(update, axis=-2)
)
state["exp_avg_sq_row"] = exp_avg_sq_row
state["exp_avg_sq_col"] = exp_avg_sq_col
update = self._approximate_exp_moving_avg(exp_avg_sq_row, exp_avg_sq_col)
update = update * gradient
else:
exp_avg_sq = state.get("exp_avg_sq", mx.zeros_like(gradient))
exp_avg_sq = (beta_2 * exp_avg_sq) + ((1 - beta_2) * update)
state["exp_avg_sq"] = exp_avg_sq
update = mx.rsqrt(exp_avg_sq) * gradient
update = update / mx.maximum(
1.0, self._compute_rms(update) / self.clip_threshold
)
update = learning_rate * update
if use_first_moment:
exp_avg = state.get("exp_avg", mx.zeros_like(gradient))
exp_avg = (self.beta_1 * exp_avg) + ((1 - self.beta_1) * update)
state["exp_avg"] = exp_avg
update = exp_avg
if self.weight_decay != 0:
parameter += parameter * (-self.weight_decay * learning_rate)
return parameter - update

View File

@ -39,6 +39,24 @@ class TestOptimizers(mlx_tests.MLXTestCase):
all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape)) all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
self.assertTrue(all_equal) self.assertTrue(all_equal)
def test_adafactor(self):
x = mx.zeros((5, 5))
grad = mx.ones_like(x)
optimizer = opt.Adafactor()
for _ in range(2):
xp = optimizer.apply_single(grad, x, optimizer.state)
self.assertEqual(xp.dtype, x.dtype)
self.assertEqual(xp.shape, x.shape)
x = mx.zeros((5, 5), mx.float16)
grad = mx.ones_like(x)
optimizer = opt.Adafactor()
for _ in range(2):
xp = optimizer.apply_single(grad, x, optimizer.state)
self.assertEqual(xp.dtype, x.dtype)
self.assertEqual(xp.shape, x.shape)
self.assertEqual(optimizer.state["step"], 2)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()