diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 35132d514..c9969f8d6 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -8,6 +8,7 @@ with a short description of your contribution(s) below. For example: MLX was developed with contributions from the following individuals: - Juarez Bochi: Fixed bug in cross attention. +- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, linear and logistic regression python example. # Third-Party Software @@ -244,4 +245,4 @@ Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and -limitations under the License. +limitations under the License. \ No newline at end of file diff --git a/docs/src/python/optimizers.rst b/docs/src/python/optimizers.rst index b8e5cfea7..7cc6ef906 100644 --- a/docs/src/python/optimizers.rst +++ b/docs/src/python/optimizers.rst @@ -44,3 +44,4 @@ model's parameters and the **optimizer state**. Adam AdamW Adamax + Lion diff --git a/python/mlx/optimizers.py b/python/mlx/optimizers.py index 161d923f6..4fc2f6eed 100644 --- a/python/mlx/optimizers.py +++ b/python/mlx/optimizers.py @@ -442,3 +442,59 @@ class Adamax(Adam): state["v"] = v return parameter - lr * m / (v + eps) + + +class Lion(Optimizer): + r"""Implementation of the Lion optimizer [1]. + + Since updates are computed through the sign operation, they tend to + have larger norm than for other optimizers such as SGD and Adam. + We recommend a learning rate that is 3-10x smaller than AdamW and a + weight decay 3-10x larger than AdamW to maintain the strength + (lr * wd). Our Lion implementation follows the original paper. In + detail, + + [1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv + preprint arXiv:2302.06675. + + .. math:: + + c_{t + 1} &= \beta_1 m_t + (1 - \beta_1) g_t + m_{t + 1} &= \beta_2 m_t + (1 - \beta_2) g_t + w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t) + + Args: + learning_rate (float): The learning rate :math:`\eta`. + betas (Tuple[float, float], optional): The coefficients + :math:`(\beta_1, \beta_2)` used for computing the gradient + momentum and update direction. Default: ``(0.9, 0.99)`` + weight_decay (float, optional): The weight decay :math:`\lambda`. Default: ``0.0`` + """ + + def __init__( + self, + learning_rate: float, + betas: List[float] = [0.9, 0.99], + weight_decay: float = 0.0, + ): + super().__init__() + + self.learning_rate = learning_rate + self.betas = betas + self.weight_decay = weight_decay + + def apply_single( + self, gradient: mx.array, parameter: mx.array, state: OptimizerState + ): + """Performs the Lion parameter update and stores :math:`m` + in the optimizer state.""" + lr = self.learning_rate + b1, b2 = self.betas + weight_decay = self.weight_decay + + m = state.get("m", gradient) + c = b1 * m + (1 - b1) * gradient + state["m"] = b2 * m + (1 - b2) * gradient + if weight_decay > 0: + parameter = (1 - lr * weight_decay) * parameter + return parameter - lr * mx.sign(c)