diff --git a/python/mlx/optimizers.py b/python/mlx/optimizers.py index fd5d9c732..ae981c7f3 100644 --- a/python/mlx/optimizers.py +++ b/python/mlx/optimizers.py @@ -221,3 +221,47 @@ class AdamW(Adam): return super().apply_single( gradient, parameter * (1 - self.learning_rate * self.weight_decay), state ) + + +class Adagrad(Optimizer): + r"""Implementation of the Adagrad optimizer [1]. + + Our Adagrad implementation follows the original paper. In detail, + + .. math:: + + v_{t+1} &= v_t + g_t^2 \\ + w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1} + \epsilon}} + + [1]: Duchi, J., Hazan, E. and Singer, Y., 2011. Adaptive subgradient methods + for online learning and stochastic optimization. JMLR 2011. + """ + + def __init__(self, learning_rate: float, eps: float = 1e-8): + super().__init__() + + self.learning_rate = learning_rate + self.eps = eps + + if self.learning_rate < 0.0: + raise ValueError( + f"Adagrad learning rate should be >=0, {self.learning_rate} was provided instead" + ) + if self.eps < 0.0: + raise ValueError( + f"Adagrad epsilon should be >0, {self.eps} was provided instead" + ) + + def apply_single( + self, gradient: mx.array, parameter: mx.array, state: OptimizerState + ): + """Performs the Adagrad parameter update and stores :math:`v` in the + optimizer state.""" + lr = self.learning_rate + eps = self.eps + + v = state.get("v", mx.zeros_like(gradient)) + v = v + mx.square(gradient) + state["v"] = v + + return parameter - lr * gradient / (mx.sqrt(v) + eps)