mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 02:38:09 +08:00
Support bias correction in Adam and AdamW optimizers (#1640)
This commit is contained in:
@@ -395,10 +395,7 @@ class AdaDelta(Optimizer):
|
||||
|
||||
|
||||
class Adam(Optimizer):
|
||||
r"""The Adam optimizer [1].
|
||||
|
||||
Our Adam implementation follows the original paper and omits the bias
|
||||
correction in the first and second moment estimates. In detail,
|
||||
r"""The Adam optimizer [1]. In detail,
|
||||
|
||||
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
|
||||
optimization. ICLR 2015.
|
||||
@@ -416,6 +413,8 @@ class Adam(Optimizer):
|
||||
gradient and its square. Default: ``(0.9, 0.999)``
|
||||
eps (float, optional): The term :math:`\epsilon` added to the
|
||||
denominator to improve numerical stability. Default: ``1e-8``
|
||||
bias_correction (bool, optional): If set to ``True``, bias correction
|
||||
is applied. Default: ``False``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -423,12 +422,14 @@ class Adam(Optimizer):
|
||||
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||
betas: List[float] = [0.9, 0.999],
|
||||
eps: float = 1e-8,
|
||||
bias_correction: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._maybe_schedule("learning_rate", learning_rate)
|
||||
self.betas = betas
|
||||
self.eps = eps
|
||||
self.bias_correction = bias_correction
|
||||
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""Initialize optimizer state"""
|
||||
@@ -441,6 +442,8 @@ class Adam(Optimizer):
|
||||
lr = self.learning_rate.astype(gradient.dtype)
|
||||
b1, b2 = self.betas
|
||||
eps = self.eps
|
||||
bias_correction = self.bias_correction
|
||||
step = self.step
|
||||
|
||||
m = state["m"]
|
||||
v = state["v"]
|
||||
@@ -449,15 +452,17 @@ class Adam(Optimizer):
|
||||
state["m"] = m
|
||||
state["v"] = v
|
||||
|
||||
return parameter - lr * m / (mx.sqrt(v) + eps)
|
||||
if bias_correction:
|
||||
numerator = lr / (1 - b1**step) * m
|
||||
denominator = mx.sqrt(v) / mx.sqrt(1 - b2**step) + eps
|
||||
return parameter - numerator / denominator
|
||||
else:
|
||||
return parameter - lr * m / (mx.sqrt(v) + eps)
|
||||
|
||||
|
||||
class AdamW(Adam):
|
||||
r"""The AdamW optimizer [1].
|
||||
|
||||
Following the above convention, in contrast with [1], we do not use bias
|
||||
correction in the first and second moments for AdamW. We update the weights
|
||||
with a weight_decay (:math:`\lambda`) value:
|
||||
r"""The AdamW optimizer [1]. We update the weights with a weight_decay
|
||||
(:math:`\lambda`) value:
|
||||
|
||||
[1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay
|
||||
regularization. ICLR 2019.
|
||||
@@ -477,6 +482,8 @@ class AdamW(Adam):
|
||||
denominator to improve numerical stability. Default: ``1e-8``
|
||||
weight_decay (float, optional): The weight decay :math:`\lambda`.
|
||||
Default: ``0``.
|
||||
bias_correction (bool, optional): If set to ``True``, bias correction
|
||||
is applied. Default: ``False``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -485,8 +492,14 @@ class AdamW(Adam):
|
||||
betas: List[float] = [0.9, 0.999],
|
||||
eps: float = 1e-8,
|
||||
weight_decay: float = 0.01,
|
||||
bias_correction: bool = False,
|
||||
):
|
||||
super().__init__(learning_rate=learning_rate, betas=betas, eps=eps)
|
||||
super().__init__(
|
||||
learning_rate=learning_rate,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
bias_correction=bias_correction,
|
||||
)
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
|
Reference in New Issue
Block a user