mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00

* spelling: accumulates Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: across Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: additional Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: against Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: among Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: array Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: at least Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: available Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: axes Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: basically Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: bfloat Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: bounds Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: broadcast Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: buffer Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: class Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: coefficients Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: collision Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: combinations Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: committing Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: computation Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: consider Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: constructing Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: conversions Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: correctly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: corresponding Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: declaration Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: default Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: dependency Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: destination Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: destructor Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: dimensions Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: divided Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: element-wise Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: elements Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: endianness Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: equivalent Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: explicitly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: github Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: indices Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: irregularly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: memory Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: metallib Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: negative Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: notable Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: optional Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: otherwise Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: overridden Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: partially Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: partition Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: perform Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: perturbations Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: positively Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: primitive Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: repeat Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: repeats Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: respect Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: respectively Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: result Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: rounding Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: separate Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: skipping Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: structure Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: the Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: transpose Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unnecessary Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unneeded Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unsupported Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> --------- Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
501 lines
17 KiB
Python
501 lines
17 KiB
Python
# Copyright © 2023 Apple Inc.
|
|
|
|
import math
|
|
from typing import List
|
|
|
|
import mlx.core as mx
|
|
from mlx.utils import tree_map
|
|
|
|
|
|
class OptimizerState(dict):
|
|
"""The optimizer state implements a recursively defined
|
|
:class:`collections.defaultdict`, namely a missing key in an optimizer
|
|
state is an :class:`OptimizerState`.
|
|
|
|
.. note::
|
|
:meth:`OptimizerState.get` in contrast to a normal dictionary also sets
|
|
the key to the ``default`` value if the ``key`` was not present in the
|
|
dictionary.
|
|
"""
|
|
|
|
def __getitem__(self, key):
|
|
if key not in self:
|
|
self[key] = OptimizerState()
|
|
return super().__getitem__(key)
|
|
|
|
def get(self, key, default):
|
|
"""If ``key`` doesn't exist set its value to ``default`` and then return it."""
|
|
if key not in self:
|
|
self[key] = default
|
|
return super().__getitem__(key)
|
|
|
|
|
|
class Optimizer:
|
|
"""The base class for all optimizers. It allows us to implement an
|
|
optimizer on a per-parameter basis and apply it to a parameter tree.
|
|
|
|
Attributes:
|
|
state (OptimizerState): It holds the optimizer's state dictionary.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.state = OptimizerState()
|
|
|
|
def update(self, model: "mlx.nn.Module", gradients: dict):
|
|
"""Apply the gradients to the parameters of the model and update the
|
|
model with the new parameters.
|
|
|
|
Args:
|
|
model (mlx.nn.Module): An mlx module to be updated.
|
|
gradients (dict): A Python tree of gradients, most likely computed
|
|
via :func:`mlx.nn.value_and_grad`.
|
|
"""
|
|
model.update(self.apply_gradients(gradients, model))
|
|
|
|
def apply_gradients(self, gradients: dict, model: dict):
|
|
"""Apply the gradients to the parameters and return the updated parameters.
|
|
|
|
Can be used to update a model via
|
|
``model.update(opt.apply_gradients(grads, model))`` which is precisely
|
|
how :meth:`Optimizer.update` is implemented.
|
|
|
|
Args:
|
|
gradients (dict): A Python tree of gradients.
|
|
model (dict): A Python tree of parameters. It can be a superset of
|
|
the gradients. In that case the returned python tree
|
|
will be of the same structure as the gradients.
|
|
"""
|
|
return tree_map(self.apply_single, gradients, model, self.state)
|
|
|
|
def apply_single(
|
|
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
|
):
|
|
"""To be extended by the children classes to implement each optimizer's
|
|
update."""
|
|
raise NotImplementedError()
|
|
|
|
|
|
class SGD(Optimizer):
|
|
r"""Stochastic gradient descent optimizer.
|
|
|
|
Updates a parameter :math:`w` with a gradient :math:`g` as follows
|
|
|
|
.. math::
|
|
|
|
v_{t+1} &= \mu v_t + (1 - \tau) g_t \\
|
|
w_{t+1} &= w_t - \lambda v_{t+1}
|
|
|
|
Args:
|
|
learning_rate (float): The learning rate :math:`\lambda`.
|
|
momentum (float, optional): The momentum strength :math:`\mu`. Default: ``0``
|
|
weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0``
|
|
dampening (float, optional): Dampening for momentum :math:`\tau`. Default: ``0``
|
|
nesterov (bool, optional): Enables Nesterov momentum. Default: ``False``
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
learning_rate: float,
|
|
momentum: float = 0.0,
|
|
weight_decay: float = 0.0,
|
|
dampening: float = 0.0,
|
|
nesterov: bool = False,
|
|
):
|
|
if nesterov and (momentum <= 0 or dampening != 0):
|
|
raise ValueError(
|
|
"Nesterov momentum requires a momentum and zero dampening."
|
|
)
|
|
super().__init__()
|
|
|
|
self.learning_rate = learning_rate
|
|
self.momentum = momentum
|
|
self.weight_decay = weight_decay
|
|
self.dampening = dampening
|
|
self.nesterov = nesterov
|
|
|
|
def apply_single(
|
|
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
|
):
|
|
"""Performs the SGD parameter update and stores :math:`v` in the
|
|
optimizer state."""
|
|
if self.momentum <= 0:
|
|
return parameter - self.learning_rate * gradient
|
|
|
|
v = state.get("v", mx.zeros_like(gradient))
|
|
|
|
if self.weight_decay != 0:
|
|
gradient += self.weight_decay * parameter
|
|
|
|
v = self.momentum * v
|
|
if self.dampening > 0:
|
|
v += (1 - self.dampening) * gradient
|
|
else:
|
|
v += gradient
|
|
|
|
if self.nesterov:
|
|
update = gradient + self.momentum * v
|
|
else:
|
|
update = v
|
|
state["v"] = v
|
|
return parameter - self.learning_rate * update
|
|
|
|
|
|
class RMSprop(Optimizer):
|
|
r"""Implementation of the RMSprop optimizer [1].
|
|
|
|
[1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning
|
|
|
|
.. math::
|
|
|
|
v_{t+1} &= \alpha v_t + (1 - \alpha) g_t^2 \\
|
|
w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
|
|
|
|
Args:
|
|
learning_rate (float): The learning rate :math:`\lambda`.
|
|
alpha (float, optional): The smoothing constant :math:`\alpha`.
|
|
Default: ``0.99``
|
|
eps (float, optional): The term :math:`\epsilon` added to the denominator
|
|
to improve numerical stability. Default: ``1e-8``
|
|
"""
|
|
|
|
def __init__(self, learning_rate: float, alpha: float = 0.99, eps: float = 1e-8):
|
|
super().__init__()
|
|
|
|
self.learning_rate = learning_rate
|
|
self.alpha = alpha
|
|
self.eps = eps
|
|
|
|
if self.alpha < 0.0:
|
|
raise ValueError(
|
|
f"RMSprop alpha should be >=0, {self.alpha} was provided instead"
|
|
)
|
|
if self.eps < 0.0:
|
|
raise ValueError(
|
|
f"RMSprop epsilon should be >0, {self.eps} was provided instead"
|
|
)
|
|
|
|
def apply_single(
|
|
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
|
):
|
|
"""Performs the RMSprop parameter update and stores :math:`v` in the optimizer state."""
|
|
lr = self.learning_rate
|
|
alpha = self.alpha
|
|
eps = self.eps
|
|
|
|
v = state.get("v", mx.zeros_like(gradient))
|
|
v = alpha * v + (1 - alpha) * mx.square(gradient)
|
|
state["v"] = v
|
|
|
|
return parameter - lr * gradient / (mx.sqrt(v) + eps)
|
|
|
|
|
|
class Adagrad(Optimizer):
|
|
r"""Implementation of the Adagrad optimizer [1].
|
|
|
|
Our Adagrad implementation follows the original paper. In detail,
|
|
|
|
[1]: Duchi, J., Hazan, E. and Singer, Y., 2011. Adaptive subgradient methods
|
|
for online learning and stochastic optimization. JMLR 2011.
|
|
|
|
.. math::
|
|
|
|
v_{t+1} &= v_t + g_t^2 \\
|
|
w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
|
|
|
|
Args:
|
|
learning_rate (float): The learning rate :math:`\lambda`.
|
|
eps (float, optional): The term :math:`\epsilon` added to the
|
|
denominator to improve numerical stability. Default: ``1e-8``
|
|
"""
|
|
|
|
def __init__(self, learning_rate: float, eps: float = 1e-8):
|
|
super().__init__()
|
|
|
|
self.learning_rate = learning_rate
|
|
self.eps = eps
|
|
|
|
if self.eps < 0.0:
|
|
raise ValueError(
|
|
f"Adagrad epsilon should be >0, {self.eps} was provided instead"
|
|
)
|
|
|
|
def apply_single(
|
|
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
|
):
|
|
"""Performs the Adagrad parameter update and stores :math:`v` in the
|
|
optimizer state."""
|
|
lr = self.learning_rate
|
|
eps = self.eps
|
|
|
|
v = state.get("v", mx.zeros_like(gradient))
|
|
v = v + mx.square(gradient)
|
|
state["v"] = v
|
|
|
|
return parameter - lr * gradient / (mx.sqrt(v) + eps)
|
|
|
|
|
|
class AdaDelta(Optimizer):
|
|
r"""Implementation of the AdaDelta optimizer with learning rate[1].
|
|
|
|
Our AdaDelta implementation follows the original paper. In detail,
|
|
|
|
[1]: Zeiler, M.D., 2012. ADADELTA: an adaptive learning rate method. arXiv preprint arXiv:1212.5701.
|
|
|
|
.. math::
|
|
|
|
v_{t+1} &= \rho v_t + (1 - \rho) g_t^2 \\
|
|
\Delta w_{t+1} &= \frac{\sqrt{u_t + \epsilon}}{\sqrt{v_{t+1} + \epsilon}} g_t \\
|
|
u_{t+1} &= \rho u_t + (1 - \rho) \Delta w_{t+1}^2 \\
|
|
w_{t+1} &= w_t - \lambda \Delta w_{t+1}
|
|
|
|
Args:
|
|
learning_rate (float): The learning rate :math:`\lambda`.
|
|
rho (float, optional): The coefficient :math:`\rho` used for computing a
|
|
running average of squared gradients. Default: ``0.9``
|
|
eps (float, optional): The term :math:`\epsilon` added to the denominator to improve
|
|
numerical stability. Default: `1e-8`
|
|
"""
|
|
|
|
def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6):
|
|
super().__init__()
|
|
|
|
self.learning_rate = learning_rate
|
|
self.rho = rho
|
|
self.eps = eps
|
|
if self.rho < 0.0:
|
|
raise ValueError(
|
|
f"AdaDelta rho should be >=0, {self.rho} was provided instead"
|
|
)
|
|
if self.eps < 0.0:
|
|
raise ValueError(
|
|
f"AdaDelta epsilon should be >0, {self.eps} was provided instead"
|
|
)
|
|
|
|
def apply_single(
|
|
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
|
):
|
|
"""Performs the AdaDelta parameter update and stores :math:`v` and
|
|
:math:`u` in the optimizer state."""
|
|
lr = self.learning_rate
|
|
rho = self.rho
|
|
eps = self.eps
|
|
|
|
v = state.get("v", mx.zeros_like(gradient))
|
|
u = state.get("s", mx.zeros_like(gradient))
|
|
|
|
v = rho * v + (1 - rho) * mx.square(gradient)
|
|
d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient
|
|
u = rho * u + (1 - rho) * mx.square(d)
|
|
|
|
state["v"] = v
|
|
state["u"] = u
|
|
|
|
return parameter - lr * d
|
|
|
|
|
|
class Adam(Optimizer):
|
|
r"""Implementation of the Adam optimizer [1].
|
|
|
|
Our Adam implementation follows the original paper and omits the bias
|
|
correction in the first and second moment estimates. In detail,
|
|
|
|
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
|
|
optimization. ICLR 2015.
|
|
|
|
.. math::
|
|
|
|
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
|
|
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
|
|
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}}
|
|
|
|
Args:
|
|
learning_rate (float): The learning rate :math:`\lambda`.
|
|
betas (Tuple[float, float], optional): The coefficients
|
|
:math:`(\beta_1, \beta_2)` used for computing running averages of the
|
|
gradient and its square. Default: ``(0.9, 0.999)``
|
|
eps (float, optional): The term :math:`\epsilon` added to the
|
|
denominator to improve numerical stability. Default: ``1e-8``
|
|
"""
|
|
|
|
def __init__(
|
|
self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8
|
|
):
|
|
super().__init__()
|
|
|
|
self.learning_rate = learning_rate
|
|
self.betas = betas
|
|
self.eps = eps
|
|
|
|
def apply_single(
|
|
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
|
):
|
|
"""Performs the Adam parameter update and stores :math:`v` and
|
|
:math:`m` in the optimizer state."""
|
|
lr = self.learning_rate
|
|
b1, b2 = self.betas
|
|
eps = self.eps
|
|
|
|
m = state.get("m", gradient)
|
|
v = state.get("v", mx.square(gradient))
|
|
m = b1 * m + (1 - b1) * gradient
|
|
v = b2 * v + (1 - b2) * mx.square(gradient)
|
|
state["m"] = m
|
|
state["v"] = v
|
|
|
|
return parameter - lr * m / (mx.sqrt(v) + eps)
|
|
|
|
|
|
class AdamW(Adam):
|
|
r"""Implementation of the AdamW optimizer [1].
|
|
|
|
Following the above convention, in contrast with [1], we do not use bias
|
|
correction in the first and second moments for AdamW. We update the weights
|
|
with a weight_decay (:math:`\lambda`) value:
|
|
|
|
[1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay
|
|
regularization. ICLR 2019.
|
|
|
|
.. math::
|
|
|
|
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
|
|
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
|
|
w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + \lambda w_t)
|
|
|
|
Args:
|
|
learning_rate (float): The learning rate :math:`\alpha`.
|
|
betas (Tuple[float, float], optional): The coefficients
|
|
:math:`(\beta_1, \beta_2)` used for computing running averages of the
|
|
gradient and its square. Default: ``(0.9, 0.999)``
|
|
eps (float, optional): The term :math:`\epsilon` added to the
|
|
denominator to improve numerical stability. Default: ``1e-8``
|
|
weight_decay (float, optional): The weight decay :math:`\lambda`.
|
|
Default: ``0``.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
learning_rate: float,
|
|
betas: List[float] = [0.9, 0.999],
|
|
eps: float = 1e-8,
|
|
weight_decay: float = 0.01,
|
|
):
|
|
super().__init__(learning_rate=learning_rate, betas=betas, eps=eps)
|
|
self.weight_decay = weight_decay
|
|
|
|
def apply_single(
|
|
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
|
):
|
|
"""Performs the AdamW parameter update by modifying the parameters
|
|
passed into Adam.
|
|
"""
|
|
|
|
return super().apply_single(
|
|
gradient, parameter * (1 - self.learning_rate * self.weight_decay), state
|
|
)
|
|
|
|
|
|
class Adamax(Adam):
|
|
r"""Implementation of the Adamax optimizer. It is a variant of Adam based
|
|
on the infinity norm [1].
|
|
|
|
Our Adam implementation follows the original paper and omits the bias
|
|
correction in the first and second moment estimates. In detail,
|
|
|
|
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
|
|
optimization. ICLR 2015.
|
|
|
|
.. math::
|
|
|
|
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
|
|
v_{t+1} &= \max(\beta_2 v_t, |g_t|) \\
|
|
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{v_{t+1} + \epsilon}
|
|
|
|
Args:
|
|
learning_rate (float): The learning rate :math:`\lambda`.
|
|
betas (Tuple[float, float], optional): The coefficients
|
|
:math:`(\beta_1, \beta_2)` used for computing running averages of the
|
|
gradient and its square. Default: ``(0.9, 0.999)``
|
|
eps (float, optional): The term :math:`\epsilon` added to the
|
|
denominator to improve numerical stability. Default: ``1e-8``
|
|
"""
|
|
|
|
def __init__(
|
|
self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8
|
|
):
|
|
super().__init__(learning_rate, betas, eps)
|
|
|
|
def apply_single(
|
|
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
|
):
|
|
"""Performs the Adamax parameter update and stores :math:`v` and
|
|
:math:`m` in the optimizer state."""
|
|
lr = self.learning_rate
|
|
b1, b2 = self.betas
|
|
eps = self.eps
|
|
|
|
m = state.get("m", mx.zeros_like(gradient))
|
|
v = state.get("v", mx.zeros_like(gradient))
|
|
|
|
m = b1 * m + (1 - b1) * gradient
|
|
v = mx.maximum(b2 * v, mx.abs(gradient))
|
|
state["m"] = m
|
|
state["v"] = v
|
|
|
|
return parameter - lr * m / (v + eps)
|
|
|
|
|
|
class Lion(Optimizer):
|
|
r"""Implementation of the Lion optimizer [1].
|
|
|
|
Since updates are computed through the sign operation, they tend to
|
|
have larger norm than for other optimizers such as SGD and Adam.
|
|
We recommend a learning rate that is 3-10x smaller than AdamW and a
|
|
weight decay 3-10x larger than AdamW to maintain the strength
|
|
(lr * wd). Our Lion implementation follows the original paper. In
|
|
detail,
|
|
|
|
[1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv
|
|
preprint arXiv:2302.06675.
|
|
|
|
.. math::
|
|
|
|
c_{t + 1} &= \beta_1 m_t + (1 - \beta_1) g_t
|
|
m_{t + 1} &= \beta_2 m_t + (1 - \beta_2) g_t
|
|
w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t)
|
|
|
|
Args:
|
|
learning_rate (float): The learning rate :math:`\eta`.
|
|
betas (Tuple[float, float], optional): The coefficients
|
|
:math:`(\beta_1, \beta_2)` used for computing the gradient
|
|
momentum and update direction. Default: ``(0.9, 0.99)``
|
|
weight_decay (float, optional): The weight decay :math:`\lambda`. Default: ``0.0``
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
learning_rate: float,
|
|
betas: List[float] = [0.9, 0.99],
|
|
weight_decay: float = 0.0,
|
|
):
|
|
super().__init__()
|
|
|
|
self.learning_rate = learning_rate
|
|
self.betas = betas
|
|
self.weight_decay = weight_decay
|
|
|
|
def apply_single(
|
|
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
|
):
|
|
"""Performs the Lion parameter update and stores :math:`m`
|
|
in the optimizer state."""
|
|
lr = self.learning_rate
|
|
b1, b2 = self.betas
|
|
weight_decay = self.weight_decay
|
|
|
|
m = state.get("m", gradient)
|
|
c = b1 * m + (1 - b1) * gradient
|
|
state["m"] = b2 * m + (1 - b2) * gradient
|
|
if weight_decay > 0:
|
|
parameter = (1 - lr * weight_decay) * parameter
|
|
return parameter - lr * mx.sign(c)
|