diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index 1b37bcc26..3d40dd0d1 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -395,10 +395,7 @@ class AdaDelta(Optimizer): class Adam(Optimizer): - r"""The Adam optimizer [1]. - - Our Adam implementation follows the original paper and omits the bias - correction in the first and second moment estimates. In detail, + r"""The Adam optimizer [1]. In detail, [1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic optimization. ICLR 2015. @@ -416,6 +413,8 @@ class Adam(Optimizer): gradient and its square. Default: ``(0.9, 0.999)`` eps (float, optional): The term :math:`\epsilon` added to the denominator to improve numerical stability. Default: ``1e-8`` + bias_correction (bool, optional): If set to ``True``, bias correction + is applied. Default: ``False`` """ def __init__( @@ -423,12 +422,14 @@ class Adam(Optimizer): learning_rate: Union[float, Callable[[mx.array], mx.array]], betas: List[float] = [0.9, 0.999], eps: float = 1e-8, + bias_correction: bool = False, ): super().__init__() self._maybe_schedule("learning_rate", learning_rate) self.betas = betas self.eps = eps + self.bias_correction = bias_correction def init_single(self, parameter: mx.array, state: dict): """Initialize optimizer state""" @@ -441,6 +442,8 @@ class Adam(Optimizer): lr = self.learning_rate.astype(gradient.dtype) b1, b2 = self.betas eps = self.eps + bias_correction = self.bias_correction + step = self.step m = state["m"] v = state["v"] @@ -449,15 +452,17 @@ class Adam(Optimizer): state["m"] = m state["v"] = v - return parameter - lr * m / (mx.sqrt(v) + eps) + if bias_correction: + numerator = lr / (1 - b1**step) * m + denominator = mx.sqrt(v) / mx.sqrt(1 - b2**step) + eps + return parameter - numerator / denominator + else: + return parameter - lr * m / (mx.sqrt(v) + eps) class AdamW(Adam): - r"""The AdamW optimizer [1]. - - Following the above convention, in contrast with [1], we do not use bias - correction in the first and second moments for AdamW. We update the weights - with a weight_decay (:math:`\lambda`) value: + r"""The AdamW optimizer [1]. We update the weights with a weight_decay + (:math:`\lambda`) value: [1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay regularization. ICLR 2019. @@ -477,6 +482,8 @@ class AdamW(Adam): denominator to improve numerical stability. Default: ``1e-8`` weight_decay (float, optional): The weight decay :math:`\lambda`. Default: ``0``. + bias_correction (bool, optional): If set to ``True``, bias correction + is applied. Default: ``False`` """ def __init__( @@ -485,8 +492,14 @@ class AdamW(Adam): betas: List[float] = [0.9, 0.999], eps: float = 1e-8, weight_decay: float = 0.01, + bias_correction: bool = False, ): - super().__init__(learning_rate=learning_rate, betas=betas, eps=eps) + super().__init__( + learning_rate=learning_rate, + betas=betas, + eps=eps, + bias_correction=bias_correction, + ) self.weight_decay = weight_decay def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index d817c10e9..cf3a2b4fa 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -10,8 +10,17 @@ import mlx.nn as nn import mlx.optimizers as opt import mlx.utils import mlx_tests +import numpy as np from mlx.utils import tree_flatten, tree_map, tree_unflatten +try: + import torch + import torch.nn.functional as F + + has_torch = True +except ImportError as e: + has_torch = False + def get_all_optimizers(): classes = dict() @@ -186,6 +195,51 @@ class TestOptimizers(mlx_tests.MLXTestCase): ) ) + @unittest.skipIf(not has_torch, "requires Torch") + def test_adamw_matches_pytorch(self): + mx.random.seed(0) + np.random.seed(0) + + model = nn.Linear(3, 1) + init_weight = np.array(model.weight.tolist()) + init_bias = np.array(model.bias.tolist()) + + def loss_fn(model, x, y): + pred = model(x) + return nn.losses.mse_loss(pred, y) + + x = np.random.rand(3, 3) + y = np.random.rand(3, 1) + + optimizer = opt.AdamW(learning_rate=3e-4, bias_correction=True) + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + loss, grads = loss_and_grad_fn(model, mx.array(x), mx.array(y)) + optimizer.update(model, grads) + + # Equivalent torch code + torch_model = torch.nn.Linear(3, 1) + + # copy over the parameters + torch_model.weight.data = torch.tensor(init_weight, dtype=torch.float32) + torch_model.bias.data = torch.tensor(init_bias, dtype=torch.float32) + + torch_optimizer = torch.optim.AdamW(torch_model.parameters(), lr=3e-4) + torch_optimizer.zero_grad() + pred = torch_model(torch.tensor(x, dtype=torch.float32)) + loss = torch.nn.MSELoss()(pred, torch.tensor(y, dtype=torch.float32)) + loss.backward() + torch_optimizer.step() + + for name, param in torch_model.named_parameters(): + mlx_grad = np.array(grads[name]) + torch_grad = param.grad.detach().numpy() + self.assertTrue(np.allclose(torch_grad, mlx_grad)) + + for name, param in torch_model.named_parameters(): + mlx_param = np.array(model[name]) + torch_param = param.data.detach().numpy() + self.assertTrue(np.allclose(torch_param, mlx_param)) + def test_lion(self): params = { "first": [mx.zeros((10,)), mx.zeros((1,))],