mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-27 00:09:17 +08:00
AdamW implementation (#72)
* AdamW implementation without bias correction * Makes use of the underlying Adam implementation
This commit is contained in:
parent
5b9be57ac3
commit
69a24e6a1e
@ -152,3 +152,42 @@ class Adam(Optimizer):
|
||||
state["v"] = v
|
||||
|
||||
return parameter - lr * m / (mx.sqrt(v) + eps)
|
||||
|
||||
|
||||
class AdamW(Adam):
|
||||
r"""Implementation of the AdamW optimizer [1].
|
||||
|
||||
Following the above convention, in contrast with [1], we do not use bias
|
||||
correction in the first and second moments for AdamW. We update the weights
|
||||
with a weight_decay (λ) value:
|
||||
|
||||
.. math::
|
||||
|
||||
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
|
||||
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
|
||||
w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + \lambda w_t)
|
||||
|
||||
[1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay
|
||||
regularization. ICLR 2019.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate: float,
|
||||
betas: List[float] = [0.9, 0.999],
|
||||
eps: float = 1e-8,
|
||||
weight_decay: float = 0.01,
|
||||
):
|
||||
super().__init__(learning_rate=learning_rate, betas=betas, eps=eps)
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
"""Performs the AdamW parameter update by modifying the parameters
|
||||
passed into Adam.
|
||||
"""
|
||||
|
||||
return super().apply_single(
|
||||
gradient, parameter * (1 - self.learning_rate * self.weight_decay), state
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user