mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:29:35 +08:00
Support bias correction in Adam and AdamW optimizers (#1640)
This commit is contained in:
@@ -10,8 +10,17 @@ import mlx.nn as nn
|
||||
import mlx.optimizers as opt
|
||||
import mlx.utils
|
||||
import mlx_tests
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
has_torch = True
|
||||
except ImportError as e:
|
||||
has_torch = False
|
||||
|
||||
|
||||
def get_all_optimizers():
|
||||
classes = dict()
|
||||
@@ -186,6 +195,51 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_adamw_matches_pytorch(self):
|
||||
mx.random.seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
model = nn.Linear(3, 1)
|
||||
init_weight = np.array(model.weight.tolist())
|
||||
init_bias = np.array(model.bias.tolist())
|
||||
|
||||
def loss_fn(model, x, y):
|
||||
pred = model(x)
|
||||
return nn.losses.mse_loss(pred, y)
|
||||
|
||||
x = np.random.rand(3, 3)
|
||||
y = np.random.rand(3, 1)
|
||||
|
||||
optimizer = opt.AdamW(learning_rate=3e-4, bias_correction=True)
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
loss, grads = loss_and_grad_fn(model, mx.array(x), mx.array(y))
|
||||
optimizer.update(model, grads)
|
||||
|
||||
# Equivalent torch code
|
||||
torch_model = torch.nn.Linear(3, 1)
|
||||
|
||||
# copy over the parameters
|
||||
torch_model.weight.data = torch.tensor(init_weight, dtype=torch.float32)
|
||||
torch_model.bias.data = torch.tensor(init_bias, dtype=torch.float32)
|
||||
|
||||
torch_optimizer = torch.optim.AdamW(torch_model.parameters(), lr=3e-4)
|
||||
torch_optimizer.zero_grad()
|
||||
pred = torch_model(torch.tensor(x, dtype=torch.float32))
|
||||
loss = torch.nn.MSELoss()(pred, torch.tensor(y, dtype=torch.float32))
|
||||
loss.backward()
|
||||
torch_optimizer.step()
|
||||
|
||||
for name, param in torch_model.named_parameters():
|
||||
mlx_grad = np.array(grads[name])
|
||||
torch_grad = param.grad.detach().numpy()
|
||||
self.assertTrue(np.allclose(torch_grad, mlx_grad))
|
||||
|
||||
for name, param in torch_model.named_parameters():
|
||||
mlx_param = np.array(model[name])
|
||||
torch_param = param.data.detach().numpy()
|
||||
self.assertTrue(np.allclose(torch_param, mlx_param))
|
||||
|
||||
def test_lion(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
|
Reference in New Issue
Block a user