Add Lion optimizer (#209)

* Add Lion optimizer
* Update acknowledgements also with past contributions
This commit is contained in:
Justin Deschenaux 2023-12-20 22:54:58 +01:00 committed by GitHub
parent f40d17047d
commit 4912ff3ec2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 1 deletions

View File

@ -8,6 +8,7 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals: MLX was developed with contributions from the following individuals:
- Juarez Bochi: Fixed bug in cross attention. - Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, linear and logistic regression python example.
# Third-Party Software # Third-Party Software
@ -244,4 +245,4 @@ Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.

View File

@ -44,3 +44,4 @@ model's parameters and the **optimizer state**.
Adam Adam
AdamW AdamW
Adamax Adamax
Lion

View File

@ -442,3 +442,59 @@ class Adamax(Adam):
state["v"] = v state["v"] = v
return parameter - lr * m / (v + eps) return parameter - lr * m / (v + eps)
class Lion(Optimizer):
r"""Implementation of the Lion optimizer [1].
Since updates are computed through the sign operation, they tend to
have larger norm than for other optimizers such as SGD and Adam.
We recommend a learning rate that is 3-10x smaller than AdamW and a
weight decay 3-10x larger than AdamW to maintain the strength
(lr * wd). Our Lion implementation follows the original paper. In
detail,
[1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv
preprint arXiv:2302.06675.
.. math::
c_{t + 1} &= \beta_1 m_t + (1 - \beta_1) g_t
m_{t + 1} &= \beta_2 m_t + (1 - \beta_2) g_t
w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t)
Args:
learning_rate (float): The learning rate :math:`\eta`.
betas (Tuple[float, float], optional): The coefficients
:math:`(\beta_1, \beta_2)` used for computing the gradient
momentum and update direction. Default: ``(0.9, 0.99)``
weight_decay (float, optional): The weight decay :math:`\lambda`. Default: ``0.0``
"""
def __init__(
self,
learning_rate: float,
betas: List[float] = [0.9, 0.99],
weight_decay: float = 0.0,
):
super().__init__()
self.learning_rate = learning_rate
self.betas = betas
self.weight_decay = weight_decay
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""Performs the Lion parameter update and stores :math:`m`
in the optimizer state."""
lr = self.learning_rate
b1, b2 = self.betas
weight_decay = self.weight_decay
m = state.get("m", gradient)
c = b1 * m + (1 - b1) * gradient
state["m"] = b2 * m + (1 - b2) * gradient
if weight_decay > 0:
parameter = (1 - lr * weight_decay) * parameter
return parameter - lr * mx.sign(c)