mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Add Lion optimizer (#209)
* Add Lion optimizer * Update acknowledgements also with past contributions
This commit is contained in:
parent
f40d17047d
commit
4912ff3ec2
@ -8,6 +8,7 @@ with a short description of your contribution(s) below. For example:
|
||||
MLX was developed with contributions from the following individuals:
|
||||
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, linear and logistic regression python example.
|
||||
|
||||
# Third-Party Software
|
||||
|
||||
|
@ -44,3 +44,4 @@ model's parameters and the **optimizer state**.
|
||||
Adam
|
||||
AdamW
|
||||
Adamax
|
||||
Lion
|
||||
|
@ -442,3 +442,59 @@ class Adamax(Adam):
|
||||
state["v"] = v
|
||||
|
||||
return parameter - lr * m / (v + eps)
|
||||
|
||||
|
||||
class Lion(Optimizer):
|
||||
r"""Implementation of the Lion optimizer [1].
|
||||
|
||||
Since updates are computed through the sign operation, they tend to
|
||||
have larger norm than for other optimizers such as SGD and Adam.
|
||||
We recommend a learning rate that is 3-10x smaller than AdamW and a
|
||||
weight decay 3-10x larger than AdamW to maintain the strength
|
||||
(lr * wd). Our Lion implementation follows the original paper. In
|
||||
detail,
|
||||
|
||||
[1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv
|
||||
preprint arXiv:2302.06675.
|
||||
|
||||
.. math::
|
||||
|
||||
c_{t + 1} &= \beta_1 m_t + (1 - \beta_1) g_t
|
||||
m_{t + 1} &= \beta_2 m_t + (1 - \beta_2) g_t
|
||||
w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t)
|
||||
|
||||
Args:
|
||||
learning_rate (float): The learning rate :math:`\eta`.
|
||||
betas (Tuple[float, float], optional): The coefficients
|
||||
:math:`(\beta_1, \beta_2)` used for computing the gradient
|
||||
momentum and update direction. Default: ``(0.9, 0.99)``
|
||||
weight_decay (float, optional): The weight decay :math:`\lambda`. Default: ``0.0``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate: float,
|
||||
betas: List[float] = [0.9, 0.99],
|
||||
weight_decay: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.betas = betas
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
"""Performs the Lion parameter update and stores :math:`m`
|
||||
in the optimizer state."""
|
||||
lr = self.learning_rate
|
||||
b1, b2 = self.betas
|
||||
weight_decay = self.weight_decay
|
||||
|
||||
m = state.get("m", gradient)
|
||||
c = b1 * m + (1 - b1) * gradient
|
||||
state["m"] = b2 * m + (1 - b2) * gradient
|
||||
if weight_decay > 0:
|
||||
parameter = (1 - lr * weight_decay) * parameter
|
||||
return parameter - lr * mx.sign(c)
|
||||
|
Loading…
Reference in New Issue
Block a user