mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-24 22:36:39 +08:00
Added Adafactor (#415)
* Added adafactor * Added Adafactor and ran pre-commit * modified operations * Added docstrings * Switched two ops to fix a bug * added underscore for internal functions and removed the plus sign in the last return statment * Removed parameter rms from the optimizer state because its not needed * Added simple MNIST test for Adafactor and temporary training log * remove test files * nits in docs * comment nit --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
755dcf6137
commit
37fc9db82c
@ -40,6 +40,7 @@ model's parameters and the **optimizer state**.
|
||||
SGD
|
||||
RMSprop
|
||||
Adagrad
|
||||
Adafactor
|
||||
AdaDelta
|
||||
Adam
|
||||
AdamW
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import List
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
@ -76,7 +76,7 @@ class Optimizer:
|
||||
|
||||
|
||||
class SGD(Optimizer):
|
||||
r"""Stochastic gradient descent optimizer.
|
||||
r"""The stochastic gradient descent optimizer.
|
||||
|
||||
Updates a parameter :math:`w` with a gradient :math:`g` as follows
|
||||
|
||||
@ -141,7 +141,7 @@ class SGD(Optimizer):
|
||||
|
||||
|
||||
class RMSprop(Optimizer):
|
||||
r"""Implementation of the RMSprop optimizer [1].
|
||||
r"""The RMSprop optimizer [1].
|
||||
|
||||
[1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning
|
||||
|
||||
@ -190,7 +190,7 @@ class RMSprop(Optimizer):
|
||||
|
||||
|
||||
class Adagrad(Optimizer):
|
||||
r"""Implementation of the Adagrad optimizer [1].
|
||||
r"""The Adagrad optimizer [1].
|
||||
|
||||
Our Adagrad implementation follows the original paper. In detail,
|
||||
|
||||
@ -235,7 +235,7 @@ class Adagrad(Optimizer):
|
||||
|
||||
|
||||
class AdaDelta(Optimizer):
|
||||
r"""Implementation of the AdaDelta optimizer with learning rate[1].
|
||||
r"""The AdaDelta optimizer with a learning rate [1].
|
||||
|
||||
Our AdaDelta implementation follows the original paper. In detail,
|
||||
|
||||
@ -294,7 +294,7 @@ class AdaDelta(Optimizer):
|
||||
|
||||
|
||||
class Adam(Optimizer):
|
||||
r"""Implementation of the Adam optimizer [1].
|
||||
r"""The Adam optimizer [1].
|
||||
|
||||
Our Adam implementation follows the original paper and omits the bias
|
||||
correction in the first and second moment estimates. In detail,
|
||||
@ -346,7 +346,7 @@ class Adam(Optimizer):
|
||||
|
||||
|
||||
class AdamW(Adam):
|
||||
r"""Implementation of the AdamW optimizer [1].
|
||||
r"""The AdamW optimizer [1].
|
||||
|
||||
Following the above convention, in contrast with [1], we do not use bias
|
||||
correction in the first and second moments for AdamW. We update the weights
|
||||
@ -395,8 +395,7 @@ class AdamW(Adam):
|
||||
|
||||
|
||||
class Adamax(Adam):
|
||||
r"""Implementation of the Adamax optimizer. It is a variant of Adam based
|
||||
on the infinity norm [1].
|
||||
r"""The Adamax optimizer, a variant of Adam based on the infinity norm [1].
|
||||
|
||||
Our Adam implementation follows the original paper and omits the bias
|
||||
correction in the first and second moment estimates. In detail,
|
||||
@ -449,7 +448,7 @@ class Adamax(Adam):
|
||||
|
||||
|
||||
class Lion(Optimizer):
|
||||
r"""Implementation of the Lion optimizer [1].
|
||||
r"""The Lion optimizer [1].
|
||||
|
||||
Since updates are computed through the sign operation, they tend to
|
||||
have larger norm than for other optimizers such as SGD and Adam.
|
||||
@ -502,3 +501,139 @@ class Lion(Optimizer):
|
||||
if weight_decay > 0:
|
||||
parameter = (1 - lr * weight_decay) * parameter
|
||||
return parameter - lr * mx.sign(c)
|
||||
|
||||
|
||||
class Adafactor(Optimizer):
|
||||
r"""The Adafactor optimizer.
|
||||
|
||||
Our Adafactor implementation follows the original paper: `Adafactor:
|
||||
Adaptive Learning Rates with Sublinear Memory Cost
|
||||
<https://arxiv.org/abs/1804.04235>`_
|
||||
|
||||
Args:
|
||||
learning_rate (float, optional): The learning rate. Default: ``None``.
|
||||
eps (tuple(float, float), optional): The first term :math:`\epsilon_1`
|
||||
added to the square of the gradients to improve numerical
|
||||
stability and the second term :math:`\epsilon_2` is used for
|
||||
parameter scaling if ``parameter_scale`` is set to ``True``.
|
||||
Default: ``(1e-30, 1e-3)``.
|
||||
clip_threshold (float, optional): Clips the unscaled update at
|
||||
``clip_threshold``. Default: ``1.0``.
|
||||
decay_rate (float, optional): Coefficient for the running average
|
||||
of the squared gradient. Default: ``-0.8``.
|
||||
beta_1 (float, optional): If set to a value bigger than zero
|
||||
then first moment will be used. Default: ``None``.
|
||||
weight_decay (float, optional): The weight decay :math:`\lambda`.
|
||||
Default: ``0.0``.
|
||||
scale_parameter (bool, optional): If set to ``True`` the learning rate
|
||||
will be scaled by :math:`\max(\epsilon_1, \text{RMS}(w_{t-1}))`.
|
||||
Default: ``True``.
|
||||
relative_step (bool, optional): If set to ``True`` the ``learning_rate``
|
||||
will be ignored and relative step size will be computed.
|
||||
Default: ``True``.
|
||||
warmup_init (bool, optional): If set to ``True`` then the relative
|
||||
step size will be calculated by the current step. Default:
|
||||
``False``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate: Optional[float] = None,
|
||||
eps: Tuple[float, float] = (1e-30, 1e-3),
|
||||
clip_threshold: float = 1.0,
|
||||
decay_rate: float = -0.8,
|
||||
beta_1: Optional[float] = None,
|
||||
weight_decay: float = 0.0,
|
||||
scale_parameter: bool = True,
|
||||
relative_step: bool = True,
|
||||
warmup_init: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.learning_rate = learning_rate
|
||||
self.eps = eps
|
||||
self.clip_threshold = clip_threshold
|
||||
self.decay_rate = decay_rate
|
||||
self.beta_1 = beta_1
|
||||
self.weight_decay = weight_decay
|
||||
self.scale_parameter = scale_parameter
|
||||
self.relative_step = relative_step
|
||||
self.warmup_init = warmup_init
|
||||
|
||||
def _compute_rms(self, inputs):
|
||||
return mx.sqrt(mx.mean(mx.square(inputs)))
|
||||
|
||||
def _compute_learning_rate(self, step, parameter_rms):
|
||||
relative_step_size = self.learning_rate
|
||||
if self.relative_step:
|
||||
min_step = 1e-6 * step if self.warmup_init else 1e-2
|
||||
relative_step_size = min(min_step, 1 / math.sqrt(step))
|
||||
|
||||
parameter_scale = 1.0
|
||||
if self.scale_parameter:
|
||||
parameter_scale = mx.maximum(self.eps[1], parameter_rms)
|
||||
return parameter_scale * relative_step_size
|
||||
|
||||
def _approximate_exp_moving_avg(self, exp_avg_sq_row, exp_avg_sq_col):
|
||||
r_factor = mx.rsqrt(
|
||||
exp_avg_sq_row / mx.mean(exp_avg_sq_row, axis=-1, keepdims=True)
|
||||
)
|
||||
c_factor = mx.rsqrt(exp_avg_sq_col)
|
||||
return mx.matmul(
|
||||
mx.expand_dims(r_factor, axis=-1), mx.expand_dims(c_factor, axis=0)
|
||||
)
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
"""Performs the Adafactor parameter and state update."""
|
||||
gradient_shape = gradient.shape
|
||||
factored = len(gradient_shape) >= 2
|
||||
step = state.get("step", 0) + 1
|
||||
state["step"] = step
|
||||
use_first_moment = self.beta_1 is not None
|
||||
|
||||
parameter_rms = self._compute_rms(parameter)
|
||||
learning_rate = self._compute_learning_rate(step, parameter_rms)
|
||||
beta_2 = 1.0 - math.pow(step, self.decay_rate)
|
||||
update = mx.square(gradient) + self.eps[0]
|
||||
|
||||
if factored:
|
||||
exp_avg_sq_row = state.get(
|
||||
"exp_avg_sq_row", mx.zeros(gradient_shape[:-1], dtype=gradient.dtype)
|
||||
)
|
||||
exp_avg_sq_col = state.get(
|
||||
"exp_avg_sq_col",
|
||||
mx.zeros(
|
||||
gradient_shape[:-2] + gradient_shape[-1:], dtype=gradient.dtype
|
||||
),
|
||||
)
|
||||
exp_avg_sq_row = (beta_2 * exp_avg_sq_row) + (
|
||||
(1 - beta_2) * mx.mean(update, axis=-1)
|
||||
)
|
||||
exp_avg_sq_col = (beta_2 * exp_avg_sq_col) + (
|
||||
(1 - beta_2) * mx.mean(update, axis=-2)
|
||||
)
|
||||
state["exp_avg_sq_row"] = exp_avg_sq_row
|
||||
state["exp_avg_sq_col"] = exp_avg_sq_col
|
||||
update = self._approximate_exp_moving_avg(exp_avg_sq_row, exp_avg_sq_col)
|
||||
update = update * gradient
|
||||
else:
|
||||
exp_avg_sq = state.get("exp_avg_sq", mx.zeros_like(gradient))
|
||||
exp_avg_sq = (beta_2 * exp_avg_sq) + ((1 - beta_2) * update)
|
||||
state["exp_avg_sq"] = exp_avg_sq
|
||||
update = mx.rsqrt(exp_avg_sq) * gradient
|
||||
|
||||
update = update / mx.maximum(
|
||||
1.0, self._compute_rms(update) / self.clip_threshold
|
||||
)
|
||||
update = learning_rate * update
|
||||
|
||||
if use_first_moment:
|
||||
exp_avg = state.get("exp_avg", mx.zeros_like(gradient))
|
||||
exp_avg = (self.beta_1 * exp_avg) + ((1 - self.beta_1) * update)
|
||||
state["exp_avg"] = exp_avg
|
||||
update = exp_avg
|
||||
|
||||
if self.weight_decay != 0:
|
||||
parameter += parameter * (-self.weight_decay * learning_rate)
|
||||
return parameter - update
|
||||
|
@ -39,6 +39,24 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
|
||||
self.assertTrue(all_equal)
|
||||
|
||||
def test_adafactor(self):
|
||||
x = mx.zeros((5, 5))
|
||||
grad = mx.ones_like(x)
|
||||
optimizer = opt.Adafactor()
|
||||
for _ in range(2):
|
||||
xp = optimizer.apply_single(grad, x, optimizer.state)
|
||||
self.assertEqual(xp.dtype, x.dtype)
|
||||
self.assertEqual(xp.shape, x.shape)
|
||||
|
||||
x = mx.zeros((5, 5), mx.float16)
|
||||
grad = mx.ones_like(x)
|
||||
optimizer = opt.Adafactor()
|
||||
for _ in range(2):
|
||||
xp = optimizer.apply_single(grad, x, optimizer.state)
|
||||
self.assertEqual(xp.dtype, x.dtype)
|
||||
self.assertEqual(xp.shape, x.shape)
|
||||
self.assertEqual(optimizer.state["step"], 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user