mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
Support bias correction in Adam and AdamW optimizers (#1640)
This commit is contained in:
parent
d0b6cb0425
commit
fd3377dd1f
@ -395,10 +395,7 @@ class AdaDelta(Optimizer):
|
|||||||
|
|
||||||
|
|
||||||
class Adam(Optimizer):
|
class Adam(Optimizer):
|
||||||
r"""The Adam optimizer [1].
|
r"""The Adam optimizer [1]. In detail,
|
||||||
|
|
||||||
Our Adam implementation follows the original paper and omits the bias
|
|
||||||
correction in the first and second moment estimates. In detail,
|
|
||||||
|
|
||||||
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
|
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
|
||||||
optimization. ICLR 2015.
|
optimization. ICLR 2015.
|
||||||
@ -416,6 +413,8 @@ class Adam(Optimizer):
|
|||||||
gradient and its square. Default: ``(0.9, 0.999)``
|
gradient and its square. Default: ``(0.9, 0.999)``
|
||||||
eps (float, optional): The term :math:`\epsilon` added to the
|
eps (float, optional): The term :math:`\epsilon` added to the
|
||||||
denominator to improve numerical stability. Default: ``1e-8``
|
denominator to improve numerical stability. Default: ``1e-8``
|
||||||
|
bias_correction (bool, optional): If set to ``True``, bias correction
|
||||||
|
is applied. Default: ``False``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -423,12 +422,14 @@ class Adam(Optimizer):
|
|||||||
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||||
betas: List[float] = [0.9, 0.999],
|
betas: List[float] = [0.9, 0.999],
|
||||||
eps: float = 1e-8,
|
eps: float = 1e-8,
|
||||||
|
bias_correction: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self._maybe_schedule("learning_rate", learning_rate)
|
self._maybe_schedule("learning_rate", learning_rate)
|
||||||
self.betas = betas
|
self.betas = betas
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
self.bias_correction = bias_correction
|
||||||
|
|
||||||
def init_single(self, parameter: mx.array, state: dict):
|
def init_single(self, parameter: mx.array, state: dict):
|
||||||
"""Initialize optimizer state"""
|
"""Initialize optimizer state"""
|
||||||
@ -441,6 +442,8 @@ class Adam(Optimizer):
|
|||||||
lr = self.learning_rate.astype(gradient.dtype)
|
lr = self.learning_rate.astype(gradient.dtype)
|
||||||
b1, b2 = self.betas
|
b1, b2 = self.betas
|
||||||
eps = self.eps
|
eps = self.eps
|
||||||
|
bias_correction = self.bias_correction
|
||||||
|
step = self.step
|
||||||
|
|
||||||
m = state["m"]
|
m = state["m"]
|
||||||
v = state["v"]
|
v = state["v"]
|
||||||
@ -449,15 +452,17 @@ class Adam(Optimizer):
|
|||||||
state["m"] = m
|
state["m"] = m
|
||||||
state["v"] = v
|
state["v"] = v
|
||||||
|
|
||||||
|
if bias_correction:
|
||||||
|
numerator = lr / (1 - b1**step) * m
|
||||||
|
denominator = mx.sqrt(v) / mx.sqrt(1 - b2**step) + eps
|
||||||
|
return parameter - numerator / denominator
|
||||||
|
else:
|
||||||
return parameter - lr * m / (mx.sqrt(v) + eps)
|
return parameter - lr * m / (mx.sqrt(v) + eps)
|
||||||
|
|
||||||
|
|
||||||
class AdamW(Adam):
|
class AdamW(Adam):
|
||||||
r"""The AdamW optimizer [1].
|
r"""The AdamW optimizer [1]. We update the weights with a weight_decay
|
||||||
|
(:math:`\lambda`) value:
|
||||||
Following the above convention, in contrast with [1], we do not use bias
|
|
||||||
correction in the first and second moments for AdamW. We update the weights
|
|
||||||
with a weight_decay (:math:`\lambda`) value:
|
|
||||||
|
|
||||||
[1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay
|
[1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay
|
||||||
regularization. ICLR 2019.
|
regularization. ICLR 2019.
|
||||||
@ -477,6 +482,8 @@ class AdamW(Adam):
|
|||||||
denominator to improve numerical stability. Default: ``1e-8``
|
denominator to improve numerical stability. Default: ``1e-8``
|
||||||
weight_decay (float, optional): The weight decay :math:`\lambda`.
|
weight_decay (float, optional): The weight decay :math:`\lambda`.
|
||||||
Default: ``0``.
|
Default: ``0``.
|
||||||
|
bias_correction (bool, optional): If set to ``True``, bias correction
|
||||||
|
is applied. Default: ``False``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -485,8 +492,14 @@ class AdamW(Adam):
|
|||||||
betas: List[float] = [0.9, 0.999],
|
betas: List[float] = [0.9, 0.999],
|
||||||
eps: float = 1e-8,
|
eps: float = 1e-8,
|
||||||
weight_decay: float = 0.01,
|
weight_decay: float = 0.01,
|
||||||
|
bias_correction: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(learning_rate=learning_rate, betas=betas, eps=eps)
|
super().__init__(
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
betas=betas,
|
||||||
|
eps=eps,
|
||||||
|
bias_correction=bias_correction,
|
||||||
|
)
|
||||||
self.weight_decay = weight_decay
|
self.weight_decay = weight_decay
|
||||||
|
|
||||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||||
|
@ -10,8 +10,17 @@ import mlx.nn as nn
|
|||||||
import mlx.optimizers as opt
|
import mlx.optimizers as opt
|
||||||
import mlx.utils
|
import mlx.utils
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
|
import numpy as np
|
||||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
has_torch = True
|
||||||
|
except ImportError as e:
|
||||||
|
has_torch = False
|
||||||
|
|
||||||
|
|
||||||
def get_all_optimizers():
|
def get_all_optimizers():
|
||||||
classes = dict()
|
classes = dict()
|
||||||
@ -186,6 +195,51 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@unittest.skipIf(not has_torch, "requires Torch")
|
||||||
|
def test_adamw_matches_pytorch(self):
|
||||||
|
mx.random.seed(0)
|
||||||
|
np.random.seed(0)
|
||||||
|
|
||||||
|
model = nn.Linear(3, 1)
|
||||||
|
init_weight = np.array(model.weight.tolist())
|
||||||
|
init_bias = np.array(model.bias.tolist())
|
||||||
|
|
||||||
|
def loss_fn(model, x, y):
|
||||||
|
pred = model(x)
|
||||||
|
return nn.losses.mse_loss(pred, y)
|
||||||
|
|
||||||
|
x = np.random.rand(3, 3)
|
||||||
|
y = np.random.rand(3, 1)
|
||||||
|
|
||||||
|
optimizer = opt.AdamW(learning_rate=3e-4, bias_correction=True)
|
||||||
|
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||||
|
loss, grads = loss_and_grad_fn(model, mx.array(x), mx.array(y))
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
|
||||||
|
# Equivalent torch code
|
||||||
|
torch_model = torch.nn.Linear(3, 1)
|
||||||
|
|
||||||
|
# copy over the parameters
|
||||||
|
torch_model.weight.data = torch.tensor(init_weight, dtype=torch.float32)
|
||||||
|
torch_model.bias.data = torch.tensor(init_bias, dtype=torch.float32)
|
||||||
|
|
||||||
|
torch_optimizer = torch.optim.AdamW(torch_model.parameters(), lr=3e-4)
|
||||||
|
torch_optimizer.zero_grad()
|
||||||
|
pred = torch_model(torch.tensor(x, dtype=torch.float32))
|
||||||
|
loss = torch.nn.MSELoss()(pred, torch.tensor(y, dtype=torch.float32))
|
||||||
|
loss.backward()
|
||||||
|
torch_optimizer.step()
|
||||||
|
|
||||||
|
for name, param in torch_model.named_parameters():
|
||||||
|
mlx_grad = np.array(grads[name])
|
||||||
|
torch_grad = param.grad.detach().numpy()
|
||||||
|
self.assertTrue(np.allclose(torch_grad, mlx_grad))
|
||||||
|
|
||||||
|
for name, param in torch_model.named_parameters():
|
||||||
|
mlx_param = np.array(model[name])
|
||||||
|
torch_param = param.data.detach().numpy()
|
||||||
|
self.assertTrue(np.allclose(torch_param, mlx_param))
|
||||||
|
|
||||||
def test_lion(self):
|
def test_lion(self):
|
||||||
params = {
|
params = {
|
||||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||||
|
Loading…
Reference in New Issue
Block a user