Added Adafactor (#415)

* Added adafactor

* Added Adafactor and ran pre-commit

* modified operations

* Added docstrings

* Switched two ops to fix a bug

* added underscore for internal functions and removed the plus sign in the last return statment

* Removed parameter rms from the optimizer state because its not needed

* Added simple MNIST test for Adafactor and temporary training log

* remove test files

* nits in docs

* comment nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Hazem Essam 2024-01-24 01:11:27 +02:00 committed by GitHub
parent 755dcf6137
commit 37fc9db82c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 164 additions and 10 deletions

View File

@ -40,6 +40,7 @@ model's parameters and the **optimizer state**.
SGD
RMSprop
Adagrad
Adafactor
AdaDelta
Adam
AdamW

View File

@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc.
import math
from typing import List
from typing import List, Optional, Tuple
import mlx.core as mx
from mlx.utils import tree_map
@ -76,7 +76,7 @@ class Optimizer:
class SGD(Optimizer):
r"""Stochastic gradient descent optimizer.
r"""The stochastic gradient descent optimizer.
Updates a parameter :math:`w` with a gradient :math:`g` as follows
@ -141,7 +141,7 @@ class SGD(Optimizer):
class RMSprop(Optimizer):
r"""Implementation of the RMSprop optimizer [1].
r"""The RMSprop optimizer [1].
[1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning
@ -190,7 +190,7 @@ class RMSprop(Optimizer):
class Adagrad(Optimizer):
r"""Implementation of the Adagrad optimizer [1].
r"""The Adagrad optimizer [1].
Our Adagrad implementation follows the original paper. In detail,
@ -235,7 +235,7 @@ class Adagrad(Optimizer):
class AdaDelta(Optimizer):
r"""Implementation of the AdaDelta optimizer with learning rate[1].
r"""The AdaDelta optimizer with a learning rate [1].
Our AdaDelta implementation follows the original paper. In detail,
@ -294,7 +294,7 @@ class AdaDelta(Optimizer):
class Adam(Optimizer):
r"""Implementation of the Adam optimizer [1].
r"""The Adam optimizer [1].
Our Adam implementation follows the original paper and omits the bias
correction in the first and second moment estimates. In detail,
@ -346,7 +346,7 @@ class Adam(Optimizer):
class AdamW(Adam):
r"""Implementation of the AdamW optimizer [1].
r"""The AdamW optimizer [1].
Following the above convention, in contrast with [1], we do not use bias
correction in the first and second moments for AdamW. We update the weights
@ -395,8 +395,7 @@ class AdamW(Adam):
class Adamax(Adam):
r"""Implementation of the Adamax optimizer. It is a variant of Adam based
on the infinity norm [1].
r"""The Adamax optimizer, a variant of Adam based on the infinity norm [1].
Our Adam implementation follows the original paper and omits the bias
correction in the first and second moment estimates. In detail,
@ -449,7 +448,7 @@ class Adamax(Adam):
class Lion(Optimizer):
r"""Implementation of the Lion optimizer [1].
r"""The Lion optimizer [1].
Since updates are computed through the sign operation, they tend to
have larger norm than for other optimizers such as SGD and Adam.
@ -502,3 +501,139 @@ class Lion(Optimizer):
if weight_decay > 0:
parameter = (1 - lr * weight_decay) * parameter
return parameter - lr * mx.sign(c)
class Adafactor(Optimizer):
r"""The Adafactor optimizer.
Our Adafactor implementation follows the original paper: `Adafactor:
Adaptive Learning Rates with Sublinear Memory Cost
<https://arxiv.org/abs/1804.04235>`_
Args:
learning_rate (float, optional): The learning rate. Default: ``None``.
eps (tuple(float, float), optional): The first term :math:`\epsilon_1`
added to the square of the gradients to improve numerical
stability and the second term :math:`\epsilon_2` is used for
parameter scaling if ``parameter_scale`` is set to ``True``.
Default: ``(1e-30, 1e-3)``.
clip_threshold (float, optional): Clips the unscaled update at
``clip_threshold``. Default: ``1.0``.
decay_rate (float, optional): Coefficient for the running average
of the squared gradient. Default: ``-0.8``.
beta_1 (float, optional): If set to a value bigger than zero
then first moment will be used. Default: ``None``.
weight_decay (float, optional): The weight decay :math:`\lambda`.
Default: ``0.0``.
scale_parameter (bool, optional): If set to ``True`` the learning rate
will be scaled by :math:`\max(\epsilon_1, \text{RMS}(w_{t-1}))`.
Default: ``True``.
relative_step (bool, optional): If set to ``True`` the ``learning_rate``
will be ignored and relative step size will be computed.
Default: ``True``.
warmup_init (bool, optional): If set to ``True`` then the relative
step size will be calculated by the current step. Default:
``False``.
"""
def __init__(
self,
learning_rate: Optional[float] = None,
eps: Tuple[float, float] = (1e-30, 1e-3),
clip_threshold: float = 1.0,
decay_rate: float = -0.8,
beta_1: Optional[float] = None,
weight_decay: float = 0.0,
scale_parameter: bool = True,
relative_step: bool = True,
warmup_init: bool = False,
):
super().__init__()
self.learning_rate = learning_rate
self.eps = eps
self.clip_threshold = clip_threshold
self.decay_rate = decay_rate
self.beta_1 = beta_1
self.weight_decay = weight_decay
self.scale_parameter = scale_parameter
self.relative_step = relative_step
self.warmup_init = warmup_init
def _compute_rms(self, inputs):
return mx.sqrt(mx.mean(mx.square(inputs)))
def _compute_learning_rate(self, step, parameter_rms):
relative_step_size = self.learning_rate
if self.relative_step:
min_step = 1e-6 * step if self.warmup_init else 1e-2
relative_step_size = min(min_step, 1 / math.sqrt(step))
parameter_scale = 1.0
if self.scale_parameter:
parameter_scale = mx.maximum(self.eps[1], parameter_rms)
return parameter_scale * relative_step_size
def _approximate_exp_moving_avg(self, exp_avg_sq_row, exp_avg_sq_col):
r_factor = mx.rsqrt(
exp_avg_sq_row / mx.mean(exp_avg_sq_row, axis=-1, keepdims=True)
)
c_factor = mx.rsqrt(exp_avg_sq_col)
return mx.matmul(
mx.expand_dims(r_factor, axis=-1), mx.expand_dims(c_factor, axis=0)
)
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""Performs the Adafactor parameter and state update."""
gradient_shape = gradient.shape
factored = len(gradient_shape) >= 2
step = state.get("step", 0) + 1
state["step"] = step
use_first_moment = self.beta_1 is not None
parameter_rms = self._compute_rms(parameter)
learning_rate = self._compute_learning_rate(step, parameter_rms)
beta_2 = 1.0 - math.pow(step, self.decay_rate)
update = mx.square(gradient) + self.eps[0]
if factored:
exp_avg_sq_row = state.get(
"exp_avg_sq_row", mx.zeros(gradient_shape[:-1], dtype=gradient.dtype)
)
exp_avg_sq_col = state.get(
"exp_avg_sq_col",
mx.zeros(
gradient_shape[:-2] + gradient_shape[-1:], dtype=gradient.dtype
),
)
exp_avg_sq_row = (beta_2 * exp_avg_sq_row) + (
(1 - beta_2) * mx.mean(update, axis=-1)
)
exp_avg_sq_col = (beta_2 * exp_avg_sq_col) + (
(1 - beta_2) * mx.mean(update, axis=-2)
)
state["exp_avg_sq_row"] = exp_avg_sq_row
state["exp_avg_sq_col"] = exp_avg_sq_col
update = self._approximate_exp_moving_avg(exp_avg_sq_row, exp_avg_sq_col)
update = update * gradient
else:
exp_avg_sq = state.get("exp_avg_sq", mx.zeros_like(gradient))
exp_avg_sq = (beta_2 * exp_avg_sq) + ((1 - beta_2) * update)
state["exp_avg_sq"] = exp_avg_sq
update = mx.rsqrt(exp_avg_sq) * gradient
update = update / mx.maximum(
1.0, self._compute_rms(update) / self.clip_threshold
)
update = learning_rate * update
if use_first_moment:
exp_avg = state.get("exp_avg", mx.zeros_like(gradient))
exp_avg = (self.beta_1 * exp_avg) + ((1 - self.beta_1) * update)
state["exp_avg"] = exp_avg
update = exp_avg
if self.weight_decay != 0:
parameter += parameter * (-self.weight_decay * learning_rate)
return parameter - update

View File

@ -39,6 +39,24 @@ class TestOptimizers(mlx_tests.MLXTestCase):
all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
self.assertTrue(all_equal)
def test_adafactor(self):
x = mx.zeros((5, 5))
grad = mx.ones_like(x)
optimizer = opt.Adafactor()
for _ in range(2):
xp = optimizer.apply_single(grad, x, optimizer.state)
self.assertEqual(xp.dtype, x.dtype)
self.assertEqual(xp.shape, x.shape)
x = mx.zeros((5, 5), mx.float16)
grad = mx.ones_like(x)
optimizer = opt.Adafactor()
for _ in range(2):
xp = optimizer.apply_single(grad, x, optimizer.state)
self.assertEqual(xp.dtype, x.dtype)
self.assertEqual(xp.shape, x.shape)
self.assertEqual(optimizer.state["step"], 2)
if __name__ == "__main__":
unittest.main()