Add Lion optimizer (#209)

* Add Lion optimizer
* Update acknowledgements also with past contributions
This commit is contained in:
Justin Deschenaux 2023-12-20 22:54:58 +01:00 committed by GitHub
parent f40d17047d
commit 4912ff3ec2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 1 deletions

View File

@ -8,6 +8,7 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals:
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, linear and logistic regression python example.
# Third-Party Software

View File

@ -44,3 +44,4 @@ model's parameters and the **optimizer state**.
Adam
AdamW
Adamax
Lion

View File

@ -442,3 +442,59 @@ class Adamax(Adam):
state["v"] = v
return parameter - lr * m / (v + eps)
class Lion(Optimizer):
r"""Implementation of the Lion optimizer [1].
Since updates are computed through the sign operation, they tend to
have larger norm than for other optimizers such as SGD and Adam.
We recommend a learning rate that is 3-10x smaller than AdamW and a
weight decay 3-10x larger than AdamW to maintain the strength
(lr * wd). Our Lion implementation follows the original paper. In
detail,
[1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv
preprint arXiv:2302.06675.
.. math::
c_{t + 1} &= \beta_1 m_t + (1 - \beta_1) g_t
m_{t + 1} &= \beta_2 m_t + (1 - \beta_2) g_t
w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t)
Args:
learning_rate (float): The learning rate :math:`\eta`.
betas (Tuple[float, float], optional): The coefficients
:math:`(\beta_1, \beta_2)` used for computing the gradient
momentum and update direction. Default: ``(0.9, 0.99)``
weight_decay (float, optional): The weight decay :math:`\lambda`. Default: ``0.0``
"""
def __init__(
self,
learning_rate: float,
betas: List[float] = [0.9, 0.99],
weight_decay: float = 0.0,
):
super().__init__()
self.learning_rate = learning_rate
self.betas = betas
self.weight_decay = weight_decay
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""Performs the Lion parameter update and stores :math:`m`
in the optimizer state."""
lr = self.learning_rate
b1, b2 = self.betas
weight_decay = self.weight_decay
m = state.get("m", gradient)
c = b1 * m + (1 - b1) * gradient
state["m"] = b2 * m + (1 - b2) * gradient
if weight_decay > 0:
parameter = (1 - lr * weight_decay) * parameter
return parameter - lr * mx.sign(c)