mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
Support bias correction in Adam and AdamW optimizers (#1640)
This commit is contained in:
parent
d0b6cb0425
commit
fd3377dd1f
@ -395,10 +395,7 @@ class AdaDelta(Optimizer):
|
||||
|
||||
|
||||
class Adam(Optimizer):
|
||||
r"""The Adam optimizer [1].
|
||||
|
||||
Our Adam implementation follows the original paper and omits the bias
|
||||
correction in the first and second moment estimates. In detail,
|
||||
r"""The Adam optimizer [1]. In detail,
|
||||
|
||||
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
|
||||
optimization. ICLR 2015.
|
||||
@ -416,6 +413,8 @@ class Adam(Optimizer):
|
||||
gradient and its square. Default: ``(0.9, 0.999)``
|
||||
eps (float, optional): The term :math:`\epsilon` added to the
|
||||
denominator to improve numerical stability. Default: ``1e-8``
|
||||
bias_correction (bool, optional): If set to ``True``, bias correction
|
||||
is applied. Default: ``False``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -423,12 +422,14 @@ class Adam(Optimizer):
|
||||
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||
betas: List[float] = [0.9, 0.999],
|
||||
eps: float = 1e-8,
|
||||
bias_correction: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._maybe_schedule("learning_rate", learning_rate)
|
||||
self.betas = betas
|
||||
self.eps = eps
|
||||
self.bias_correction = bias_correction
|
||||
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""Initialize optimizer state"""
|
||||
@ -441,6 +442,8 @@ class Adam(Optimizer):
|
||||
lr = self.learning_rate.astype(gradient.dtype)
|
||||
b1, b2 = self.betas
|
||||
eps = self.eps
|
||||
bias_correction = self.bias_correction
|
||||
step = self.step
|
||||
|
||||
m = state["m"]
|
||||
v = state["v"]
|
||||
@ -449,15 +452,17 @@ class Adam(Optimizer):
|
||||
state["m"] = m
|
||||
state["v"] = v
|
||||
|
||||
return parameter - lr * m / (mx.sqrt(v) + eps)
|
||||
if bias_correction:
|
||||
numerator = lr / (1 - b1**step) * m
|
||||
denominator = mx.sqrt(v) / mx.sqrt(1 - b2**step) + eps
|
||||
return parameter - numerator / denominator
|
||||
else:
|
||||
return parameter - lr * m / (mx.sqrt(v) + eps)
|
||||
|
||||
|
||||
class AdamW(Adam):
|
||||
r"""The AdamW optimizer [1].
|
||||
|
||||
Following the above convention, in contrast with [1], we do not use bias
|
||||
correction in the first and second moments for AdamW. We update the weights
|
||||
with a weight_decay (:math:`\lambda`) value:
|
||||
r"""The AdamW optimizer [1]. We update the weights with a weight_decay
|
||||
(:math:`\lambda`) value:
|
||||
|
||||
[1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay
|
||||
regularization. ICLR 2019.
|
||||
@ -477,6 +482,8 @@ class AdamW(Adam):
|
||||
denominator to improve numerical stability. Default: ``1e-8``
|
||||
weight_decay (float, optional): The weight decay :math:`\lambda`.
|
||||
Default: ``0``.
|
||||
bias_correction (bool, optional): If set to ``True``, bias correction
|
||||
is applied. Default: ``False``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -485,8 +492,14 @@ class AdamW(Adam):
|
||||
betas: List[float] = [0.9, 0.999],
|
||||
eps: float = 1e-8,
|
||||
weight_decay: float = 0.01,
|
||||
bias_correction: bool = False,
|
||||
):
|
||||
super().__init__(learning_rate=learning_rate, betas=betas, eps=eps)
|
||||
super().__init__(
|
||||
learning_rate=learning_rate,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
bias_correction=bias_correction,
|
||||
)
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
|
@ -10,8 +10,17 @@ import mlx.nn as nn
|
||||
import mlx.optimizers as opt
|
||||
import mlx.utils
|
||||
import mlx_tests
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
has_torch = True
|
||||
except ImportError as e:
|
||||
has_torch = False
|
||||
|
||||
|
||||
def get_all_optimizers():
|
||||
classes = dict()
|
||||
@ -186,6 +195,51 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_adamw_matches_pytorch(self):
|
||||
mx.random.seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
model = nn.Linear(3, 1)
|
||||
init_weight = np.array(model.weight.tolist())
|
||||
init_bias = np.array(model.bias.tolist())
|
||||
|
||||
def loss_fn(model, x, y):
|
||||
pred = model(x)
|
||||
return nn.losses.mse_loss(pred, y)
|
||||
|
||||
x = np.random.rand(3, 3)
|
||||
y = np.random.rand(3, 1)
|
||||
|
||||
optimizer = opt.AdamW(learning_rate=3e-4, bias_correction=True)
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
loss, grads = loss_and_grad_fn(model, mx.array(x), mx.array(y))
|
||||
optimizer.update(model, grads)
|
||||
|
||||
# Equivalent torch code
|
||||
torch_model = torch.nn.Linear(3, 1)
|
||||
|
||||
# copy over the parameters
|
||||
torch_model.weight.data = torch.tensor(init_weight, dtype=torch.float32)
|
||||
torch_model.bias.data = torch.tensor(init_bias, dtype=torch.float32)
|
||||
|
||||
torch_optimizer = torch.optim.AdamW(torch_model.parameters(), lr=3e-4)
|
||||
torch_optimizer.zero_grad()
|
||||
pred = torch_model(torch.tensor(x, dtype=torch.float32))
|
||||
loss = torch.nn.MSELoss()(pred, torch.tensor(y, dtype=torch.float32))
|
||||
loss.backward()
|
||||
torch_optimizer.step()
|
||||
|
||||
for name, param in torch_model.named_parameters():
|
||||
mlx_grad = np.array(grads[name])
|
||||
torch_grad = param.grad.detach().numpy()
|
||||
self.assertTrue(np.allclose(torch_grad, mlx_grad))
|
||||
|
||||
for name, param in torch_model.named_parameters():
|
||||
mlx_param = np.array(model[name])
|
||||
torch_param = param.data.detach().numpy()
|
||||
self.assertTrue(np.allclose(torch_param, mlx_param))
|
||||
|
||||
def test_lion(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
|
Loading…
Reference in New Issue
Block a user