AdamW implementation (#72)

* AdamW implementation without bias correction
* Makes use of the underlying Adam implementation
This commit is contained in:
Joe Barrow 2023-12-08 17:45:34 -05:00 committed by GitHub
parent 5b9be57ac3
commit 69a24e6a1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -152,3 +152,42 @@ class Adam(Optimizer):
state["v"] = v
return parameter - lr * m / (mx.sqrt(v) + eps)
class AdamW(Adam):
r"""Implementation of the AdamW optimizer [1].
Following the above convention, in contrast with [1], we do not use bias
correction in the first and second moments for AdamW. We update the weights
with a weight_decay (λ) value:
.. math::
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + \lambda w_t)
[1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay
regularization. ICLR 2019.
"""
def __init__(
self,
learning_rate: float,
betas: List[float] = [0.9, 0.999],
eps: float = 1e-8,
weight_decay: float = 0.01,
):
super().__init__(learning_rate=learning_rate, betas=betas, eps=eps)
self.weight_decay = weight_decay
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""Performs the AdamW parameter update by modifying the parameters
passed into Adam.
"""
return super().apply_single(
gradient, parameter * (1 - self.learning_rate * self.weight_decay), state
)