diff --git a/python/mlx/optimizers.py b/python/mlx/optimizers.py index 4fb48935b..9a77328d3 100644 --- a/python/mlx/optimizers.py +++ b/python/mlx/optimizers.py @@ -152,3 +152,42 @@ class Adam(Optimizer): state["v"] = v return parameter - lr * m / (mx.sqrt(v) + eps) + + +class AdamW(Adam): + r"""Implementation of the AdamW optimizer [1]. + + Following the above convention, in contrast with [1], we do not use bias + correction in the first and second moments for AdamW. We update the weights + with a weight_decay (λ) value: + + .. math:: + + m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\ + v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\ + w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + \lambda w_t) + + [1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay + regularization. ICLR 2019. + """ + + def __init__( + self, + learning_rate: float, + betas: List[float] = [0.9, 0.999], + eps: float = 1e-8, + weight_decay: float = 0.01, + ): + super().__init__(learning_rate=learning_rate, betas=betas, eps=eps) + self.weight_decay = weight_decay + + def apply_single( + self, gradient: mx.array, parameter: mx.array, state: OptimizerState + ): + """Performs the AdamW parameter update by modifying the parameters + passed into Adam. + """ + + return super().apply_single( + gradient, parameter * (1 - self.learning_rate * self.weight_decay), state + )