diff --git a/docs/src/python/optimizers.rst b/docs/src/python/optimizers.rst index 7cc6ef906..fe8632a7e 100644 --- a/docs/src/python/optimizers.rst +++ b/docs/src/python/optimizers.rst @@ -40,6 +40,7 @@ model's parameters and the **optimizer state**. SGD RMSprop Adagrad + Adafactor AdaDelta Adam AdamW diff --git a/python/mlx/optimizers.py b/python/mlx/optimizers.py index de77d9bea..8b9965e78 100644 --- a/python/mlx/optimizers.py +++ b/python/mlx/optimizers.py @@ -1,7 +1,7 @@ # Copyright © 2023 Apple Inc. import math -from typing import List +from typing import List, Optional, Tuple import mlx.core as mx from mlx.utils import tree_map @@ -76,7 +76,7 @@ class Optimizer: class SGD(Optimizer): - r"""Stochastic gradient descent optimizer. + r"""The stochastic gradient descent optimizer. Updates a parameter :math:`w` with a gradient :math:`g` as follows @@ -141,7 +141,7 @@ class SGD(Optimizer): class RMSprop(Optimizer): - r"""Implementation of the RMSprop optimizer [1]. + r"""The RMSprop optimizer [1]. [1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning @@ -190,7 +190,7 @@ class RMSprop(Optimizer): class Adagrad(Optimizer): - r"""Implementation of the Adagrad optimizer [1]. + r"""The Adagrad optimizer [1]. Our Adagrad implementation follows the original paper. In detail, @@ -235,7 +235,7 @@ class Adagrad(Optimizer): class AdaDelta(Optimizer): - r"""Implementation of the AdaDelta optimizer with learning rate[1]. + r"""The AdaDelta optimizer with a learning rate [1]. Our AdaDelta implementation follows the original paper. In detail, @@ -294,7 +294,7 @@ class AdaDelta(Optimizer): class Adam(Optimizer): - r"""Implementation of the Adam optimizer [1]. + r"""The Adam optimizer [1]. Our Adam implementation follows the original paper and omits the bias correction in the first and second moment estimates. In detail, @@ -346,7 +346,7 @@ class Adam(Optimizer): class AdamW(Adam): - r"""Implementation of the AdamW optimizer [1]. + r"""The AdamW optimizer [1]. Following the above convention, in contrast with [1], we do not use bias correction in the first and second moments for AdamW. We update the weights @@ -395,8 +395,7 @@ class AdamW(Adam): class Adamax(Adam): - r"""Implementation of the Adamax optimizer. It is a variant of Adam based - on the infinity norm [1]. + r"""The Adamax optimizer, a variant of Adam based on the infinity norm [1]. Our Adam implementation follows the original paper and omits the bias correction in the first and second moment estimates. In detail, @@ -449,7 +448,7 @@ class Adamax(Adam): class Lion(Optimizer): - r"""Implementation of the Lion optimizer [1]. + r"""The Lion optimizer [1]. Since updates are computed through the sign operation, they tend to have larger norm than for other optimizers such as SGD and Adam. @@ -502,3 +501,139 @@ class Lion(Optimizer): if weight_decay > 0: parameter = (1 - lr * weight_decay) * parameter return parameter - lr * mx.sign(c) + + +class Adafactor(Optimizer): + r"""The Adafactor optimizer. + + Our Adafactor implementation follows the original paper: `Adafactor: + Adaptive Learning Rates with Sublinear Memory Cost + `_ + + Args: + learning_rate (float, optional): The learning rate. Default: ``None``. + eps (tuple(float, float), optional): The first term :math:`\epsilon_1` + added to the square of the gradients to improve numerical + stability and the second term :math:`\epsilon_2` is used for + parameter scaling if ``parameter_scale`` is set to ``True``. + Default: ``(1e-30, 1e-3)``. + clip_threshold (float, optional): Clips the unscaled update at + ``clip_threshold``. Default: ``1.0``. + decay_rate (float, optional): Coefficient for the running average + of the squared gradient. Default: ``-0.8``. + beta_1 (float, optional): If set to a value bigger than zero + then first moment will be used. Default: ``None``. + weight_decay (float, optional): The weight decay :math:`\lambda`. + Default: ``0.0``. + scale_parameter (bool, optional): If set to ``True`` the learning rate + will be scaled by :math:`\max(\epsilon_1, \text{RMS}(w_{t-1}))`. + Default: ``True``. + relative_step (bool, optional): If set to ``True`` the ``learning_rate`` + will be ignored and relative step size will be computed. + Default: ``True``. + warmup_init (bool, optional): If set to ``True`` then the relative + step size will be calculated by the current step. Default: + ``False``. + """ + + def __init__( + self, + learning_rate: Optional[float] = None, + eps: Tuple[float, float] = (1e-30, 1e-3), + clip_threshold: float = 1.0, + decay_rate: float = -0.8, + beta_1: Optional[float] = None, + weight_decay: float = 0.0, + scale_parameter: bool = True, + relative_step: bool = True, + warmup_init: bool = False, + ): + super().__init__() + self.learning_rate = learning_rate + self.eps = eps + self.clip_threshold = clip_threshold + self.decay_rate = decay_rate + self.beta_1 = beta_1 + self.weight_decay = weight_decay + self.scale_parameter = scale_parameter + self.relative_step = relative_step + self.warmup_init = warmup_init + + def _compute_rms(self, inputs): + return mx.sqrt(mx.mean(mx.square(inputs))) + + def _compute_learning_rate(self, step, parameter_rms): + relative_step_size = self.learning_rate + if self.relative_step: + min_step = 1e-6 * step if self.warmup_init else 1e-2 + relative_step_size = min(min_step, 1 / math.sqrt(step)) + + parameter_scale = 1.0 + if self.scale_parameter: + parameter_scale = mx.maximum(self.eps[1], parameter_rms) + return parameter_scale * relative_step_size + + def _approximate_exp_moving_avg(self, exp_avg_sq_row, exp_avg_sq_col): + r_factor = mx.rsqrt( + exp_avg_sq_row / mx.mean(exp_avg_sq_row, axis=-1, keepdims=True) + ) + c_factor = mx.rsqrt(exp_avg_sq_col) + return mx.matmul( + mx.expand_dims(r_factor, axis=-1), mx.expand_dims(c_factor, axis=0) + ) + + def apply_single( + self, gradient: mx.array, parameter: mx.array, state: OptimizerState + ): + """Performs the Adafactor parameter and state update.""" + gradient_shape = gradient.shape + factored = len(gradient_shape) >= 2 + step = state.get("step", 0) + 1 + state["step"] = step + use_first_moment = self.beta_1 is not None + + parameter_rms = self._compute_rms(parameter) + learning_rate = self._compute_learning_rate(step, parameter_rms) + beta_2 = 1.0 - math.pow(step, self.decay_rate) + update = mx.square(gradient) + self.eps[0] + + if factored: + exp_avg_sq_row = state.get( + "exp_avg_sq_row", mx.zeros(gradient_shape[:-1], dtype=gradient.dtype) + ) + exp_avg_sq_col = state.get( + "exp_avg_sq_col", + mx.zeros( + gradient_shape[:-2] + gradient_shape[-1:], dtype=gradient.dtype + ), + ) + exp_avg_sq_row = (beta_2 * exp_avg_sq_row) + ( + (1 - beta_2) * mx.mean(update, axis=-1) + ) + exp_avg_sq_col = (beta_2 * exp_avg_sq_col) + ( + (1 - beta_2) * mx.mean(update, axis=-2) + ) + state["exp_avg_sq_row"] = exp_avg_sq_row + state["exp_avg_sq_col"] = exp_avg_sq_col + update = self._approximate_exp_moving_avg(exp_avg_sq_row, exp_avg_sq_col) + update = update * gradient + else: + exp_avg_sq = state.get("exp_avg_sq", mx.zeros_like(gradient)) + exp_avg_sq = (beta_2 * exp_avg_sq) + ((1 - beta_2) * update) + state["exp_avg_sq"] = exp_avg_sq + update = mx.rsqrt(exp_avg_sq) * gradient + + update = update / mx.maximum( + 1.0, self._compute_rms(update) / self.clip_threshold + ) + update = learning_rate * update + + if use_first_moment: + exp_avg = state.get("exp_avg", mx.zeros_like(gradient)) + exp_avg = (self.beta_1 * exp_avg) + ((1 - self.beta_1) * update) + state["exp_avg"] = exp_avg + update = exp_avg + + if self.weight_decay != 0: + parameter += parameter * (-self.weight_decay * learning_rate) + return parameter - update diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index b0a3165bb..59046184f 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -39,6 +39,24 @@ class TestOptimizers(mlx_tests.MLXTestCase): all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape)) self.assertTrue(all_equal) + def test_adafactor(self): + x = mx.zeros((5, 5)) + grad = mx.ones_like(x) + optimizer = opt.Adafactor() + for _ in range(2): + xp = optimizer.apply_single(grad, x, optimizer.state) + self.assertEqual(xp.dtype, x.dtype) + self.assertEqual(xp.shape, x.shape) + + x = mx.zeros((5, 5), mx.float16) + grad = mx.ones_like(x) + optimizer = opt.Adafactor() + for _ in range(2): + xp = optimizer.apply_single(grad, x, optimizer.state) + self.assertEqual(xp.dtype, x.dtype) + self.assertEqual(xp.shape, x.shape) + self.assertEqual(optimizer.state["step"], 2) + if __name__ == "__main__": unittest.main()