mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Added Adagrad optimizer (#102)
This commit is contained in:
parent
68bf1d7867
commit
c1e1c1443f
@ -221,3 +221,47 @@ class AdamW(Adam):
|
|||||||
return super().apply_single(
|
return super().apply_single(
|
||||||
gradient, parameter * (1 - self.learning_rate * self.weight_decay), state
|
gradient, parameter * (1 - self.learning_rate * self.weight_decay), state
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Adagrad(Optimizer):
|
||||||
|
r"""Implementation of the Adagrad optimizer [1].
|
||||||
|
|
||||||
|
Our Adagrad implementation follows the original paper. In detail,
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
v_{t+1} &= v_t + g_t^2 \\
|
||||||
|
w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1} + \epsilon}}
|
||||||
|
|
||||||
|
[1]: Duchi, J., Hazan, E. and Singer, Y., 2011. Adaptive subgradient methods
|
||||||
|
for online learning and stochastic optimization. JMLR 2011.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, learning_rate: float, eps: float = 1e-8):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
if self.learning_rate < 0.0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Adagrad learning rate should be >=0, {self.learning_rate} was provided instead"
|
||||||
|
)
|
||||||
|
if self.eps < 0.0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Adagrad epsilon should be >0, {self.eps} was provided instead"
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_single(
|
||||||
|
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||||
|
):
|
||||||
|
"""Performs the Adagrad parameter update and stores :math:`v` in the
|
||||||
|
optimizer state."""
|
||||||
|
lr = self.learning_rate
|
||||||
|
eps = self.eps
|
||||||
|
|
||||||
|
v = state.get("v", mx.zeros_like(gradient))
|
||||||
|
v = v + mx.square(gradient)
|
||||||
|
state["v"] = v
|
||||||
|
|
||||||
|
return parameter - lr * gradient / (mx.sqrt(v) + eps)
|
||||||
|
Loading…
Reference in New Issue
Block a user