mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Support bias correction in Adam and AdamW optimizers (#1640)
This commit is contained in:
		| @@ -10,8 +10,17 @@ import mlx.nn as nn | ||||
| import mlx.optimizers as opt | ||||
| import mlx.utils | ||||
| import mlx_tests | ||||
| import numpy as np | ||||
| from mlx.utils import tree_flatten, tree_map, tree_unflatten | ||||
|  | ||||
| try: | ||||
|     import torch | ||||
|     import torch.nn.functional as F | ||||
|  | ||||
|     has_torch = True | ||||
| except ImportError as e: | ||||
|     has_torch = False | ||||
|  | ||||
|  | ||||
| def get_all_optimizers(): | ||||
|     classes = dict() | ||||
| @@ -186,6 +195,51 @@ class TestOptimizers(mlx_tests.MLXTestCase): | ||||
|                 ) | ||||
|             ) | ||||
|  | ||||
|     @unittest.skipIf(not has_torch, "requires Torch") | ||||
|     def test_adamw_matches_pytorch(self): | ||||
|         mx.random.seed(0) | ||||
|         np.random.seed(0) | ||||
|  | ||||
|         model = nn.Linear(3, 1) | ||||
|         init_weight = np.array(model.weight.tolist()) | ||||
|         init_bias = np.array(model.bias.tolist()) | ||||
|  | ||||
|         def loss_fn(model, x, y): | ||||
|             pred = model(x) | ||||
|             return nn.losses.mse_loss(pred, y) | ||||
|  | ||||
|         x = np.random.rand(3, 3) | ||||
|         y = np.random.rand(3, 1) | ||||
|  | ||||
|         optimizer = opt.AdamW(learning_rate=3e-4, bias_correction=True) | ||||
|         loss_and_grad_fn = nn.value_and_grad(model, loss_fn) | ||||
|         loss, grads = loss_and_grad_fn(model, mx.array(x), mx.array(y)) | ||||
|         optimizer.update(model, grads) | ||||
|  | ||||
|         # Equivalent torch code | ||||
|         torch_model = torch.nn.Linear(3, 1) | ||||
|  | ||||
|         # copy over the parameters | ||||
|         torch_model.weight.data = torch.tensor(init_weight, dtype=torch.float32) | ||||
|         torch_model.bias.data = torch.tensor(init_bias, dtype=torch.float32) | ||||
|  | ||||
|         torch_optimizer = torch.optim.AdamW(torch_model.parameters(), lr=3e-4) | ||||
|         torch_optimizer.zero_grad() | ||||
|         pred = torch_model(torch.tensor(x, dtype=torch.float32)) | ||||
|         loss = torch.nn.MSELoss()(pred, torch.tensor(y, dtype=torch.float32)) | ||||
|         loss.backward() | ||||
|         torch_optimizer.step() | ||||
|  | ||||
|         for name, param in torch_model.named_parameters(): | ||||
|             mlx_grad = np.array(grads[name]) | ||||
|             torch_grad = param.grad.detach().numpy() | ||||
|             self.assertTrue(np.allclose(torch_grad, mlx_grad)) | ||||
|  | ||||
|         for name, param in torch_model.named_parameters(): | ||||
|             mlx_param = np.array(model[name]) | ||||
|             torch_param = param.data.detach().numpy() | ||||
|             self.assertTrue(np.allclose(torch_param, mlx_param)) | ||||
|  | ||||
|     def test_lion(self): | ||||
|         params = { | ||||
|             "first": [mx.zeros((10,)), mx.zeros((1,))], | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 mt_caret
					mt_caret