mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-24 22:36:39 +08:00
Added Adafactor (#415)
* Added adafactor * Added Adafactor and ran pre-commit * modified operations * Added docstrings * Switched two ops to fix a bug * added underscore for internal functions and removed the plus sign in the last return statment * Removed parameter rms from the optimizer state because its not needed * Added simple MNIST test for Adafactor and temporary training log * remove test files * nits in docs * comment nit --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
755dcf6137
commit
37fc9db82c
@ -40,6 +40,7 @@ model's parameters and the **optimizer state**.
|
|||||||
SGD
|
SGD
|
||||||
RMSprop
|
RMSprop
|
||||||
Adagrad
|
Adagrad
|
||||||
|
Adafactor
|
||||||
AdaDelta
|
AdaDelta
|
||||||
Adam
|
Adam
|
||||||
AdamW
|
AdamW
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.utils import tree_map
|
from mlx.utils import tree_map
|
||||||
@ -76,7 +76,7 @@ class Optimizer:
|
|||||||
|
|
||||||
|
|
||||||
class SGD(Optimizer):
|
class SGD(Optimizer):
|
||||||
r"""Stochastic gradient descent optimizer.
|
r"""The stochastic gradient descent optimizer.
|
||||||
|
|
||||||
Updates a parameter :math:`w` with a gradient :math:`g` as follows
|
Updates a parameter :math:`w` with a gradient :math:`g` as follows
|
||||||
|
|
||||||
@ -141,7 +141,7 @@ class SGD(Optimizer):
|
|||||||
|
|
||||||
|
|
||||||
class RMSprop(Optimizer):
|
class RMSprop(Optimizer):
|
||||||
r"""Implementation of the RMSprop optimizer [1].
|
r"""The RMSprop optimizer [1].
|
||||||
|
|
||||||
[1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning
|
[1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning
|
||||||
|
|
||||||
@ -190,7 +190,7 @@ class RMSprop(Optimizer):
|
|||||||
|
|
||||||
|
|
||||||
class Adagrad(Optimizer):
|
class Adagrad(Optimizer):
|
||||||
r"""Implementation of the Adagrad optimizer [1].
|
r"""The Adagrad optimizer [1].
|
||||||
|
|
||||||
Our Adagrad implementation follows the original paper. In detail,
|
Our Adagrad implementation follows the original paper. In detail,
|
||||||
|
|
||||||
@ -235,7 +235,7 @@ class Adagrad(Optimizer):
|
|||||||
|
|
||||||
|
|
||||||
class AdaDelta(Optimizer):
|
class AdaDelta(Optimizer):
|
||||||
r"""Implementation of the AdaDelta optimizer with learning rate[1].
|
r"""The AdaDelta optimizer with a learning rate [1].
|
||||||
|
|
||||||
Our AdaDelta implementation follows the original paper. In detail,
|
Our AdaDelta implementation follows the original paper. In detail,
|
||||||
|
|
||||||
@ -294,7 +294,7 @@ class AdaDelta(Optimizer):
|
|||||||
|
|
||||||
|
|
||||||
class Adam(Optimizer):
|
class Adam(Optimizer):
|
||||||
r"""Implementation of the Adam optimizer [1].
|
r"""The Adam optimizer [1].
|
||||||
|
|
||||||
Our Adam implementation follows the original paper and omits the bias
|
Our Adam implementation follows the original paper and omits the bias
|
||||||
correction in the first and second moment estimates. In detail,
|
correction in the first and second moment estimates. In detail,
|
||||||
@ -346,7 +346,7 @@ class Adam(Optimizer):
|
|||||||
|
|
||||||
|
|
||||||
class AdamW(Adam):
|
class AdamW(Adam):
|
||||||
r"""Implementation of the AdamW optimizer [1].
|
r"""The AdamW optimizer [1].
|
||||||
|
|
||||||
Following the above convention, in contrast with [1], we do not use bias
|
Following the above convention, in contrast with [1], we do not use bias
|
||||||
correction in the first and second moments for AdamW. We update the weights
|
correction in the first and second moments for AdamW. We update the weights
|
||||||
@ -395,8 +395,7 @@ class AdamW(Adam):
|
|||||||
|
|
||||||
|
|
||||||
class Adamax(Adam):
|
class Adamax(Adam):
|
||||||
r"""Implementation of the Adamax optimizer. It is a variant of Adam based
|
r"""The Adamax optimizer, a variant of Adam based on the infinity norm [1].
|
||||||
on the infinity norm [1].
|
|
||||||
|
|
||||||
Our Adam implementation follows the original paper and omits the bias
|
Our Adam implementation follows the original paper and omits the bias
|
||||||
correction in the first and second moment estimates. In detail,
|
correction in the first and second moment estimates. In detail,
|
||||||
@ -449,7 +448,7 @@ class Adamax(Adam):
|
|||||||
|
|
||||||
|
|
||||||
class Lion(Optimizer):
|
class Lion(Optimizer):
|
||||||
r"""Implementation of the Lion optimizer [1].
|
r"""The Lion optimizer [1].
|
||||||
|
|
||||||
Since updates are computed through the sign operation, they tend to
|
Since updates are computed through the sign operation, they tend to
|
||||||
have larger norm than for other optimizers such as SGD and Adam.
|
have larger norm than for other optimizers such as SGD and Adam.
|
||||||
@ -502,3 +501,139 @@ class Lion(Optimizer):
|
|||||||
if weight_decay > 0:
|
if weight_decay > 0:
|
||||||
parameter = (1 - lr * weight_decay) * parameter
|
parameter = (1 - lr * weight_decay) * parameter
|
||||||
return parameter - lr * mx.sign(c)
|
return parameter - lr * mx.sign(c)
|
||||||
|
|
||||||
|
|
||||||
|
class Adafactor(Optimizer):
|
||||||
|
r"""The Adafactor optimizer.
|
||||||
|
|
||||||
|
Our Adafactor implementation follows the original paper: `Adafactor:
|
||||||
|
Adaptive Learning Rates with Sublinear Memory Cost
|
||||||
|
<https://arxiv.org/abs/1804.04235>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
learning_rate (float, optional): The learning rate. Default: ``None``.
|
||||||
|
eps (tuple(float, float), optional): The first term :math:`\epsilon_1`
|
||||||
|
added to the square of the gradients to improve numerical
|
||||||
|
stability and the second term :math:`\epsilon_2` is used for
|
||||||
|
parameter scaling if ``parameter_scale`` is set to ``True``.
|
||||||
|
Default: ``(1e-30, 1e-3)``.
|
||||||
|
clip_threshold (float, optional): Clips the unscaled update at
|
||||||
|
``clip_threshold``. Default: ``1.0``.
|
||||||
|
decay_rate (float, optional): Coefficient for the running average
|
||||||
|
of the squared gradient. Default: ``-0.8``.
|
||||||
|
beta_1 (float, optional): If set to a value bigger than zero
|
||||||
|
then first moment will be used. Default: ``None``.
|
||||||
|
weight_decay (float, optional): The weight decay :math:`\lambda`.
|
||||||
|
Default: ``0.0``.
|
||||||
|
scale_parameter (bool, optional): If set to ``True`` the learning rate
|
||||||
|
will be scaled by :math:`\max(\epsilon_1, \text{RMS}(w_{t-1}))`.
|
||||||
|
Default: ``True``.
|
||||||
|
relative_step (bool, optional): If set to ``True`` the ``learning_rate``
|
||||||
|
will be ignored and relative step size will be computed.
|
||||||
|
Default: ``True``.
|
||||||
|
warmup_init (bool, optional): If set to ``True`` then the relative
|
||||||
|
step size will be calculated by the current step. Default:
|
||||||
|
``False``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
learning_rate: Optional[float] = None,
|
||||||
|
eps: Tuple[float, float] = (1e-30, 1e-3),
|
||||||
|
clip_threshold: float = 1.0,
|
||||||
|
decay_rate: float = -0.8,
|
||||||
|
beta_1: Optional[float] = None,
|
||||||
|
weight_decay: float = 0.0,
|
||||||
|
scale_parameter: bool = True,
|
||||||
|
relative_step: bool = True,
|
||||||
|
warmup_init: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
self.eps = eps
|
||||||
|
self.clip_threshold = clip_threshold
|
||||||
|
self.decay_rate = decay_rate
|
||||||
|
self.beta_1 = beta_1
|
||||||
|
self.weight_decay = weight_decay
|
||||||
|
self.scale_parameter = scale_parameter
|
||||||
|
self.relative_step = relative_step
|
||||||
|
self.warmup_init = warmup_init
|
||||||
|
|
||||||
|
def _compute_rms(self, inputs):
|
||||||
|
return mx.sqrt(mx.mean(mx.square(inputs)))
|
||||||
|
|
||||||
|
def _compute_learning_rate(self, step, parameter_rms):
|
||||||
|
relative_step_size = self.learning_rate
|
||||||
|
if self.relative_step:
|
||||||
|
min_step = 1e-6 * step if self.warmup_init else 1e-2
|
||||||
|
relative_step_size = min(min_step, 1 / math.sqrt(step))
|
||||||
|
|
||||||
|
parameter_scale = 1.0
|
||||||
|
if self.scale_parameter:
|
||||||
|
parameter_scale = mx.maximum(self.eps[1], parameter_rms)
|
||||||
|
return parameter_scale * relative_step_size
|
||||||
|
|
||||||
|
def _approximate_exp_moving_avg(self, exp_avg_sq_row, exp_avg_sq_col):
|
||||||
|
r_factor = mx.rsqrt(
|
||||||
|
exp_avg_sq_row / mx.mean(exp_avg_sq_row, axis=-1, keepdims=True)
|
||||||
|
)
|
||||||
|
c_factor = mx.rsqrt(exp_avg_sq_col)
|
||||||
|
return mx.matmul(
|
||||||
|
mx.expand_dims(r_factor, axis=-1), mx.expand_dims(c_factor, axis=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_single(
|
||||||
|
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||||
|
):
|
||||||
|
"""Performs the Adafactor parameter and state update."""
|
||||||
|
gradient_shape = gradient.shape
|
||||||
|
factored = len(gradient_shape) >= 2
|
||||||
|
step = state.get("step", 0) + 1
|
||||||
|
state["step"] = step
|
||||||
|
use_first_moment = self.beta_1 is not None
|
||||||
|
|
||||||
|
parameter_rms = self._compute_rms(parameter)
|
||||||
|
learning_rate = self._compute_learning_rate(step, parameter_rms)
|
||||||
|
beta_2 = 1.0 - math.pow(step, self.decay_rate)
|
||||||
|
update = mx.square(gradient) + self.eps[0]
|
||||||
|
|
||||||
|
if factored:
|
||||||
|
exp_avg_sq_row = state.get(
|
||||||
|
"exp_avg_sq_row", mx.zeros(gradient_shape[:-1], dtype=gradient.dtype)
|
||||||
|
)
|
||||||
|
exp_avg_sq_col = state.get(
|
||||||
|
"exp_avg_sq_col",
|
||||||
|
mx.zeros(
|
||||||
|
gradient_shape[:-2] + gradient_shape[-1:], dtype=gradient.dtype
|
||||||
|
),
|
||||||
|
)
|
||||||
|
exp_avg_sq_row = (beta_2 * exp_avg_sq_row) + (
|
||||||
|
(1 - beta_2) * mx.mean(update, axis=-1)
|
||||||
|
)
|
||||||
|
exp_avg_sq_col = (beta_2 * exp_avg_sq_col) + (
|
||||||
|
(1 - beta_2) * mx.mean(update, axis=-2)
|
||||||
|
)
|
||||||
|
state["exp_avg_sq_row"] = exp_avg_sq_row
|
||||||
|
state["exp_avg_sq_col"] = exp_avg_sq_col
|
||||||
|
update = self._approximate_exp_moving_avg(exp_avg_sq_row, exp_avg_sq_col)
|
||||||
|
update = update * gradient
|
||||||
|
else:
|
||||||
|
exp_avg_sq = state.get("exp_avg_sq", mx.zeros_like(gradient))
|
||||||
|
exp_avg_sq = (beta_2 * exp_avg_sq) + ((1 - beta_2) * update)
|
||||||
|
state["exp_avg_sq"] = exp_avg_sq
|
||||||
|
update = mx.rsqrt(exp_avg_sq) * gradient
|
||||||
|
|
||||||
|
update = update / mx.maximum(
|
||||||
|
1.0, self._compute_rms(update) / self.clip_threshold
|
||||||
|
)
|
||||||
|
update = learning_rate * update
|
||||||
|
|
||||||
|
if use_first_moment:
|
||||||
|
exp_avg = state.get("exp_avg", mx.zeros_like(gradient))
|
||||||
|
exp_avg = (self.beta_1 * exp_avg) + ((1 - self.beta_1) * update)
|
||||||
|
state["exp_avg"] = exp_avg
|
||||||
|
update = exp_avg
|
||||||
|
|
||||||
|
if self.weight_decay != 0:
|
||||||
|
parameter += parameter * (-self.weight_decay * learning_rate)
|
||||||
|
return parameter - update
|
||||||
|
@ -39,6 +39,24 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
|||||||
all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
|
all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
|
||||||
self.assertTrue(all_equal)
|
self.assertTrue(all_equal)
|
||||||
|
|
||||||
|
def test_adafactor(self):
|
||||||
|
x = mx.zeros((5, 5))
|
||||||
|
grad = mx.ones_like(x)
|
||||||
|
optimizer = opt.Adafactor()
|
||||||
|
for _ in range(2):
|
||||||
|
xp = optimizer.apply_single(grad, x, optimizer.state)
|
||||||
|
self.assertEqual(xp.dtype, x.dtype)
|
||||||
|
self.assertEqual(xp.shape, x.shape)
|
||||||
|
|
||||||
|
x = mx.zeros((5, 5), mx.float16)
|
||||||
|
grad = mx.ones_like(x)
|
||||||
|
optimizer = opt.Adafactor()
|
||||||
|
for _ in range(2):
|
||||||
|
xp = optimizer.apply_single(grad, x, optimizer.state)
|
||||||
|
self.assertEqual(xp.dtype, x.dtype)
|
||||||
|
self.assertEqual(xp.shape, x.shape)
|
||||||
|
self.assertEqual(optimizer.state["step"], 2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user