mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Add Lion optimizer (#209)
* Add Lion optimizer * Update acknowledgements also with past contributions
This commit is contained in:
parent
f40d17047d
commit
4912ff3ec2
@ -8,6 +8,7 @@ with a short description of your contribution(s) below. For example:
|
|||||||
MLX was developed with contributions from the following individuals:
|
MLX was developed with contributions from the following individuals:
|
||||||
|
|
||||||
- Juarez Bochi: Fixed bug in cross attention.
|
- Juarez Bochi: Fixed bug in cross attention.
|
||||||
|
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, linear and logistic regression python example.
|
||||||
|
|
||||||
# Third-Party Software
|
# Third-Party Software
|
||||||
|
|
||||||
|
@ -44,3 +44,4 @@ model's parameters and the **optimizer state**.
|
|||||||
Adam
|
Adam
|
||||||
AdamW
|
AdamW
|
||||||
Adamax
|
Adamax
|
||||||
|
Lion
|
||||||
|
@ -442,3 +442,59 @@ class Adamax(Adam):
|
|||||||
state["v"] = v
|
state["v"] = v
|
||||||
|
|
||||||
return parameter - lr * m / (v + eps)
|
return parameter - lr * m / (v + eps)
|
||||||
|
|
||||||
|
|
||||||
|
class Lion(Optimizer):
|
||||||
|
r"""Implementation of the Lion optimizer [1].
|
||||||
|
|
||||||
|
Since updates are computed through the sign operation, they tend to
|
||||||
|
have larger norm than for other optimizers such as SGD and Adam.
|
||||||
|
We recommend a learning rate that is 3-10x smaller than AdamW and a
|
||||||
|
weight decay 3-10x larger than AdamW to maintain the strength
|
||||||
|
(lr * wd). Our Lion implementation follows the original paper. In
|
||||||
|
detail,
|
||||||
|
|
||||||
|
[1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv
|
||||||
|
preprint arXiv:2302.06675.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
c_{t + 1} &= \beta_1 m_t + (1 - \beta_1) g_t
|
||||||
|
m_{t + 1} &= \beta_2 m_t + (1 - \beta_2) g_t
|
||||||
|
w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
learning_rate (float): The learning rate :math:`\eta`.
|
||||||
|
betas (Tuple[float, float], optional): The coefficients
|
||||||
|
:math:`(\beta_1, \beta_2)` used for computing the gradient
|
||||||
|
momentum and update direction. Default: ``(0.9, 0.99)``
|
||||||
|
weight_decay (float, optional): The weight decay :math:`\lambda`. Default: ``0.0``
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
learning_rate: float,
|
||||||
|
betas: List[float] = [0.9, 0.99],
|
||||||
|
weight_decay: float = 0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
self.betas = betas
|
||||||
|
self.weight_decay = weight_decay
|
||||||
|
|
||||||
|
def apply_single(
|
||||||
|
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||||
|
):
|
||||||
|
"""Performs the Lion parameter update and stores :math:`m`
|
||||||
|
in the optimizer state."""
|
||||||
|
lr = self.learning_rate
|
||||||
|
b1, b2 = self.betas
|
||||||
|
weight_decay = self.weight_decay
|
||||||
|
|
||||||
|
m = state.get("m", gradient)
|
||||||
|
c = b1 * m + (1 - b1) * gradient
|
||||||
|
state["m"] = b2 * m + (1 - b2) * gradient
|
||||||
|
if weight_decay > 0:
|
||||||
|
parameter = (1 - lr * weight_decay) * parameter
|
||||||
|
return parameter - lr * mx.sign(c)
|
||||||
|
Loading…
Reference in New Issue
Block a user