Support bias correction in Adam and AdamW optimizers (#1640)

This commit is contained in:
mt_caret
2024-12-07 05:13:34 +09:00
committed by GitHub
parent d0b6cb0425
commit fd3377dd1f
2 changed files with 78 additions and 11 deletions

View File

@@ -10,8 +10,17 @@ import mlx.nn as nn
import mlx.optimizers as opt
import mlx.utils
import mlx_tests
import numpy as np
from mlx.utils import tree_flatten, tree_map, tree_unflatten
try:
import torch
import torch.nn.functional as F
has_torch = True
except ImportError as e:
has_torch = False
def get_all_optimizers():
classes = dict()
@@ -186,6 +195,51 @@ class TestOptimizers(mlx_tests.MLXTestCase):
)
)
@unittest.skipIf(not has_torch, "requires Torch")
def test_adamw_matches_pytorch(self):
mx.random.seed(0)
np.random.seed(0)
model = nn.Linear(3, 1)
init_weight = np.array(model.weight.tolist())
init_bias = np.array(model.bias.tolist())
def loss_fn(model, x, y):
pred = model(x)
return nn.losses.mse_loss(pred, y)
x = np.random.rand(3, 3)
y = np.random.rand(3, 1)
optimizer = opt.AdamW(learning_rate=3e-4, bias_correction=True)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, mx.array(x), mx.array(y))
optimizer.update(model, grads)
# Equivalent torch code
torch_model = torch.nn.Linear(3, 1)
# copy over the parameters
torch_model.weight.data = torch.tensor(init_weight, dtype=torch.float32)
torch_model.bias.data = torch.tensor(init_bias, dtype=torch.float32)
torch_optimizer = torch.optim.AdamW(torch_model.parameters(), lr=3e-4)
torch_optimizer.zero_grad()
pred = torch_model(torch.tensor(x, dtype=torch.float32))
loss = torch.nn.MSELoss()(pred, torch.tensor(y, dtype=torch.float32))
loss.backward()
torch_optimizer.step()
for name, param in torch_model.named_parameters():
mlx_grad = np.array(grads[name])
torch_grad = param.grad.detach().numpy()
self.assertTrue(np.allclose(torch_grad, mlx_grad))
for name, param in torch_model.named_parameters():
mlx_param = np.array(model[name])
torch_param = param.data.detach().numpy()
self.assertTrue(np.allclose(torch_param, mlx_param))
def test_lion(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],